1
+ from __future__ import annotations
2
+
1
3
import base64
2
4
import json
3
5
import logging
@@ -191,10 +193,10 @@ def __init__(
191
193
A boolean value that sets the value of `Access-Control-Allow-Credentials`
192
194
"""
193
195
194
- self .allowed_origins = [allow_origin ]
196
+ self ._allowed_origins = [allow_origin ]
195
197
196
198
if extra_origins :
197
- self .allowed_origins .extend (extra_origins )
199
+ self ._allowed_origins .extend (extra_origins )
198
200
199
201
self .allow_headers = set (self ._REQUIRED_HEADERS + (allow_headers or []))
200
202
self .expose_headers = expose_headers or []
@@ -210,7 +212,7 @@ def to_dict(self, origin: Optional[str]) -> Dict[str, str]:
210
212
211
213
# If the origin doesn't match any of the allowed origins, and we don't allow all origins ("*"),
212
214
# don't add any CORS headers
213
- if origin not in self .allowed_origins and "*" not in self .allowed_origins :
215
+ if origin not in self ._allowed_origins and "*" not in self ._allowed_origins :
214
216
return {}
215
217
216
218
# The origin matched an allowed origin, so return the CORS headers
@@ -227,6 +229,14 @@ def to_dict(self, origin: Optional[str]) -> Dict[str, str]:
227
229
headers ["Access-Control-Allow-Credentials" ] = "true"
228
230
return headers
229
231
232
+ def allowed_origin (self , extracted_origin : str ) -> str | None :
233
+ if extracted_origin in self ._allowed_origins :
234
+ return extracted_origin
235
+ if extracted_origin is not None and "*" in self ._allowed_origins :
236
+ return "*"
237
+
238
+ return None
239
+
230
240
@staticmethod
231
241
def build_allow_methods (methods : Set [str ]) -> str :
232
242
"""Build sorted comma delimited methods for Access-Control-Allow-Methods header
@@ -812,10 +822,9 @@ def _add_cors(self, event: ResponseEventT, cors: CORSConfig):
812
822
"""Update headers to include the configured Access-Control headers"""
813
823
extracted_origin_header = extract_origin_header (event .resolved_headers_field )
814
824
815
- if extracted_origin_header in cors .allowed_origins :
816
- self .response .headers .update (cors .to_dict (extracted_origin_header ))
817
- if extracted_origin_header is not None and "*" in cors .allowed_origins :
818
- self .response .headers .update (cors .to_dict ("*" ))
825
+ origin = cors .allowed_origin (extracted_origin_header )
826
+ if origin is not None :
827
+ self .response .headers .update (cors .to_dict (origin ))
819
828
820
829
def _add_cache_control (self , cache_control : str ):
821
830
"""Set the specified cache control headers for 200 http responses. For non-200 `no-cache` is used."""
0 commit comments