diff --git a/packages/event-handler/src/rest/middleware/cors.ts b/packages/event-handler/src/rest/middleware/cors.ts index 6809a1d2d4..983f805530 100644 --- a/packages/event-handler/src/rest/middleware/cors.ts +++ b/packages/event-handler/src/rest/middleware/cors.ts @@ -1,26 +1,10 @@ -import type { - CorsOptions, - Middleware, -} from '../../types/rest.js'; +import type { CorsOptions, Middleware } from '../../types/rest.js'; import { DEFAULT_CORS_OPTIONS, HttpErrorCodes, HttpVerbs, } from '../constants.js'; -/** - * Resolves the origin value based on the configuration - */ -const resolveOrigin = ( - originConfig: NonNullable, - requestOrigin: string | null, -): string => { - if (Array.isArray(originConfig)) { - return requestOrigin && originConfig.includes(requestOrigin) ? requestOrigin : ''; - } - return originConfig; -}; - /** * Creates a CORS middleware that adds appropriate CORS headers to responses * and handles OPTIONS preflight requests. @@ -29,9 +13,9 @@ const resolveOrigin = ( * ```typescript * import { Router } from '@aws-lambda-powertools/event-handler/experimental-rest'; * import { cors } from '@aws-lambda-powertools/event-handler/experimental-rest/middleware'; - * + * * const app = new Router(); - * + * * // Use default configuration * app.use(cors()); * @@ -50,7 +34,7 @@ const resolveOrigin = ( * } * })); * ``` - * + * * @param options.origin - The origin to allow requests from * @param options.allowMethods - The HTTP methods to allow * @param options.allowHeaders - The headers to allow @@ -61,38 +45,93 @@ const resolveOrigin = ( export const cors = (options?: CorsOptions): Middleware => { const config = { ...DEFAULT_CORS_OPTIONS, - ...options + ...options, }; + const allowedOrigins = + typeof config.origin === 'string' ? [config.origin] : config.origin; + const allowsWildcard = allowedOrigins.includes('*'); + const allowedMethods = config.allowMethods.map((method) => + method.toUpperCase() + ); + const allowedHeaders = config.allowHeaders.map((header) => + header.toLowerCase() + ); - return async (_params, reqCtx, next) => { - const requestOrigin = reqCtx.request.headers.get('Origin'); - const resolvedOrigin = resolveOrigin(config.origin, requestOrigin); + const isOriginAllowed = ( + requestOrigin: string | null + ): requestOrigin is string => { + return ( + requestOrigin !== null && + (allowsWildcard || allowedOrigins.includes(requestOrigin)) + ); + }; - reqCtx.res.headers.set('access-control-allow-origin', resolvedOrigin); - if (resolvedOrigin !== '*') { - reqCtx.res.headers.set('Vary', 'Origin'); + const isValidPreflightRequest = (requestHeaders: Headers) => { + const accessControlRequestMethod = requestHeaders + .get('Access-Control-Request-Method') + ?.toUpperCase(); + const accessControlRequestHeaders = requestHeaders + .get('Access-Control-Request-Headers') + ?.toLowerCase(); + return ( + accessControlRequestMethod && + allowedMethods.includes(accessControlRequestMethod) && + accessControlRequestHeaders + ?.split(',') + .every((header) => allowedHeaders.includes(header.trim())) + ); + }; + + const setCORSBaseHeaders = ( + requestOrigin: string, + responseHeaders: Headers + ) => { + const resolvedOrigin = allowsWildcard ? '*' : requestOrigin; + responseHeaders.set('access-control-allow-origin', resolvedOrigin); + if (!allowsWildcard && Array.isArray(config.origin)) { + responseHeaders.set('vary', 'Origin'); } - config.allowMethods.forEach(method => { - reqCtx.res.headers.append('access-control-allow-methods', method); - }); - config.allowHeaders.forEach(header => { - reqCtx.res.headers.append('access-control-allow-headers', header); - }); - config.exposeHeaders.forEach(header => { - reqCtx.res.headers.append('access-control-expose-headers', header); - }); - reqCtx.res.headers.set('access-control-allow-credentials', config.credentials.toString()); - if (config.maxAge !== undefined) { - reqCtx.res.headers.set('access-control-max-age', config.maxAge.toString()); + if (config.credentials) { + responseHeaders.set('access-control-allow-credentials', 'true'); + } + }; + + return async (_params, reqCtx, next) => { + const requestOrigin = reqCtx.request.headers.get('Origin'); + if (!isOriginAllowed(requestOrigin)) { + await next(); + return; } // Handle preflight OPTIONS request - if (reqCtx.request.method === HttpVerbs.OPTIONS && reqCtx.request.headers.has('Access-Control-Request-Method')) { + if (reqCtx.request.method === HttpVerbs.OPTIONS) { + if (!isValidPreflightRequest(reqCtx.request.headers)) { + await next(); + return; + } + setCORSBaseHeaders(requestOrigin, reqCtx.res.headers); + if (config.maxAge !== undefined) { + reqCtx.res.headers.set( + 'access-control-max-age', + config.maxAge.toString() + ); + } + for (const method of allowedMethods) { + reqCtx.res.headers.append('access-control-allow-methods', method); + } + for (const header of allowedHeaders) { + reqCtx.res.headers.append('access-control-allow-headers', header); + } return new Response(null, { status: HttpErrorCodes.NO_CONTENT, headers: reqCtx.res.headers, }); } + + setCORSBaseHeaders(requestOrigin, reqCtx.res.headers); + for (const header of config.exposeHeaders) { + reqCtx.res.headers.append('access-control-expose-headers', header); + } await next(); }; }; diff --git a/packages/event-handler/tests/unit/rest/middleware/cors.test.ts b/packages/event-handler/tests/unit/rest/middleware/cors.test.ts index 989e875d31..e907eef410 100644 --- a/packages/event-handler/tests/unit/rest/middleware/cors.test.ts +++ b/packages/event-handler/tests/unit/rest/middleware/cors.test.ts @@ -1,111 +1,206 @@ -import { beforeEach, describe, expect, it } from 'vitest'; import context from '@aws-lambda-powertools/testing-utils/context'; +import { beforeEach, describe, expect, it } from 'vitest'; +import { DEFAULT_CORS_OPTIONS } from '../../../../src/rest/constants.js'; import { cors } from '../../../../src/rest/middleware/cors.js'; -import { createTestEvent, createHeaderCheckMiddleware } from '../helpers.js'; import { Router } from '../../../../src/rest/Router.js'; -import { DEFAULT_CORS_OPTIONS } from 'src/rest/constants.js'; +import { createHeaderCheckMiddleware, createTestEvent } from '../helpers.js'; describe('CORS Middleware', () => { - const getRequestEvent = createTestEvent('/test', 'GET'); - const optionsRequestEvent = createTestEvent('/test', 'OPTIONS'); + const origin = 'https://example.com'; + const getRequestEvent = createTestEvent('/test', 'GET', { + Origin: origin, + }); let app: Router; - const customCorsOptions = { - origin: 'https://example.com', - allowMethods: ['GET', 'POST'], - allowHeaders: ['Authorization', 'Content-Type'], - credentials: true, - exposeHeaders: ['Authorization', 'X-Custom-Header'], - maxAge: 86400, - }; - - const expectedDefaultHeaders = { - "access-control-allow-credentials": "false", - "access-control-allow-headers": "Authorization, Content-Type, X-Amz-Date, X-Api-Key, X-Amz-Security-Token", - "access-control-allow-methods": "DELETE, GET, HEAD, PATCH, POST, PUT", - "access-control-allow-origin": "*", - }; - beforeEach(() => { app = new Router(); app.use(cors()); }); - it('uses default configuration when no options are provided', async () => { + it('does not set CORS headers when request has no origin header', async () => { + // Prepare + app.get('/test', async () => ({ success: true })); + + // Act + const result = await app.resolve(createTestEvent('/test', 'GET'), context); + + // Assess + expect(result.headers?.['access-control-allow-origin']).toBeUndefined(); + }); + + it('does not set CORS headers when request origin does not match with allowed origin', async () => { + // Prepare + const app = new Router(); + app.get( + '/test', + [ + cors({ + origin: 'https://another-origin.com', + }), + ], + async () => ({ success: true }) + ); + + // Act + const result = await app.resolve(getRequestEvent, context); + + // Assess + expect(result.headers?.['access-control-allow-origin']).toBeUndefined(); + }); + + it('uses default CORS configuration when no options are provided', async () => { // Prepare const corsHeaders: { [key: string]: string } = {}; - app.get('/test', [createHeaderCheckMiddleware(corsHeaders)], async () => ({ success: true })); + app.get('/test', [createHeaderCheckMiddleware(corsHeaders)], async () => ({ + success: true, + })); // Act const result = await app.resolve(getRequestEvent, context); // Assess - expect(result.headers?.['access-control-allow-origin']).toEqual(DEFAULT_CORS_OPTIONS.origin); - expect(result.multiValueHeaders?.['access-control-allow-methods']).toEqual(DEFAULT_CORS_OPTIONS.allowMethods); - expect(result.multiValueHeaders?.['access-control-allow-headers']).toEqual(DEFAULT_CORS_OPTIONS.allowHeaders); - expect(result.headers?.['access-control-allow-credentials']).toEqual(DEFAULT_CORS_OPTIONS.credentials.toString()); - expect(corsHeaders).toMatchObject(expectedDefaultHeaders); + expect(result.headers?.['access-control-allow-origin']).toEqual( + DEFAULT_CORS_OPTIONS.origin + ); + expect(corsHeaders['access-control-allow-origin']).toEqual( + DEFAULT_CORS_OPTIONS.origin + ); }); - it('merges user options with defaults', async () => { + it('uses custom CORS configuration when provided', async () => { // Prepare const corsHeaders: { [key: string]: string } = {}; const app = new Router(); - app.get('/test', [cors(customCorsOptions), createHeaderCheckMiddleware(corsHeaders)], async () => ({ success: true })); + const customConfig = { + origin, + credentials: true, + exposeHeaders: ['Authorization', 'X-Custom-Header'], + }; + app.get( + '/test', + [cors(customConfig), createHeaderCheckMiddleware(corsHeaders)], + async () => ({ success: true }) + ); // Act const result = await app.resolve(getRequestEvent, context); // Assess - expect(result.headers?.['access-control-allow-origin']).toEqual('https://example.com'); - expect(result.multiValueHeaders?.['access-control-allow-methods']).toEqual(['GET', 'POST']); - expect(result.multiValueHeaders?.['access-control-allow-headers']).toEqual(['Authorization', 'Content-Type']); - expect(result.headers?.['access-control-allow-credentials']).toEqual('true'); - expect(result.multiValueHeaders?.['access-control-expose-headers']).toEqual(['Authorization', 'X-Custom-Header']); - expect(result.headers?.['access-control-max-age']).toEqual('86400'); - expect(corsHeaders).toMatchObject({ - "access-control-allow-credentials": "true", - "access-control-allow-headers": "Authorization, Content-Type", - "access-control-allow-methods": "GET, POST", - "access-control-allow-origin": "https://example.com", - }); + expect(result.headers?.['access-control-allow-origin']).toEqual(origin); + expect(result.headers?.['access-control-allow-credentials']).toEqual( + customConfig.credentials.toString() + ); + expect(result.multiValueHeaders?.['access-control-expose-headers']).toEqual( + customConfig.exposeHeaders + ); + expect(corsHeaders['access-control-allow-origin']).toEqual(origin); + expect(corsHeaders['access-control-allow-credentials']).toEqual( + customConfig.credentials.toString() + ); + expect(corsHeaders['access-control-expose-headers']).toEqual( + customConfig.exposeHeaders.join(', ') + ); }); - it.each([ - ['matching', 'https://app.com', 'https://app.com'], - ['non-matching', 'https://non-matching.com', ''] - ])('handles array origin with %s request', async (_, origin, expected) => { + it('sets the vary header if the response is dynamic based on origin', async () => { // Prepare const app = new Router(); - app.get('/test', [cors({ origin: ['https://app.com', 'https://admin.app.com'] })], async () => ({ success: true })); + app.get( + '/test', + [ + cors({ + origin: ['https://example.com', 'https://another-example.com'], + }), + ], + async () => ({ success: true }) + ); // Act - const result = await app.resolve(createTestEvent('/test', 'GET', { 'Origin': origin }), context); + const result = await app.resolve(getRequestEvent, context); // Assess - expect(result.headers?.['access-control-allow-origin']).toEqual(expected); + expect(result.headers?.['access-control-allow-origin']).toEqual(origin); + expect(result.headers?.vary).toEqual('Origin'); }); - it('handles OPTIONS preflight requests', async () => { + it('does not set CORS headers when preflight request method does not match allowed method', async () => { // Prepare - app.options('/test', async () => ({ foo: 'bar' })); + const app = new Router(); + app.use( + cors({ + allowMethods: ['POST'], + }) + ); // Act - const result = await app.resolve(createTestEvent('/test', 'OPTIONS', { 'Access-Control-Request-Method': 'GET' }), context); + const result = await app.resolve( + createTestEvent('/test', 'OPTIONS', { + Origin: origin, + 'Access-Control-Request-Method': 'GET', + }), + context + ); // Assess - expect(result.statusCode).toBe(204); + expect(result.headers?.['access-control-allow-origin']).toBeUndefined(); }); - it('calls the next middleware if the Access-Control-Request-Method is not present', async () => { + it('does not set CORS headers when preflight request header does not match allowed header', async () => { // Prepare - const corsHeaders: { [key: string]: string } = {}; - app.options('/test', [createHeaderCheckMiddleware(corsHeaders)], async () => ({ success: true })); + const app = new Router(); + app.use( + cors({ + allowHeaders: ['Content-Type'], + }) + ); + + // Act + const result = await app.resolve( + createTestEvent('/test', 'OPTIONS', { + Origin: origin, + 'Access-Control-Request-Header': 'x-test-header', + }), + context + ); + + // Assess + expect(result.headers?.['access-control-allow-origin']).toBeUndefined(); + }); + + it('handles OPTIONS preflight requests', async () => { + // Prepare + const app = new Router(); + const corsConfig = { + origin, + allowMethods: ['GET', 'POST'], + allowHeaders: ['Authorization', 'Content-Type'], + maxAge: 3600, + }; + app.use(cors(corsConfig)); // Act - await app.resolve(optionsRequestEvent, context); + const result = await app.resolve( + createTestEvent('/test', 'OPTIONS', { + Origin: origin, + 'Access-Control-Request-Method': 'GET', + 'Access-Control-Request-Headers': 'Authorization', + }), + context + ); // Assess - expect(corsHeaders).toMatchObject(expectedDefaultHeaders); + expect(result.statusCode).toBe(204); + expect(result.headers?.['access-control-allow-origin']).toEqual( + corsConfig.origin + ); + expect(result.multiValueHeaders?.['access-control-allow-methods']).toEqual( + corsConfig.allowMethods + ); + expect(result.multiValueHeaders?.['access-control-allow-headers']).toEqual( + corsConfig.allowHeaders.map((header) => header.toLowerCase()) + ); + expect(result.headers?.['access-control-max-age']).toEqual( + corsConfig.maxAge.toString() + ); }); });