Skip to content

Commit 75cc53c

Browse files
author
Jaap Roes
committed
Add conf argument to CorsMiddleware
1 parent 3dc7093 commit 75cc53c

File tree

1 file changed

+21
-14
lines changed

1 file changed

+21
-14
lines changed

src/corsheaders/middleware.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from django.utils.cache import patch_vary_headers
1414

1515
from corsheaders.conf import conf
16+
from corsheaders.conf import Settings
1617
from corsheaders.signals import check_request_enabled
1718

1819
ACCESS_CONTROL_ALLOW_ORIGIN = "access-control-allow-origin"
@@ -35,8 +36,10 @@ def __init__(
3536
Callable[[HttpRequest], HttpResponseBase]
3637
| Callable[[HttpRequest], Awaitable[HttpResponseBase]]
3738
),
39+
conf: Settings = conf,
3840
) -> None:
3941
self.get_response = get_response
42+
self.conf = conf
4043
if asyncio.iscoroutinefunction(self.get_response):
4144
# Mark the class as async-capable, but do the actual switch
4245
# inside __call__ to avoid swapping out dunder methods
@@ -105,34 +108,38 @@ def add_response_headers(
105108
except ValueError:
106109
return response
107110

108-
if conf.CORS_ALLOW_CREDENTIALS:
111+
if self.conf.CORS_ALLOW_CREDENTIALS:
109112
response[ACCESS_CONTROL_ALLOW_CREDENTIALS] = "true"
110113

111114
if (
112-
not conf.CORS_ALLOW_ALL_ORIGINS
115+
not self.conf.CORS_ALLOW_ALL_ORIGINS
113116
and not self.origin_found_in_white_lists(origin, url)
114117
and not self.check_signal(request)
115118
):
116119
return response
117120

118-
if conf.CORS_ALLOW_ALL_ORIGINS and not conf.CORS_ALLOW_CREDENTIALS:
121+
if self.conf.CORS_ALLOW_ALL_ORIGINS and not self.conf.CORS_ALLOW_CREDENTIALS:
119122
response[ACCESS_CONTROL_ALLOW_ORIGIN] = "*"
120123
else:
121124
response[ACCESS_CONTROL_ALLOW_ORIGIN] = origin
122125

123-
if len(conf.CORS_EXPOSE_HEADERS):
126+
if len(self.conf.CORS_EXPOSE_HEADERS):
124127
response[ACCESS_CONTROL_EXPOSE_HEADERS] = ", ".join(
125-
conf.CORS_EXPOSE_HEADERS
128+
self.conf.CORS_EXPOSE_HEADERS
126129
)
127130

128131
if request.method == "OPTIONS":
129-
response[ACCESS_CONTROL_ALLOW_HEADERS] = ", ".join(conf.CORS_ALLOW_HEADERS)
130-
response[ACCESS_CONTROL_ALLOW_METHODS] = ", ".join(conf.CORS_ALLOW_METHODS)
131-
if conf.CORS_PREFLIGHT_MAX_AGE:
132-
response[ACCESS_CONTROL_MAX_AGE] = str(conf.CORS_PREFLIGHT_MAX_AGE)
132+
response[ACCESS_CONTROL_ALLOW_HEADERS] = ", ".join(
133+
self.conf.CORS_ALLOW_HEADERS
134+
)
135+
response[ACCESS_CONTROL_ALLOW_METHODS] = ", ".join(
136+
self.conf.CORS_ALLOW_METHODS
137+
)
138+
if self.conf.CORS_PREFLIGHT_MAX_AGE:
139+
response[ACCESS_CONTROL_MAX_AGE] = str(self.conf.CORS_PREFLIGHT_MAX_AGE)
133140

134141
if (
135-
conf.CORS_ALLOW_PRIVATE_NETWORK
142+
self.conf.CORS_ALLOW_PRIVATE_NETWORK
136143
and request.headers.get(ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK) == "true"
137144
):
138145
response[ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK] = "true"
@@ -141,28 +148,28 @@ def add_response_headers(
141148

142149
def origin_found_in_white_lists(self, origin: str, url: SplitResult) -> bool:
143150
return (
144-
(origin == "null" and origin in conf.CORS_ALLOWED_ORIGINS)
151+
(origin == "null" and origin in self.conf.CORS_ALLOWED_ORIGINS)
145152
or self._url_in_whitelist(url)
146153
or self.regex_domain_match(origin)
147154
)
148155

149156
def regex_domain_match(self, origin: str) -> bool:
150157
return any(
151158
re.match(domain_pattern, origin)
152-
for domain_pattern in conf.CORS_ALLOWED_ORIGIN_REGEXES
159+
for domain_pattern in self.conf.CORS_ALLOWED_ORIGIN_REGEXES
153160
)
154161

155162
def is_enabled(self, request: HttpRequest) -> bool:
156163
return bool(
157-
re.match(conf.CORS_URLS_REGEX, request.path_info)
164+
re.match(self.conf.CORS_URLS_REGEX, request.path_info)
158165
) or self.check_signal(request)
159166

160167
def check_signal(self, request: HttpRequest) -> bool:
161168
signal_responses = check_request_enabled.send(sender=None, request=request)
162169
return any(return_value for function, return_value in signal_responses)
163170

164171
def _url_in_whitelist(self, url: SplitResult) -> bool:
165-
origins = [urlsplit(o) for o in conf.CORS_ALLOWED_ORIGINS]
172+
origins = [urlsplit(o) for o in self.conf.CORS_ALLOWED_ORIGINS]
166173
return any(
167174
origin.scheme == url.scheme and origin.netloc == url.netloc
168175
for origin in origins

0 commit comments

Comments
 (0)