Skip to content

Commit 28b7cad

Browse files
authored
Merge commit from fork
1 parent df5cf40 commit 28b7cad

File tree

2 files changed

+47
-3
lines changed

2 files changed

+47
-3
lines changed

litestar/config/cors.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from __future__ import annotations
22

33
import re
4+
import uuid
45
from dataclasses import dataclass, field
56
from functools import cached_property
67
from re import Pattern
7-
from typing import TYPE_CHECKING, Literal
8+
from typing import TYPE_CHECKING, Final, Literal
89

910
from litestar.constants import DEFAULT_ALLOWED_CORS_HEADERS
1011

@@ -15,6 +16,11 @@
1516
from litestar.types import Method
1617

1718

19+
# this is just a UUID, so we can be sure it's not contained within the string we're
20+
# calling '.replace' on
21+
_RE_ESCAPE_PLACEHOLDER: Final = uuid.uuid4().hex
22+
23+
1824
@dataclass
1925
class CORSConfig:
2026
"""Configuration for CORS (Cross-Origin Resource Sharing).
@@ -60,10 +66,14 @@ def allowed_origins_regex(self) -> Pattern[str]:
6066
Returns:
6167
A compiled regex of the allowed path.
6268
"""
63-
origins = self.allow_origins
69+
# escape the allowed origins, while turning '*' into wildcard '.*' matches
70+
origins = [
71+
re.escape(o.replace("*", _RE_ESCAPE_PLACEHOLDER)).replace(_RE_ESCAPE_PLACEHOLDER, ".*")
72+
for o in self.allow_origins
73+
]
6474
if self.allow_origin_regex:
6575
origins.append(self.allow_origin_regex)
66-
return re.compile("|".join([origin.replace("*.", r".*\.") for origin in origins]))
76+
return re.compile("|".join(origins))
6777

6878
@cached_property
6979
def is_allow_all_origins(self) -> bool:

tests/unit/test_middleware/test_cors_middleware.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,3 +121,37 @@ def handler() -> dict[str, str]:
121121
assert response.headers.get("Access-Control-Allow-Origin") == origin
122122
else:
123123
assert not response.headers.get("Access-Control-Allow-Origin")
124+
125+
126+
@pytest.mark.parametrize(
127+
"allow_origin,origin,host,should_allow",
128+
[
129+
("httpx://good.example", "https://goodXexample", "example.com", False),
130+
("https://*good.example", "https://very.good.example", "very.good.example", True),
131+
("https://*good.example", "https://verygood.example", "vergood.example", True),
132+
("https://*good.example", "https://good.example", "good.example", True),
133+
("https://*good.example", "https://bad.example", "bad.example", False),
134+
("https://*.good.example", "https://very.good.example", "very.good.example", True),
135+
("https://*.good.example", "https://verygood.example", "verygood.example", False),
136+
("https://*.good.example", "https://some.verygood.example", "verygood.example", False),
137+
("https://*.good.example", "https://good.example", "good.example", False),
138+
],
139+
)
140+
def test_cors_test_regex_escape(allow_origin: str, origin: str, host: str, should_allow: bool) -> None:
141+
@get("/")
142+
async def handler() -> None:
143+
return None
144+
145+
with create_test_client(
146+
[handler],
147+
cors_config=CORSConfig(
148+
allow_origins=[allow_origin],
149+
allow_credentials=True,
150+
),
151+
) as client:
152+
res = client.get("/", headers={"Origin": origin, "Host": host})
153+
154+
if should_allow:
155+
assert "Access-Control-Allow-Origin" in res.headers
156+
else:
157+
assert "Access-Control-Allow-Origin" not in res.headers

0 commit comments

Comments
 (0)