1
- import type {
2
- CorsOptions ,
3
- Middleware ,
4
- } from '../../types/rest.js' ;
1
+ import type { CorsOptions , Middleware } from '../../types/rest.js' ;
5
2
import {
6
3
DEFAULT_CORS_OPTIONS ,
7
4
HttpErrorCodes ,
8
5
HttpVerbs ,
9
6
} from '../constants.js' ;
10
7
11
- /**
12
- * Resolves the origin value based on the configuration
13
- */
14
- const resolveOrigin = (
15
- originConfig : NonNullable < CorsOptions [ 'origin' ] > ,
16
- requestOrigin : string | null ,
17
- ) : string => {
18
- if ( Array . isArray ( originConfig ) ) {
19
- return requestOrigin && originConfig . includes ( requestOrigin ) ? requestOrigin : '' ;
20
- }
21
- return originConfig ;
22
- } ;
23
-
24
8
/**
25
9
* Creates a CORS middleware that adds appropriate CORS headers to responses
26
10
* and handles OPTIONS preflight requests.
@@ -29,9 +13,9 @@ const resolveOrigin = (
29
13
* ```typescript
30
14
* import { Router } from '@aws-lambda-powertools/event-handler/experimental-rest';
31
15
* import { cors } from '@aws-lambda-powertools/event-handler/experimental-rest/middleware';
32
- *
16
+ *
33
17
* const app = new Router();
34
- *
18
+ *
35
19
* // Use default configuration
36
20
* app.use(cors());
37
21
*
@@ -50,7 +34,7 @@ const resolveOrigin = (
50
34
* }
51
35
* }));
52
36
* ```
53
- *
37
+ *
54
38
* @param options.origin - The origin to allow requests from
55
39
* @param options.allowMethods - The HTTP methods to allow
56
40
* @param options.allowHeaders - The headers to allow
@@ -61,38 +45,76 @@ const resolveOrigin = (
61
45
export const cors = ( options ?: CorsOptions ) : Middleware => {
62
46
const config = {
63
47
...DEFAULT_CORS_OPTIONS ,
64
- ...options
48
+ ...options ,
65
49
} ;
50
+ const allowedOrigins =
51
+ typeof config . origin === 'string' ? [ config . origin ] : config . origin ;
52
+ const allowsWildcard = allowedOrigins . includes ( '*' ) ;
66
53
67
54
return async ( _params , reqCtx , next ) => {
68
55
const requestOrigin = reqCtx . request . headers . get ( 'Origin' ) ;
69
- const resolvedOrigin = resolveOrigin ( config . origin , requestOrigin ) ;
56
+ if (
57
+ ! requestOrigin ||
58
+ ( ! allowsWildcard && ! allowedOrigins . includes ( requestOrigin ) )
59
+ ) {
60
+ await next ( ) ;
61
+ return ;
62
+ }
63
+
64
+ const isOptions = reqCtx . request . method === HttpVerbs . OPTIONS ;
65
+ // Handle preflight OPTIONS request
66
+ if ( isOptions ) {
67
+ const requestMethod = reqCtx . request . headers . get (
68
+ 'Access-Control-Request-Method'
69
+ ) ;
70
+ const requestHeaders = reqCtx . request . headers . get (
71
+ 'Access-Control-Request-Headers'
72
+ ) ;
73
+ if (
74
+ ! requestMethod ||
75
+ ! config . allowMethods . includes ( requestMethod ) ||
76
+ ! requestHeaders ||
77
+ requestHeaders
78
+ . split ( ',' )
79
+ . some ( ( header ) => ! config . allowHeaders . includes ( header . trim ( ) ) )
80
+ ) {
81
+ await next ( ) ;
82
+ return ;
83
+ }
84
+ }
70
85
86
+ const resolvedOrigin = allowsWildcard ? '*' : requestOrigin ;
71
87
reqCtx . res . headers . set ( 'access-control-allow-origin' , resolvedOrigin ) ;
72
- if ( resolvedOrigin !== '*' ) {
73
- reqCtx . res . headers . set ( 'Vary ' , 'Origin' ) ;
88
+ if ( ! allowsWildcard && Array . isArray ( config . origin ) ) {
89
+ reqCtx . res . headers . set ( 'vary ' , 'Origin' ) ;
74
90
}
75
- config . allowMethods . forEach ( method => {
76
- reqCtx . res . headers . append ( 'access-control-allow-methods' , method ) ;
77
- } ) ;
78
- config . allowHeaders . forEach ( header => {
79
- reqCtx . res . headers . append ( 'access-control-allow-headers' , header ) ;
80
- } ) ;
81
- config . exposeHeaders . forEach ( header => {
82
- reqCtx . res . headers . append ( 'access-control-expose-headers' , header ) ;
83
- } ) ;
84
- reqCtx . res . headers . set ( 'access-control-allow-credentials' , config . credentials . toString ( ) ) ;
85
- if ( config . maxAge !== undefined ) {
86
- reqCtx . res . headers . set ( 'access-control-max-age' , config . maxAge . toString ( ) ) ;
91
+ if ( config . credentials ) {
92
+ reqCtx . res . headers . set ( 'access-control-allow-credentials' , 'true' ) ;
87
93
}
88
94
89
- // Handle preflight OPTIONS request
90
- if ( reqCtx . request . method === HttpVerbs . OPTIONS && reqCtx . request . headers . has ( 'Access-Control-Request-Method' ) ) {
95
+ if ( isOptions ) {
96
+ if ( config . maxAge !== undefined ) {
97
+ reqCtx . res . headers . set (
98
+ 'access-control-max-age' ,
99
+ config . maxAge . toString ( )
100
+ ) ;
101
+ }
102
+ config . allowMethods . forEach ( ( method ) => {
103
+ reqCtx . res . headers . append ( 'access-control-allow-methods' , method ) ;
104
+ } ) ;
105
+ config . allowHeaders . forEach ( ( header ) => {
106
+ reqCtx . res . headers . append ( 'access-control-allow-headers' , header ) ;
107
+ } ) ;
91
108
return new Response ( null , {
92
109
status : HttpErrorCodes . NO_CONTENT ,
93
110
headers : reqCtx . res . headers ,
94
111
} ) ;
95
112
}
113
+
114
+ config . exposeHeaders . forEach ( ( header ) => {
115
+ reqCtx . res . headers . append ( 'access-control-expose-headers' , header ) ;
116
+ } ) ;
117
+
96
118
await next ( ) ;
97
119
} ;
98
120
} ;
0 commit comments