Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions packages/event-handler/src/rest/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,23 @@ export const SAFE_CHARS = "-._~()'!*:@,;=+&$";

export const UNSAFE_CHARS = '%<> \\[\\]{}|^';

/**
* Default CORS configuration
*/
export const DEFAULT_CORS_OPTIONS = {
origin: '*',
allowMethods: ['DELETE', 'GET', 'HEAD', 'PATCH', 'POST', 'PUT'],
allowHeaders: [
'Authorization',
'Content-Type',
'X-Amz-Date',
'X-Api-Key',
'X-Amz-Security-Token',
],
exposeHeaders: [],
credentials: false
};

export const DEFAULT_COMPRESSION_RESPONSE_THRESHOLD = 1024;

export const CACHE_CONTROL_NO_TRANSFORM_REGEX =
Expand Down
98 changes: 98 additions & 0 deletions packages/event-handler/src/rest/middleware/cors.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import type {
CorsOptions,
Middleware,
} from '../../types/rest.js';
import {
DEFAULT_CORS_OPTIONS,
HttpErrorCodes,
HttpVerbs,
} from '../constants.js';

/**
* Resolves the origin value based on the configuration
*/
const resolveOrigin = (
originConfig: NonNullable<CorsOptions['origin']>,
requestOrigin: string | null,
): string => {
if (Array.isArray(originConfig)) {
return requestOrigin && originConfig.includes(requestOrigin) ? requestOrigin : '';
}
return originConfig;
};

/**
* Creates a CORS middleware that adds appropriate CORS headers to responses
* and handles OPTIONS preflight requests.
*
* @example
* ```typescript
* import { Router } from '@aws-lambda-powertools/event-handler/experimental-rest';
* import { cors } from '@aws-lambda-powertools/event-handler/experimental-rest/middleware';
*
* const app = new Router();
*
* // Use default configuration
* app.use(cors());
*
* // Custom configuration
* app.use(cors({
* origin: 'https://example.com',
* allowMethods: ['GET', 'POST'],
* credentials: true,
* }));
*
* // Dynamic origin with function
* app.use(cors({
* origin: (origin, reqCtx) => {
* const allowedOrigins = ['https://app.com', 'https://admin.app.com'];
* return origin && allowedOrigins.includes(origin);
* }
* }));
* ```
*
* @param options.origin - The origin to allow requests from
* @param options.allowMethods - The HTTP methods to allow
* @param options.allowHeaders - The headers to allow
* @param options.exposeHeaders - The headers to expose
* @param options.credentials - Whether to allow credentials
* @param options.maxAge - The maximum age for the preflight response
*/
export const cors = (options?: CorsOptions): Middleware => {
const config = {
...DEFAULT_CORS_OPTIONS,
...options
};

return async (_params, reqCtx, next) => {
const requestOrigin = reqCtx.request.headers.get('Origin');
const resolvedOrigin = resolveOrigin(config.origin, requestOrigin);

reqCtx.res.headers.set('access-control-allow-origin', resolvedOrigin);
if (resolvedOrigin !== '*') {
reqCtx.res.headers.set('Vary', 'Origin');
}
config.allowMethods.forEach(method => {
reqCtx.res.headers.append('access-control-allow-methods', method);
});
config.allowHeaders.forEach(header => {
reqCtx.res.headers.append('access-control-allow-headers', header);
});
config.exposeHeaders.forEach(header => {
reqCtx.res.headers.append('access-control-expose-headers', header);
});
reqCtx.res.headers.set('access-control-allow-credentials', config.credentials.toString());
if (config.maxAge !== undefined) {
reqCtx.res.headers.set('access-control-max-age', config.maxAge.toString());
}

// Handle preflight OPTIONS request
if (reqCtx.request.method === HttpVerbs.OPTIONS && reqCtx.request.headers.has('Access-Control-Request-Method')) {
return new Response(null, {
status: HttpErrorCodes.NO_CONTENT,
headers: reqCtx.res.headers,
});
}
await next();
};
};
1 change: 1 addition & 0 deletions packages/event-handler/src/rest/middleware/index.ts
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
export { compress } from './compress.js';
export { cors } from './cors.js';
1 change: 1 addition & 0 deletions packages/event-handler/src/types/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ export type {
} from './common.js';

export type {
CorsOptions,
ErrorHandler,
ErrorResolveOptions,
ErrorResponse,
Expand Down
43 changes: 43 additions & 0 deletions packages/event-handler/src/types/rest.ts
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,56 @@ type ValidationResult = {
issues: string[];
};

/**
* Configuration options for CORS middleware
*/
type CorsOptions = {
/**
* The Access-Control-Allow-Origin header value.
* Can be a string, array of strings.
* @default '*'
*/
origin?: string | string[];

/**
* The Access-Control-Allow-Methods header value.
* @default ['DELETE', 'GET', 'HEAD', 'PATCH', 'POST', 'PUT']
*/
allowMethods?: string[];

/**
* The Access-Control-Allow-Headers header value.
* @default ['Authorization', 'Content-Type', 'X-Amz-Date', 'X-Api-Key', 'X-Amz-Security-Token']
*/
allowHeaders?: string[];

/**
* The Access-Control-Expose-Headers header value.
* @default []
*/
exposeHeaders?: string[];

/**
* The Access-Control-Allow-Credentials header value.
* @default false
*/
credentials?: boolean;

/**
* The Access-Control-Max-Age header value in seconds.
* Only applicable for preflight requests.
*/
maxAge?: number;
};

type CompressionOptions = {
encoding?: 'gzip' | 'deflate';
threshold?: number;
};

export type {
CompiledRoute,
CorsOptions,
DynamicRoute,
ErrorResponse,
ErrorConstructor,
Expand Down
11 changes: 11 additions & 0 deletions packages/event-handler/tests/unit/rest/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,14 @@ export const createSettingHeadersMiddleware = (headers: {
});
};
};

export const createHeaderCheckMiddleware = (headers: {
[key: string]: string;
}): Middleware => {
return async (_params, options, next) => {
options.res.headers.forEach((value, key) => {
headers[key] = value;
});
await next();
};
};
96 changes: 96 additions & 0 deletions packages/event-handler/tests/unit/rest/middleware/cors.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import { beforeEach, describe, expect, it } from 'vitest';
import context from '@aws-lambda-powertools/testing-utils/context';
import { cors } from '../../../../src/rest/middleware/cors.js';
import { createTestEvent, createHeaderCheckMiddleware } from '../helpers.js';
import { Router } from '../../../../src/rest/Router.js';
import { DEFAULT_CORS_OPTIONS } from 'src/rest/constants.js';

describe('CORS Middleware', () => {
const getRequestEvent = createTestEvent('/test', 'GET');
const optionsRequestEvent = createTestEvent('/test', 'OPTIONS');
let app: Router;

const customCorsOptions = {
origin: 'https://example.com',
allowMethods: ['GET', 'POST'],
allowHeaders: ['Authorization', 'Content-Type'],
credentials: true,
exposeHeaders: ['Authorization', 'X-Custom-Header'],
maxAge: 86400,
};

const expectedDefaultHeaders = {
"access-control-allow-credentials": "false",
"access-control-allow-headers": "Authorization, Content-Type, X-Amz-Date, X-Api-Key, X-Amz-Security-Token",
"access-control-allow-methods": "DELETE, GET, HEAD, PATCH, POST, PUT",
"access-control-allow-origin": "*",
};

beforeEach(() => {
app = new Router();
app.use(cors());
});

it('uses default configuration when no options are provided', async () => {
const corsHeaders: { [key: string]: string } = {};
app.get('/test', [createHeaderCheckMiddleware(corsHeaders)], async () => ({ success: true }));

const result = await app.resolve(getRequestEvent, context);

expect(result.headers?.['access-control-allow-origin']).toEqual(DEFAULT_CORS_OPTIONS.origin);
expect(result.multiValueHeaders?.['access-control-allow-methods']).toEqual(DEFAULT_CORS_OPTIONS.allowMethods);
expect(result.multiValueHeaders?.['access-control-allow-headers']).toEqual(DEFAULT_CORS_OPTIONS.allowHeaders);
expect(result.headers?.['access-control-allow-credentials']).toEqual(DEFAULT_CORS_OPTIONS.credentials.toString());
expect(corsHeaders).toMatchObject(expectedDefaultHeaders);
});

it('merges user options with defaults', async () => {
const corsHeaders: { [key: string]: string } = {};
const customApp = new Router();
customApp.get('/test', [cors(customCorsOptions), createHeaderCheckMiddleware(corsHeaders)], async () => ({ success: true }));

const result = await customApp.resolve(getRequestEvent, context);

expect(result.headers?.['access-control-allow-origin']).toEqual('https://example.com');
expect(result.multiValueHeaders?.['access-control-allow-methods']).toEqual(['GET', 'POST']);
expect(result.multiValueHeaders?.['access-control-allow-headers']).toEqual(['Authorization', 'Content-Type']);
expect(result.headers?.['access-control-allow-credentials']).toEqual('true');
expect(result.multiValueHeaders?.['access-control-expose-headers']).toEqual(['Authorization', 'X-Custom-Header']);
expect(result.headers?.['access-control-max-age']).toEqual('86400');
expect(corsHeaders).toMatchObject({
"access-control-allow-credentials": "true",
"access-control-allow-headers": "Authorization, Content-Type",
"access-control-allow-methods": "GET, POST",
"access-control-allow-origin": "https://example.com",
});
});

it.each([
['matching', 'https://app.com', 'https://app.com'],
['non-matching', 'https://non-matching.com', '']
])('handles array origin with %s request', async (_, origin, expected) => {
const customApp = new Router();
customApp.get('/test', [cors({ origin: ['https://app.com', 'https://admin.app.com'] })], async () => ({ success: true }));

const result = await customApp.resolve(createTestEvent('/test', 'GET', { 'Origin': origin }), context);

expect(result.headers?.['access-control-allow-origin']).toEqual(expected);
});

it('handles OPTIONS preflight requests', async () => {
app.options('/test', async () => ({ foo: 'bar' }));

const result = await app.resolve(createTestEvent('/test', 'OPTIONS', { 'Access-Control-Request-Method': 'GET' }), context);

expect(result.statusCode).toBe(204);
});

it('calls the next middleware if the Access-Control-Request-Method is not present', async () => {
const corsHeaders: { [key: string]: string } = {};
app.options('/test', [createHeaderCheckMiddleware(corsHeaders)], async () => ({ success: true }));

await app.resolve(optionsRequestEvent, context);

expect(corsHeaders).toMatchObject(expectedDefaultHeaders);
});
});