Skip to content

Commit dd29199

Browse files
authored
Merge pull request #2 from amirreza8002/middleware2
Implement base async middleware
2 parents d622a7a + 52bd543 commit dd29199

File tree

4 files changed

+131
-0
lines changed

4 files changed

+131
-0
lines changed

django_async_extensions/amiddleware/__init__.py

Whitespace-only changes.
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from asgiref.sync import iscoroutinefunction, markcoroutinefunction
2+
3+
from django.core.exceptions import ImproperlyConfigured
4+
5+
6+
class AsyncMiddlewareMixin:
7+
sync_capable = False
8+
async_capable = True
9+
10+
def __init__(self, get_response):
11+
if get_response is None:
12+
raise ValueError("get_response must be provided.")
13+
self.get_response = get_response
14+
# If get_response is not an async function, raise an error.
15+
self.async_mode = iscoroutinefunction(self.get_response) or iscoroutinefunction(
16+
getattr(self.get_response, "__call__", None)
17+
)
18+
if self.async_mode:
19+
# Mark the class as async-capable.
20+
markcoroutinefunction(self)
21+
else:
22+
raise ImproperlyConfigured("get_response must be async")
23+
24+
super().__init__()
25+
26+
def __repr__(self):
27+
return "<%s get_response=%s>" % (
28+
self.__class__.__qualname__,
29+
getattr(
30+
self.get_response,
31+
"__qualname__",
32+
self.get_response.__class__.__name__,
33+
),
34+
)
35+
36+
async def __call__(self, request):
37+
response = None
38+
if hasattr(self, "process_request"):
39+
response = await self.process_request(request)
40+
response = response or await self.get_response(request)
41+
if hasattr(self, "process_response"):
42+
response = await self.process_response(request, response)
43+
return response

tests/test_middlewares/__init__.py

Whitespace-only changes.
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
from inspect import iscoroutinefunction
2+
3+
import pytest
4+
5+
from django.core.exceptions import ImproperlyConfigured
6+
from django.http.response import HttpResponse
7+
8+
from django_async_extensions.amiddleware.base import AsyncMiddlewareMixin
9+
10+
req = HttpResponse()
11+
resp = HttpResponse()
12+
resp_for_get_response = HttpResponse()
13+
14+
15+
async def async_get_response(request):
16+
return resp_for_get_response
17+
18+
19+
class ResponseMiddleware(AsyncMiddlewareMixin):
20+
async def process_request(self, request):
21+
return req
22+
23+
async def process_response(self, request, response):
24+
return resp
25+
26+
27+
class RequestMiddleware(AsyncMiddlewareMixin):
28+
async def process_request(self, request):
29+
return resp
30+
31+
32+
class TestMiddlewareMixin:
33+
def test_repr(self):
34+
class GetResponse:
35+
async def __call__(self):
36+
return HttpResponse()
37+
38+
async def get_response():
39+
return HttpResponse()
40+
41+
assert (
42+
repr(AsyncMiddlewareMixin(GetResponse()))
43+
== "<AsyncMiddlewareMixin get_response=GetResponse>"
44+
)
45+
assert (
46+
repr(AsyncMiddlewareMixin(get_response))
47+
== "<AsyncMiddlewareMixin get_response="
48+
"TestMiddlewareMixin.test_repr.<locals>.get_response>"
49+
)
50+
51+
def test_call_is_async(self):
52+
assert iscoroutinefunction(AsyncMiddlewareMixin.__call__)
53+
54+
def test_middleware_raises_if_get_response_is_sync(self):
55+
def get_response():
56+
return HttpResponse()
57+
58+
with pytest.raises(ImproperlyConfigured):
59+
AsyncMiddlewareMixin(get_response)
60+
61+
async def test_middleware_get_response(self, client):
62+
middleware = AsyncMiddlewareMixin(async_get_response)
63+
assert await middleware(client) is resp_for_get_response
64+
65+
async def test_middleware_process_request(self, client, mocker):
66+
spy = mocker.spy(RequestMiddleware, "process_request")
67+
68+
middleware = RequestMiddleware(async_get_response)
69+
result = await middleware(client)
70+
assert result is resp is spy.spy_return
71+
assert result is not resp_for_get_response
72+
assert spy.call_count == 1
73+
spy.assert_called_once_with(middleware, client)
74+
75+
async def test_middleware_process_response(self, client, mocker):
76+
spy1 = mocker.spy(ResponseMiddleware, "process_request")
77+
spy2 = mocker.spy(ResponseMiddleware, "process_response")
78+
79+
middleware = ResponseMiddleware(async_get_response)
80+
result = await middleware(client)
81+
82+
assert result is resp is spy2.spy_return
83+
assert result is not resp_for_get_response
84+
assert spy2.call_count == 1
85+
spy2.assert_called_once_with(middleware, client, req)
86+
87+
assert spy1.call_count == 1
88+
assert spy1.spy_return == req

0 commit comments

Comments
 (0)