Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
43 changes: 43 additions & 0 deletions django_async_extensions/amiddleware/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from asgiref.sync import iscoroutinefunction, markcoroutinefunction

from django.core.exceptions import ImproperlyConfigured


class AsyncMiddlewareMixin:
sync_capable = False
async_capable = True

def __init__(self, get_response):
if get_response is None:
raise ValueError("get_response must be provided.")
self.get_response = get_response
# If get_response is not an async function, raise an error.
self.async_mode = iscoroutinefunction(self.get_response) or iscoroutinefunction(
getattr(self.get_response, "__call__", None)
)
if self.async_mode:
# Mark the class as async-capable.
markcoroutinefunction(self)
else:
raise ImproperlyConfigured("get_response must be async")

super().__init__()

def __repr__(self):
return "<%s get_response=%s>" % (
self.__class__.__qualname__,
getattr(
self.get_response,
"__qualname__",
self.get_response.__class__.__name__,
),
)

async def __call__(self, request):
response = None
if hasattr(self, "process_request"):
response = await self.process_request(request)
response = response or await self.get_response(request)
if hasattr(self, "process_response"):
response = await self.process_response(request, response)
return response
Empty file.
88 changes: 88 additions & 0 deletions tests/test_middlewares/test_middleware_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from inspect import iscoroutinefunction

import pytest

from django.core.exceptions import ImproperlyConfigured
from django.http.response import HttpResponse

from django_async_extensions.amiddleware.base import AsyncMiddlewareMixin

req = HttpResponse()
resp = HttpResponse()
resp_for_get_response = HttpResponse()


async def async_get_response(request):
return resp_for_get_response


class ResponseMiddleware(AsyncMiddlewareMixin):
async def process_request(self, request):
return req

async def process_response(self, request, response):
return resp


class RequestMiddleware(AsyncMiddlewareMixin):
async def process_request(self, request):
return resp


class TestMiddlewareMixin:
def test_repr(self):
class GetResponse:
async def __call__(self):
return HttpResponse()

async def get_response():
return HttpResponse()

assert (
repr(AsyncMiddlewareMixin(GetResponse()))
== "<AsyncMiddlewareMixin get_response=GetResponse>"
)
assert (
repr(AsyncMiddlewareMixin(get_response))
== "<AsyncMiddlewareMixin get_response="
"TestMiddlewareMixin.test_repr.<locals>.get_response>"
)

def test_call_is_async(self):
assert iscoroutinefunction(AsyncMiddlewareMixin.__call__)

def test_middleware_raises_if_get_response_is_sync(self):
def get_response():
return HttpResponse()

with pytest.raises(ImproperlyConfigured):
AsyncMiddlewareMixin(get_response)

async def test_middleware_get_response(self, client):
middleware = AsyncMiddlewareMixin(async_get_response)
assert await middleware(client) is resp_for_get_response

async def test_middleware_process_request(self, client, mocker):
spy = mocker.spy(RequestMiddleware, "process_request")

middleware = RequestMiddleware(async_get_response)
result = await middleware(client)
assert result is resp is spy.spy_return
assert result is not resp_for_get_response
assert spy.call_count == 1
spy.assert_called_once_with(middleware, client)

async def test_middleware_process_response(self, client, mocker):
spy1 = mocker.spy(ResponseMiddleware, "process_request")
spy2 = mocker.spy(ResponseMiddleware, "process_response")

middleware = ResponseMiddleware(async_get_response)
result = await middleware(client)

assert result is resp is spy2.spy_return
assert result is not resp_for_get_response
assert spy2.call_count == 1
spy2.assert_called_once_with(middleware, client, req)

assert spy1.call_count == 1
assert spy1.spy_return == req
Loading