Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 79 additions & 40 deletions packages/event-handler/src/rest/middleware/cors.ts
Original file line number Diff line number Diff line change
@@ -1,26 +1,10 @@
import type {
CorsOptions,
Middleware,
} from '../../types/rest.js';
import type { CorsOptions, Middleware } from '../../types/rest.js';
import {
DEFAULT_CORS_OPTIONS,
HttpErrorCodes,
HttpVerbs,
} from '../constants.js';

/**
* Resolves the origin value based on the configuration
*/
const resolveOrigin = (
originConfig: NonNullable<CorsOptions['origin']>,
requestOrigin: string | null,
): string => {
if (Array.isArray(originConfig)) {
return requestOrigin && originConfig.includes(requestOrigin) ? requestOrigin : '';
}
return originConfig;
};

/**
* Creates a CORS middleware that adds appropriate CORS headers to responses
* and handles OPTIONS preflight requests.
Expand All @@ -29,9 +13,9 @@ const resolveOrigin = (
* ```typescript
* import { Router } from '@aws-lambda-powertools/event-handler/experimental-rest';
* import { cors } from '@aws-lambda-powertools/event-handler/experimental-rest/middleware';
*
*
* const app = new Router();
*
*
* // Use default configuration
* app.use(cors());
*
Expand All @@ -50,7 +34,7 @@ const resolveOrigin = (
* }
* }));
* ```
*
*
* @param options.origin - The origin to allow requests from
* @param options.allowMethods - The HTTP methods to allow
* @param options.allowHeaders - The headers to allow
Expand All @@ -61,38 +45,93 @@ const resolveOrigin = (
export const cors = (options?: CorsOptions): Middleware => {
const config = {
...DEFAULT_CORS_OPTIONS,
...options
...options,
};
const allowedOrigins =
typeof config.origin === 'string' ? [config.origin] : config.origin;
const allowsWildcard = allowedOrigins.includes('*');
const allowedMethods = config.allowMethods.map((method) =>
method.toUpperCase()
);
const allowedHeaders = config.allowHeaders.map((header) =>
header.toLowerCase()
);

return async (_params, reqCtx, next) => {
const requestOrigin = reqCtx.request.headers.get('Origin');
const resolvedOrigin = resolveOrigin(config.origin, requestOrigin);
const isOriginAllowed = (
requestOrigin: string | null
): requestOrigin is string => {
return (
requestOrigin !== null &&
(allowsWildcard || allowedOrigins.includes(requestOrigin))
);
};

reqCtx.res.headers.set('access-control-allow-origin', resolvedOrigin);
if (resolvedOrigin !== '*') {
reqCtx.res.headers.set('Vary', 'Origin');
const isValidPreflightRequest = (requestHeaders: Headers) => {
const accessControlRequestMethod = requestHeaders
.get('Access-Control-Request-Method')
?.toUpperCase();
const accessControlRequestHeaders = requestHeaders
.get('Access-Control-Request-Headers')
?.toLowerCase();
return (
accessControlRequestMethod &&
allowedMethods.includes(accessControlRequestMethod) &&
accessControlRequestHeaders
?.split(',')
.every((header) => allowedHeaders.includes(header.trim()))
);
};

const setCORSBaseHeaders = (
requestOrigin: string,
responseHeaders: Headers
) => {
const resolvedOrigin = allowsWildcard ? '*' : requestOrigin;
responseHeaders.set('access-control-allow-origin', resolvedOrigin);
if (!allowsWildcard && Array.isArray(config.origin)) {
responseHeaders.set('vary', 'Origin');
}
config.allowMethods.forEach(method => {
reqCtx.res.headers.append('access-control-allow-methods', method);
});
config.allowHeaders.forEach(header => {
reqCtx.res.headers.append('access-control-allow-headers', header);
});
config.exposeHeaders.forEach(header => {
reqCtx.res.headers.append('access-control-expose-headers', header);
});
reqCtx.res.headers.set('access-control-allow-credentials', config.credentials.toString());
if (config.maxAge !== undefined) {
reqCtx.res.headers.set('access-control-max-age', config.maxAge.toString());
if (config.credentials) {
responseHeaders.set('access-control-allow-credentials', 'true');
}
};

return async (_params, reqCtx, next) => {
const requestOrigin = reqCtx.request.headers.get('Origin');
if (!isOriginAllowed(requestOrigin)) {
await next();
return;
}

// Handle preflight OPTIONS request
if (reqCtx.request.method === HttpVerbs.OPTIONS && reqCtx.request.headers.has('Access-Control-Request-Method')) {
if (reqCtx.request.method === HttpVerbs.OPTIONS) {
if (!isValidPreflightRequest(reqCtx.request.headers)) {
await next();
return;
}
setCORSBaseHeaders(requestOrigin, reqCtx.res.headers);
if (config.maxAge !== undefined) {
reqCtx.res.headers.set(
'access-control-max-age',
config.maxAge.toString()
);
}
for (const method of allowedMethods) {
reqCtx.res.headers.append('access-control-allow-methods', method);
}
for (const header of allowedHeaders) {
reqCtx.res.headers.append('access-control-allow-headers', header);
}
return new Response(null, {
status: HttpErrorCodes.NO_CONTENT,
headers: reqCtx.res.headers,
});
}

setCORSBaseHeaders(requestOrigin, reqCtx.res.headers);
for (const header of config.exposeHeaders) {
reqCtx.res.headers.append('access-control-expose-headers', header);
}
await next();
};
};
Loading