Skip to content

Commit b2293a4

Browse files
committed
fix types
1 parent 31f6d16 commit b2293a4

File tree

1 file changed

+25
-11
lines changed

1 file changed

+25
-11
lines changed

posthog/integrations/django.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
if TYPE_CHECKING:
1414
from django.http import HttpRequest, HttpResponse # noqa: F401
15-
from typing import Callable, Dict, Any, Optional # noqa: F401
15+
from typing import Callable, Dict, Any, Optional, Union, Awaitable # noqa: F401
1616

1717

1818
class PosthogContextMiddleware:
@@ -46,9 +46,19 @@ class PosthogContextMiddleware:
4646
async_capable = True
4747

4848
def __init__(self, get_response):
49-
# type: (Callable[[HttpRequest], HttpResponse]) -> None
50-
self.get_response = get_response
49+
# type: (Union[Callable[[HttpRequest], HttpResponse], Callable[[HttpRequest], Awaitable[HttpResponse]]]) -> None
5150
self._is_coroutine = iscoroutinefunction(get_response)
51+
self._async_get_response = None # type: Optional[Callable[[HttpRequest], Awaitable[HttpResponse]]]
52+
self._sync_get_response = None # type: Optional[Callable[[HttpRequest], HttpResponse]]
53+
54+
if self._is_coroutine:
55+
self._async_get_response = cast(
56+
"Callable[[HttpRequest], Awaitable[HttpResponse]]", get_response
57+
)
58+
else:
59+
self._sync_get_response = cast(
60+
"Callable[[HttpRequest], HttpResponse]", get_response
61+
)
5262

5363
from django.conf import settings
5464

@@ -180,27 +190,31 @@ def __call__(self, request):
180190
)
181191

182192
if self.request_filter and not self.request_filter(request):
183-
return self.get_response(request)
193+
assert self._sync_get_response is not None
194+
return self._sync_get_response(request)
184195

185196
with contexts.new_context(self.capture_exceptions, client=self.client):
186197
for k, v in self.extract_tags(request).items():
187198
contexts.tag(k, v)
188199

189-
return self.get_response(request)
200+
assert self._sync_get_response is not None
201+
return self._sync_get_response(request)
190202

191203
async def __acall__(self, request):
192204
# type: (HttpRequest) -> HttpResponse
193205
if self.request_filter and not self.request_filter(request):
194-
if self._is_coroutine:
195-
return await self.get_response(request) # type: ignore
206+
if self._async_get_response is not None:
207+
return await self._async_get_response(request)
196208
else:
197-
return self.get_response(request) # type: ignore
209+
assert self._sync_get_response is not None
210+
return self._sync_get_response(request)
198211

199212
with contexts.new_context(self.capture_exceptions, client=self.client):
200213
for k, v in self.extract_tags(request).items():
201214
contexts.tag(k, v)
202215

203-
if self._is_coroutine:
204-
return await self.get_response(request) # type: ignore
216+
if self._async_get_response is not None:
217+
return await self._async_get_response(request)
205218
else:
206-
return self.get_response(request) # type: ignore
219+
assert self._sync_get_response is not None
220+
return self._sync_get_response(request)

0 commit comments

Comments
 (0)