|
1 | 1 | from typing import TYPE_CHECKING, cast |
2 | | -from posthog import contexts, capture_exception |
| 2 | +from posthog import contexts |
3 | 3 | from posthog.client import Client |
4 | 4 |
|
| 5 | +try: |
| 6 | + from asgiref.sync import iscoroutinefunction |
| 7 | +except ImportError: |
| 8 | + # Fallback for older Django versions |
| 9 | + import asyncio |
| 10 | + |
| 11 | + iscoroutinefunction = asyncio.iscoroutinefunction |
| 12 | + |
5 | 13 | if TYPE_CHECKING: |
6 | 14 | from django.http import HttpRequest, HttpResponse # noqa: F401 |
7 | | - from typing import Callable, Dict, Any, Optional # noqa: F401 |
| 15 | + from typing import Callable, Dict, Any, Optional, Union, Awaitable # noqa: F401 |
8 | 16 |
|
9 | 17 |
|
10 | 18 | class PosthogContextMiddleware: |
@@ -33,9 +41,24 @@ class PosthogContextMiddleware: |
33 | 41 | frontend. See the documentation for `set_context_session` and `identify_context` for more details. |
34 | 42 | """ |
35 | 43 |
|
| 44 | + # Django middleware capability flags |
| 45 | + sync_capable = True |
| 46 | + async_capable = True |
| 47 | + |
36 | 48 | def __init__(self, get_response): |
37 | | - # type: (Callable[[HttpRequest], HttpResponse]) -> None |
38 | | - self.get_response = get_response |
| 49 | + # type: (Union[Callable[[HttpRequest], HttpResponse], Callable[[HttpRequest], Awaitable[HttpResponse]]]) -> None |
| 50 | + self._is_coroutine = iscoroutinefunction(get_response) |
| 51 | + self._async_get_response = None # type: Optional[Callable[[HttpRequest], Awaitable[HttpResponse]]] |
| 52 | + self._sync_get_response = None # type: Optional[Callable[[HttpRequest], HttpResponse]] |
| 53 | + |
| 54 | + if self._is_coroutine: |
| 55 | + self._async_get_response = cast( |
| 56 | + "Callable[[HttpRequest], Awaitable[HttpResponse]]", get_response |
| 57 | + ) |
| 58 | + else: |
| 59 | + self._sync_get_response = cast( |
| 60 | + "Callable[[HttpRequest], HttpResponse]", get_response |
| 61 | + ) |
39 | 62 |
|
40 | 63 | from django.conf import settings |
41 | 64 |
|
@@ -159,23 +182,39 @@ def extract_request_user(self, request): |
159 | 182 |
|
160 | 183 | def __call__(self, request): |
161 | 184 | # type: (HttpRequest) -> HttpResponse |
| 185 | + # Purely defensive around django's internal sync/async handling - this should be unreachable, but if it's reached, we may |
| 186 | + # as well return something semi-meaningful |
| 187 | + if self._is_coroutine: |
| 188 | + raise RuntimeError( |
| 189 | + "PosthogContextMiddleware received sync call but get_response is async" |
| 190 | + ) |
| 191 | + |
162 | 192 | if self.request_filter and not self.request_filter(request): |
163 | | - return self.get_response(request) |
| 193 | + assert self._sync_get_response is not None |
| 194 | + return self._sync_get_response(request) |
164 | 195 |
|
165 | 196 | with contexts.new_context(self.capture_exceptions, client=self.client): |
166 | 197 | for k, v in self.extract_tags(request).items(): |
167 | 198 | contexts.tag(k, v) |
168 | 199 |
|
169 | | - return self.get_response(request) |
| 200 | + assert self._sync_get_response is not None |
| 201 | + return self._sync_get_response(request) |
170 | 202 |
|
171 | | - def process_exception(self, request, exception): |
| 203 | + async def __acall__(self, request): |
| 204 | + # type: (HttpRequest) -> HttpResponse |
172 | 205 | if self.request_filter and not self.request_filter(request): |
173 | | - return |
| 206 | + if self._async_get_response is not None: |
| 207 | + return await self._async_get_response(request) |
| 208 | + else: |
| 209 | + assert self._sync_get_response is not None |
| 210 | + return self._sync_get_response(request) |
174 | 211 |
|
175 | | - if not self.capture_exceptions: |
176 | | - return |
| 212 | + with contexts.new_context(self.capture_exceptions, client=self.client): |
| 213 | + for k, v in self.extract_tags(request).items(): |
| 214 | + contexts.tag(k, v) |
177 | 215 |
|
178 | | - if self.client: |
179 | | - self.client.capture_exception(exception) |
180 | | - else: |
181 | | - capture_exception(exception) |
| 216 | + if self._async_get_response is not None: |
| 217 | + return await self._async_get_response(request) |
| 218 | + else: |
| 219 | + assert self._sync_get_response is not None |
| 220 | + return self._sync_get_response(request) |
0 commit comments