diff --git a/packages/event-handler/src/rest/constants.ts b/packages/event-handler/src/rest/constants.ts index b647b78fb9..ae9a80103f 100644 --- a/packages/event-handler/src/rest/constants.ts +++ b/packages/event-handler/src/rest/constants.ts @@ -88,6 +88,23 @@ export const SAFE_CHARS = "-._~()'!*:@,;=+&$"; export const UNSAFE_CHARS = '%<> \\[\\]{}|^'; +/** + * Default CORS configuration + */ +export const DEFAULT_CORS_OPTIONS = { + origin: '*', + allowMethods: ['DELETE', 'GET', 'HEAD', 'PATCH', 'POST', 'PUT'], + allowHeaders: [ + 'Authorization', + 'Content-Type', + 'X-Amz-Date', + 'X-Api-Key', + 'X-Amz-Security-Token', + ], + exposeHeaders: [], + credentials: false +}; + export const DEFAULT_COMPRESSION_RESPONSE_THRESHOLD = 1024; export const CACHE_CONTROL_NO_TRANSFORM_REGEX = diff --git a/packages/event-handler/src/rest/middleware/cors.ts b/packages/event-handler/src/rest/middleware/cors.ts new file mode 100644 index 0000000000..6809a1d2d4 --- /dev/null +++ b/packages/event-handler/src/rest/middleware/cors.ts @@ -0,0 +1,98 @@ +import type { + CorsOptions, + Middleware, +} from '../../types/rest.js'; +import { + DEFAULT_CORS_OPTIONS, + HttpErrorCodes, + HttpVerbs, +} from '../constants.js'; + +/** + * Resolves the origin value based on the configuration + */ +const resolveOrigin = ( + originConfig: NonNullable, + requestOrigin: string | null, +): string => { + if (Array.isArray(originConfig)) { + return requestOrigin && originConfig.includes(requestOrigin) ? requestOrigin : ''; + } + return originConfig; +}; + +/** + * Creates a CORS middleware that adds appropriate CORS headers to responses + * and handles OPTIONS preflight requests. + * + * @example + * ```typescript + * import { Router } from '@aws-lambda-powertools/event-handler/experimental-rest'; + * import { cors } from '@aws-lambda-powertools/event-handler/experimental-rest/middleware'; + * + * const app = new Router(); + * + * // Use default configuration + * app.use(cors()); + * + * // Custom configuration + * app.use(cors({ + * origin: 'https://example.com', + * allowMethods: ['GET', 'POST'], + * credentials: true, + * })); + * + * // Dynamic origin with function + * app.use(cors({ + * origin: (origin, reqCtx) => { + * const allowedOrigins = ['https://app.com', 'https://admin.app.com']; + * return origin && allowedOrigins.includes(origin); + * } + * })); + * ``` + * + * @param options.origin - The origin to allow requests from + * @param options.allowMethods - The HTTP methods to allow + * @param options.allowHeaders - The headers to allow + * @param options.exposeHeaders - The headers to expose + * @param options.credentials - Whether to allow credentials + * @param options.maxAge - The maximum age for the preflight response + */ +export const cors = (options?: CorsOptions): Middleware => { + const config = { + ...DEFAULT_CORS_OPTIONS, + ...options + }; + + return async (_params, reqCtx, next) => { + const requestOrigin = reqCtx.request.headers.get('Origin'); + const resolvedOrigin = resolveOrigin(config.origin, requestOrigin); + + reqCtx.res.headers.set('access-control-allow-origin', resolvedOrigin); + if (resolvedOrigin !== '*') { + reqCtx.res.headers.set('Vary', 'Origin'); + } + config.allowMethods.forEach(method => { + reqCtx.res.headers.append('access-control-allow-methods', method); + }); + config.allowHeaders.forEach(header => { + reqCtx.res.headers.append('access-control-allow-headers', header); + }); + config.exposeHeaders.forEach(header => { + reqCtx.res.headers.append('access-control-expose-headers', header); + }); + reqCtx.res.headers.set('access-control-allow-credentials', config.credentials.toString()); + if (config.maxAge !== undefined) { + reqCtx.res.headers.set('access-control-max-age', config.maxAge.toString()); + } + + // Handle preflight OPTIONS request + if (reqCtx.request.method === HttpVerbs.OPTIONS && reqCtx.request.headers.has('Access-Control-Request-Method')) { + return new Response(null, { + status: HttpErrorCodes.NO_CONTENT, + headers: reqCtx.res.headers, + }); + } + await next(); + }; +}; diff --git a/packages/event-handler/src/rest/middleware/index.ts b/packages/event-handler/src/rest/middleware/index.ts index 2b65a8ee8e..96686c578b 100644 --- a/packages/event-handler/src/rest/middleware/index.ts +++ b/packages/event-handler/src/rest/middleware/index.ts @@ -1 +1,2 @@ export { compress } from './compress.js'; +export { cors } from './cors.js'; diff --git a/packages/event-handler/src/types/index.ts b/packages/event-handler/src/types/index.ts index 4625fcf1fe..56e017c563 100644 --- a/packages/event-handler/src/types/index.ts +++ b/packages/event-handler/src/types/index.ts @@ -32,6 +32,7 @@ export type { } from './common.js'; export type { + CorsOptions, ErrorHandler, ErrorResolveOptions, ErrorResponse, diff --git a/packages/event-handler/src/types/rest.ts b/packages/event-handler/src/types/rest.ts index 0ba650c0dc..e4924d1d2c 100644 --- a/packages/event-handler/src/types/rest.ts +++ b/packages/event-handler/src/types/rest.ts @@ -111,6 +111,48 @@ type ValidationResult = { issues: string[]; }; +/** + * Configuration options for CORS middleware + */ +type CorsOptions = { + /** + * The Access-Control-Allow-Origin header value. + * Can be a string, array of strings. + * @default '*' + */ + origin?: string | string[]; + + /** + * The Access-Control-Allow-Methods header value. + * @default ['DELETE', 'GET', 'HEAD', 'PATCH', 'POST', 'PUT'] + */ + allowMethods?: string[]; + + /** + * The Access-Control-Allow-Headers header value. + * @default ['Authorization', 'Content-Type', 'X-Amz-Date', 'X-Api-Key', 'X-Amz-Security-Token'] + */ + allowHeaders?: string[]; + + /** + * The Access-Control-Expose-Headers header value. + * @default [] + */ + exposeHeaders?: string[]; + + /** + * The Access-Control-Allow-Credentials header value. + * @default false + */ + credentials?: boolean; + + /** + * The Access-Control-Max-Age header value in seconds. + * Only applicable for preflight requests. + */ + maxAge?: number; +}; + type CompressionOptions = { encoding?: 'gzip' | 'deflate'; threshold?: number; @@ -118,6 +160,7 @@ type CompressionOptions = { export type { CompiledRoute, + CorsOptions, DynamicRoute, ErrorResponse, ErrorConstructor, diff --git a/packages/event-handler/tests/unit/rest/helpers.ts b/packages/event-handler/tests/unit/rest/helpers.ts index 93255f79c1..6edd2c4678 100644 --- a/packages/event-handler/tests/unit/rest/helpers.ts +++ b/packages/event-handler/tests/unit/rest/helpers.ts @@ -77,3 +77,14 @@ export const createSettingHeadersMiddleware = (headers: { }); }; }; + +export const createHeaderCheckMiddleware = (headers: { + [key: string]: string; +}): Middleware => { + return async (_params, options, next) => { + options.res.headers.forEach((value, key) => { + headers[key] = value; + }); + await next(); + }; +}; \ No newline at end of file diff --git a/packages/event-handler/tests/unit/rest/middleware/cors.test.ts b/packages/event-handler/tests/unit/rest/middleware/cors.test.ts new file mode 100644 index 0000000000..989e875d31 --- /dev/null +++ b/packages/event-handler/tests/unit/rest/middleware/cors.test.ts @@ -0,0 +1,111 @@ +import { beforeEach, describe, expect, it } from 'vitest'; +import context from '@aws-lambda-powertools/testing-utils/context'; +import { cors } from '../../../../src/rest/middleware/cors.js'; +import { createTestEvent, createHeaderCheckMiddleware } from '../helpers.js'; +import { Router } from '../../../../src/rest/Router.js'; +import { DEFAULT_CORS_OPTIONS } from 'src/rest/constants.js'; + +describe('CORS Middleware', () => { + const getRequestEvent = createTestEvent('/test', 'GET'); + const optionsRequestEvent = createTestEvent('/test', 'OPTIONS'); + let app: Router; + + const customCorsOptions = { + origin: 'https://example.com', + allowMethods: ['GET', 'POST'], + allowHeaders: ['Authorization', 'Content-Type'], + credentials: true, + exposeHeaders: ['Authorization', 'X-Custom-Header'], + maxAge: 86400, + }; + + const expectedDefaultHeaders = { + "access-control-allow-credentials": "false", + "access-control-allow-headers": "Authorization, Content-Type, X-Amz-Date, X-Api-Key, X-Amz-Security-Token", + "access-control-allow-methods": "DELETE, GET, HEAD, PATCH, POST, PUT", + "access-control-allow-origin": "*", + }; + + beforeEach(() => { + app = new Router(); + app.use(cors()); + }); + + it('uses default configuration when no options are provided', async () => { + // Prepare + const corsHeaders: { [key: string]: string } = {}; + app.get('/test', [createHeaderCheckMiddleware(corsHeaders)], async () => ({ success: true })); + + // Act + const result = await app.resolve(getRequestEvent, context); + + // Assess + expect(result.headers?.['access-control-allow-origin']).toEqual(DEFAULT_CORS_OPTIONS.origin); + expect(result.multiValueHeaders?.['access-control-allow-methods']).toEqual(DEFAULT_CORS_OPTIONS.allowMethods); + expect(result.multiValueHeaders?.['access-control-allow-headers']).toEqual(DEFAULT_CORS_OPTIONS.allowHeaders); + expect(result.headers?.['access-control-allow-credentials']).toEqual(DEFAULT_CORS_OPTIONS.credentials.toString()); + expect(corsHeaders).toMatchObject(expectedDefaultHeaders); + }); + + it('merges user options with defaults', async () => { + // Prepare + const corsHeaders: { [key: string]: string } = {}; + const app = new Router(); + app.get('/test', [cors(customCorsOptions), createHeaderCheckMiddleware(corsHeaders)], async () => ({ success: true })); + + // Act + const result = await app.resolve(getRequestEvent, context); + + // Assess + expect(result.headers?.['access-control-allow-origin']).toEqual('https://example.com'); + expect(result.multiValueHeaders?.['access-control-allow-methods']).toEqual(['GET', 'POST']); + expect(result.multiValueHeaders?.['access-control-allow-headers']).toEqual(['Authorization', 'Content-Type']); + expect(result.headers?.['access-control-allow-credentials']).toEqual('true'); + expect(result.multiValueHeaders?.['access-control-expose-headers']).toEqual(['Authorization', 'X-Custom-Header']); + expect(result.headers?.['access-control-max-age']).toEqual('86400'); + expect(corsHeaders).toMatchObject({ + "access-control-allow-credentials": "true", + "access-control-allow-headers": "Authorization, Content-Type", + "access-control-allow-methods": "GET, POST", + "access-control-allow-origin": "https://example.com", + }); + }); + + it.each([ + ['matching', 'https://app.com', 'https://app.com'], + ['non-matching', 'https://non-matching.com', ''] + ])('handles array origin with %s request', async (_, origin, expected) => { + // Prepare + const app = new Router(); + app.get('/test', [cors({ origin: ['https://app.com', 'https://admin.app.com'] })], async () => ({ success: true })); + + // Act + const result = await app.resolve(createTestEvent('/test', 'GET', { 'Origin': origin }), context); + + // Assess + expect(result.headers?.['access-control-allow-origin']).toEqual(expected); + }); + + it('handles OPTIONS preflight requests', async () => { + // Prepare + app.options('/test', async () => ({ foo: 'bar' })); + + // Act + const result = await app.resolve(createTestEvent('/test', 'OPTIONS', { 'Access-Control-Request-Method': 'GET' }), context); + + // Assess + expect(result.statusCode).toBe(204); + }); + + it('calls the next middleware if the Access-Control-Request-Method is not present', async () => { + // Prepare + const corsHeaders: { [key: string]: string } = {}; + app.options('/test', [createHeaderCheckMiddleware(corsHeaders)], async () => ({ success: true })); + + // Act + await app.resolve(optionsRequestEvent, context); + + // Assess + expect(corsHeaders).toMatchObject(expectedDefaultHeaders); + }); +});