|
12 | 12 |
|
13 | 13 | if TYPE_CHECKING: |
14 | 14 | from django.http import HttpRequest, HttpResponse # noqa: F401 |
15 | | - from typing import Callable, Dict, Any, Optional # noqa: F401 |
| 15 | + from typing import Callable, Dict, Any, Optional, Union, Awaitable # noqa: F401 |
16 | 16 |
|
17 | 17 |
|
18 | 18 | class PosthogContextMiddleware: |
@@ -46,9 +46,19 @@ class PosthogContextMiddleware: |
46 | 46 | async_capable = True |
47 | 47 |
|
48 | 48 | def __init__(self, get_response): |
49 | | - # type: (Callable[[HttpRequest], HttpResponse]) -> None |
50 | | - self.get_response = get_response |
| 49 | + # type: (Union[Callable[[HttpRequest], HttpResponse], Callable[[HttpRequest], Awaitable[HttpResponse]]]) -> None |
51 | 50 | self._is_coroutine = iscoroutinefunction(get_response) |
| 51 | + self._async_get_response = None # type: Optional[Callable[[HttpRequest], Awaitable[HttpResponse]]] |
| 52 | + self._sync_get_response = None # type: Optional[Callable[[HttpRequest], HttpResponse]] |
| 53 | + |
| 54 | + if self._is_coroutine: |
| 55 | + self._async_get_response = cast( |
| 56 | + "Callable[[HttpRequest], Awaitable[HttpResponse]]", get_response |
| 57 | + ) |
| 58 | + else: |
| 59 | + self._sync_get_response = cast( |
| 60 | + "Callable[[HttpRequest], HttpResponse]", get_response |
| 61 | + ) |
52 | 62 |
|
53 | 63 | from django.conf import settings |
54 | 64 |
|
@@ -180,27 +190,31 @@ def __call__(self, request): |
180 | 190 | ) |
181 | 191 |
|
182 | 192 | if self.request_filter and not self.request_filter(request): |
183 | | - return self.get_response(request) |
| 193 | + assert self._sync_get_response is not None |
| 194 | + return self._sync_get_response(request) |
184 | 195 |
|
185 | 196 | with contexts.new_context(self.capture_exceptions, client=self.client): |
186 | 197 | for k, v in self.extract_tags(request).items(): |
187 | 198 | contexts.tag(k, v) |
188 | 199 |
|
189 | | - return self.get_response(request) |
| 200 | + assert self._sync_get_response is not None |
| 201 | + return self._sync_get_response(request) |
190 | 202 |
|
191 | 203 | async def __acall__(self, request): |
192 | 204 | # type: (HttpRequest) -> HttpResponse |
193 | 205 | if self.request_filter and not self.request_filter(request): |
194 | | - if self._is_coroutine: |
195 | | - return await self.get_response(request) # type: ignore |
| 206 | + if self._async_get_response is not None: |
| 207 | + return await self._async_get_response(request) |
196 | 208 | else: |
197 | | - return self.get_response(request) # type: ignore |
| 209 | + assert self._sync_get_response is not None |
| 210 | + return self._sync_get_response(request) |
198 | 211 |
|
199 | 212 | with contexts.new_context(self.capture_exceptions, client=self.client): |
200 | 213 | for k, v in self.extract_tags(request).items(): |
201 | 214 | contexts.tag(k, v) |
202 | 215 |
|
203 | | - if self._is_coroutine: |
204 | | - return await self.get_response(request) # type: ignore |
| 216 | + if self._async_get_response is not None: |
| 217 | + return await self._async_get_response(request) |
205 | 218 | else: |
206 | | - return self.get_response(request) # type: ignore |
| 219 | + assert self._sync_get_response is not None |
| 220 | + return self._sync_get_response(request) |
0 commit comments