diff --git a/django_async_extensions/amiddleware/__init__.py b/django_async_extensions/amiddleware/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/django_async_extensions/amiddleware/base.py b/django_async_extensions/amiddleware/base.py new file mode 100644 index 0000000..730c33e --- /dev/null +++ b/django_async_extensions/amiddleware/base.py @@ -0,0 +1,43 @@ +from asgiref.sync import iscoroutinefunction, markcoroutinefunction + +from django.core.exceptions import ImproperlyConfigured + + +class AsyncMiddlewareMixin: + sync_capable = False + async_capable = True + + def __init__(self, get_response): + if get_response is None: + raise ValueError("get_response must be provided.") + self.get_response = get_response + # If get_response is not an async function, raise an error. + self.async_mode = iscoroutinefunction(self.get_response) or iscoroutinefunction( + getattr(self.get_response, "__call__", None) + ) + if self.async_mode: + # Mark the class as async-capable. + markcoroutinefunction(self) + else: + raise ImproperlyConfigured("get_response must be async") + + super().__init__() + + def __repr__(self): + return "<%s get_response=%s>" % ( + self.__class__.__qualname__, + getattr( + self.get_response, + "__qualname__", + self.get_response.__class__.__name__, + ), + ) + + async def __call__(self, request): + response = None + if hasattr(self, "process_request"): + response = await self.process_request(request) + response = response or await self.get_response(request) + if hasattr(self, "process_response"): + response = await self.process_response(request, response) + return response diff --git a/tests/test_middlewares/__init__.py b/tests/test_middlewares/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_middlewares/test_middleware_mixin.py b/tests/test_middlewares/test_middleware_mixin.py new file mode 100644 index 0000000..4fcb4a1 --- /dev/null +++ b/tests/test_middlewares/test_middleware_mixin.py @@ -0,0 +1,88 @@ +from inspect import iscoroutinefunction + +import pytest + +from django.core.exceptions import ImproperlyConfigured +from django.http.response import HttpResponse + +from django_async_extensions.amiddleware.base import AsyncMiddlewareMixin + +req = HttpResponse() +resp = HttpResponse() +resp_for_get_response = HttpResponse() + + +async def async_get_response(request): + return resp_for_get_response + + +class ResponseMiddleware(AsyncMiddlewareMixin): + async def process_request(self, request): + return req + + async def process_response(self, request, response): + return resp + + +class RequestMiddleware(AsyncMiddlewareMixin): + async def process_request(self, request): + return resp + + +class TestMiddlewareMixin: + def test_repr(self): + class GetResponse: + async def __call__(self): + return HttpResponse() + + async def get_response(): + return HttpResponse() + + assert ( + repr(AsyncMiddlewareMixin(GetResponse())) + == "" + ) + assert ( + repr(AsyncMiddlewareMixin(get_response)) + == ".get_response>" + ) + + def test_call_is_async(self): + assert iscoroutinefunction(AsyncMiddlewareMixin.__call__) + + def test_middleware_raises_if_get_response_is_sync(self): + def get_response(): + return HttpResponse() + + with pytest.raises(ImproperlyConfigured): + AsyncMiddlewareMixin(get_response) + + async def test_middleware_get_response(self, client): + middleware = AsyncMiddlewareMixin(async_get_response) + assert await middleware(client) is resp_for_get_response + + async def test_middleware_process_request(self, client, mocker): + spy = mocker.spy(RequestMiddleware, "process_request") + + middleware = RequestMiddleware(async_get_response) + result = await middleware(client) + assert result is resp is spy.spy_return + assert result is not resp_for_get_response + assert spy.call_count == 1 + spy.assert_called_once_with(middleware, client) + + async def test_middleware_process_response(self, client, mocker): + spy1 = mocker.spy(ResponseMiddleware, "process_request") + spy2 = mocker.spy(ResponseMiddleware, "process_response") + + middleware = ResponseMiddleware(async_get_response) + result = await middleware(client) + + assert result is resp is spy2.spy_return + assert result is not resp_for_get_response + assert spy2.call_count == 1 + spy2.assert_called_once_with(middleware, client, req) + + assert spy1.call_count == 1 + assert spy1.spy_return == req