Skip to content

Commit cdbd01e

Browse files
committed
add allowed_origins method to CORSConfig
1 parent 4ee67b2 commit cdbd01e

File tree

1 file changed

+16
-7
lines changed

1 file changed

+16
-7
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import base64
24
import json
35
import logging
@@ -191,10 +193,10 @@ def __init__(
191193
A boolean value that sets the value of `Access-Control-Allow-Credentials`
192194
"""
193195

194-
self.allowed_origins = [allow_origin]
196+
self._allowed_origins = [allow_origin]
195197

196198
if extra_origins:
197-
self.allowed_origins.extend(extra_origins)
199+
self._allowed_origins.extend(extra_origins)
198200

199201
self.allow_headers = set(self._REQUIRED_HEADERS + (allow_headers or []))
200202
self.expose_headers = expose_headers or []
@@ -210,7 +212,7 @@ def to_dict(self, origin: Optional[str]) -> Dict[str, str]:
210212

211213
# If the origin doesn't match any of the allowed origins, and we don't allow all origins ("*"),
212214
# don't add any CORS headers
213-
if origin not in self.allowed_origins and "*" not in self.allowed_origins:
215+
if origin not in self._allowed_origins and "*" not in self._allowed_origins:
214216
return {}
215217

216218
# The origin matched an allowed origin, so return the CORS headers
@@ -227,6 +229,14 @@ def to_dict(self, origin: Optional[str]) -> Dict[str, str]:
227229
headers["Access-Control-Allow-Credentials"] = "true"
228230
return headers
229231

232+
def allowed_origin(self, extracted_origin: str) -> str | None:
233+
if extracted_origin in self._allowed_origins:
234+
return extracted_origin
235+
if extracted_origin is not None and "*" in self._allowed_origins:
236+
return "*"
237+
238+
return None
239+
230240
@staticmethod
231241
def build_allow_methods(methods: Set[str]) -> str:
232242
"""Build sorted comma delimited methods for Access-Control-Allow-Methods header
@@ -812,10 +822,9 @@ def _add_cors(self, event: ResponseEventT, cors: CORSConfig):
812822
"""Update headers to include the configured Access-Control headers"""
813823
extracted_origin_header = extract_origin_header(event.resolved_headers_field)
814824

815-
if extracted_origin_header in cors.allowed_origins:
816-
self.response.headers.update(cors.to_dict(extracted_origin_header))
817-
if extracted_origin_header is not None and "*" in cors.allowed_origins:
818-
self.response.headers.update(cors.to_dict("*"))
825+
origin = cors.allowed_origin(extracted_origin_header)
826+
if origin is not None:
827+
self.response.headers.update(cors.to_dict(origin))
819828

820829
def _add_cache_control(self, cache_control: str):
821830
"""Set the specified cache control headers for 200 http responses. For non-200 `no-cache` is used."""

0 commit comments

Comments
 (0)