Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions ninja_extra/security/session.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from django.conf import settings
from django.contrib.auth.middleware import get_user
from django.http import HttpRequest
from ninja.signature import is_async

from ninja_extra.security.api_key import AsyncAPIKeyCookie

Expand All @@ -17,10 +17,12 @@ class AsyncSessionAuth(AsyncAPIKeyCookie):
async def authenticate(
self, request: HttpRequest, key: Optional[str]
) -> Optional[Any]:
if hasattr(request, "auser") and is_async(request.auser):
from asgiref.sync import sync_to_async

if hasattr(request, "auser"):
current_user = await request.auser()
else:
current_user = request.user
current_user = await sync_to_async(get_user)(request)

if current_user.is_authenticated:
return current_user
Expand Down
4 changes: 0 additions & 4 deletions tests/test_async_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
AsyncAPIKeyQuery,
AsyncHttpBasicAuth,
AsyncHttpBearer,
async_django_auth,
)

user_secret = base64.b64encode("admin:secret".encode("utf-8")).decode()
Expand Down Expand Up @@ -145,7 +144,6 @@ async def test_csrf_on():
api = NinjaExtraAPI(csrf=True, urls_namespace="async_auth")

for path, auth in [
("django_auth", async_django_auth),
("callable", callable_auth),
("apikeyquery", KeyQuery()),
("apikeyheader", KeyHeader()),
Expand All @@ -163,8 +161,6 @@ async def test_csrf_on():
@pytest.mark.parametrize(
"path,kwargs,expected_code,expected_body",
[
("/django_auth", {}, 401, BODY_UNAUTHORIZED_DEFAULT),
("/django_auth", {"user": MockUser("admin")}, 200, {"auth": "admin"}),
("/callable", {}, 401, BODY_UNAUTHORIZED_DEFAULT),
("/callable?auth=demo", {}, 200, {"auth": "demo"}),
("/apikeyquery", {}, 401, BODY_UNAUTHORIZED_DEFAULT),
Expand Down
21 changes: 14 additions & 7 deletions tests/test_security/test_session.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,47 @@
from unittest.mock import AsyncMock, Mock

import pytest
from asgiref.sync import sync_to_async
from django.contrib.sessions.middleware import SessionMiddleware
from django.http import HttpRequest

from ninja_extra.security.session import AsyncSessionAuth
from ninja_extra.security import async_django_auth


@pytest.mark.asyncio
@pytest.mark.django_db
async def test_async_session_auth():
auth = AsyncSessionAuth()
request = HttpRequest()

# Add session to request
middleware = SessionMiddleware(lambda x: x)
await sync_to_async(middleware.process_request)(request)
await sync_to_async(request.session.save)()

# Test async authenticated user
async_user = AsyncMock()
async_user.is_authenticated = True
request.auser = AsyncMock(return_value=async_user)

authenticated_user = await auth.authenticate(request, None)
authenticated_user = await async_django_auth.authenticate(request, None)
assert authenticated_user == async_user
request.auser.assert_called_once()

# Test async non-authenticated user
async_user.is_authenticated = False
authenticated_user = await auth.authenticate(request, None)
authenticated_user = await async_django_auth.authenticate(request, None)
assert authenticated_user is None

# Test non-async authenticated user
delattr(request, "auser")
sync_user = Mock()
sync_user.is_authenticated = True
request.user = sync_user
request._cached_user = sync_user

authenticated_user = await auth.authenticate(request, None)
authenticated_user = await async_django_auth.authenticate(request, None)
assert authenticated_user == sync_user

# Test non-async non-authenticated user
sync_user.is_authenticated = False
authenticated_user = await auth.authenticate(request, None)
authenticated_user = await async_django_auth.authenticate(request, None)
assert authenticated_user is None
Loading