diff --git a/ninja_extra/controllers/route/route_functions.py b/ninja_extra/controllers/route/route_functions.py index 49a3eb89..639524b4 100644 --- a/ninja_extra/controllers/route/route_functions.py +++ b/ninja_extra/controllers/route/route_functions.py @@ -6,15 +6,15 @@ from django.http import HttpRequest, HttpResponse -from ...dependency_resolver import get_injector, service_resolver +from ninja_extra.dependency_resolver import get_injector, service_resolver + from .context import RouteContext, get_route_execution_context if TYPE_CHECKING: # pragma: no cover + from ninja_extra.controllers.base import APIController, ControllerBase + from ninja_extra.controllers.route import Route from ninja_extra.operation import Operation - from ...controllers.base import APIController, ControllerBase - from ...controllers.route import Route - class RouteFunctionContext: def __init__( @@ -50,6 +50,7 @@ def __call__( *args, **kwargs, ) + self.run_permission_check(context) return self.as_view(request, *args, route_context=context, **kwargs) def _get_required_api_func_signature(self) -> Tuple: @@ -74,6 +75,13 @@ def _resolve_api_func_signature_(self, context_func: Callable) -> Callable: context_func.__signature__ = sig_replaced # type: ignore return context_func + def run_permission_check(self, route_context: RouteContext) -> None: + _route_context = route_context or cast( + RouteContext, service_resolver(RouteContext) + ) + with self._prep_controller_route_execution(_route_context) as ctx: + ctx.controller_instance.check_permissions() + def get_view_function(self) -> Callable: def as_view( request: HttpRequest, @@ -85,23 +93,30 @@ def as_view( RouteContext, service_resolver(RouteContext) ) with self._prep_controller_route_execution(_route_context, **kwargs) as ctx: - ctx.controller_instance.check_permissions() + # ctx.controller_instance.check_permissions() result = self.route.view_func( ctx.controller_instance, *args, **ctx.view_func_kwargs ) - return self._process_view_function_result(result) + return result as_view.get_route_function = lambda: self # type:ignore return as_view def _process_view_function_result(self, result: Any) -> Any: """ - This process any an returned value from view_func - and creates an api response if result is ControllerResponseSchema - """ + This process any a returned value from view_func + and creates an api response if a result is ControllerResponseSchema - # if result and isinstance(result, ControllerResponse): - # return result.status_code, result.convert_to_schema() + deprecated:: 0.21.5 + This method is deprecated and will be removed in a future version. + The result processing should be handled by the response handlers. + """ + warnings.warn( + "_process_view_function_result() is deprecated and will be removed in a future version. " + "The result processing should be handled by the response handlers.", + DeprecationWarning, + stacklevel=2, + ) return result def _get_controller_instance(self) -> "ControllerBase": @@ -163,6 +178,11 @@ def __repr__(self) -> str: # pragma: no cover class AsyncRouteFunction(RouteFunction): + async def async_run_check_permissions(self, route_context: RouteContext) -> None: + from asgiref.sync import sync_to_async + + await sync_to_async(self.run_permission_check)(route_context) + def get_view_function(self) -> Callable: async def as_view( request: HttpRequest, @@ -170,17 +190,15 @@ async def as_view( *args: Any, **kwargs: Any, ) -> Any: - from asgiref.sync import sync_to_async - _route_context = route_context or cast( RouteContext, service_resolver(RouteContext) ) with self._prep_controller_route_execution(_route_context, **kwargs) as ctx: - await sync_to_async(ctx.controller_instance.check_permissions)() + # await sync_to_async(ctx.controller_instance.check_permissions)() result = await self.route.view_func( ctx.controller_instance, *args, **ctx.view_func_kwargs ) - return self._process_view_function_result(result) + return result as_view.get_route_function = lambda: self # type:ignore return as_view @@ -205,4 +223,5 @@ async def __call__( *args, **kwargs, ) + await self.async_run_check_permissions(context) return await self.as_view(request, *args, route_context=context, **kwargs) diff --git a/ninja_extra/operation.py b/ninja_extra/operation.py index 56ea15ee..f7811504 100644 --- a/ninja_extra/operation.py +++ b/ninja_extra/operation.py @@ -48,10 +48,12 @@ from .details import ViewSignature if TYPE_CHECKING: # pragma: no cover - from .controllers.route.route_functions import RouteFunction + from .controllers.route.route_functions import AsyncRouteFunction, RouteFunction class Operation(NinjaOperation): + view_func: Callable + def __init__( self, path: str, @@ -88,6 +90,16 @@ def _set_auth( f"N:B - {get_function_name(callback)} can only be used on Asynchronous view functions" ) + def _get_route_function( + self, + ) -> Optional[Union["RouteFunction", "AsyncRouteFunction"]]: + if hasattr(self.view_func, "get_route_function"): + return cast( + Union["RouteFunction", "AsyncRouteFunction"], + self.view_func.get_route_function(), + ) + return None + def _log_action( self, logger: Callable[..., Any], @@ -102,8 +114,8 @@ def _log_action( f'{self.view_func.__name__} {request.path}" ' f"{duration if duration else ''}" ) - if hasattr(self.view_func, "get_route_function"): - route_function: "RouteFunction" = self.view_func.get_route_function() + route_function = self._get_route_function() + if route_function: api_controller = route_function.get_api_controller() msg = ( @@ -185,6 +197,10 @@ def run(self, request: HttpRequest, **kw: Any) -> HttpResponseBase: with self._prep_run( request, temporal_response=temporal_response, **kw ) as ctx: + route_function = self._get_route_function() + if route_function: + route_function.run_permission_check(ctx) + error = self._run_checks(request) if error: return error @@ -309,6 +325,10 @@ async def run(self, request: HttpRequest, **kw: Any) -> HttpResponseBase: # typ async with self._prep_run( request, temporal_response=temporal_response, **kw ) as ctx: + route_function = self._get_route_function() + if route_function: + await route_function.async_run_check_permissions(ctx) # type: ignore[attr-defined] + error = await self._run_checks(request) if error: return error diff --git a/tests/test_permissions.py b/tests/test_permissions.py index 21f47b98..9e350176 100644 --- a/tests/test_permissions.py +++ b/tests/test_permissions.py @@ -2,10 +2,11 @@ from unittest.mock import Mock import pytest +from asgiref.sync import sync_to_async from django.contrib.auth.models import AnonymousUser, User from ninja_extra import ControllerBase, api_controller, http_get, permissions -from ninja_extra.testing import TestClient +from ninja_extra.testing import TestAsyncClient, TestClient anonymous_request = Mock() anonymous_request.user = AnonymousUser() @@ -250,12 +251,19 @@ def index(self): def permission_accept_type_and_instance(self): return {"success": True} + @http_get( + "permission/async/", + permissions=[permissions.IsAdminUser() & permissions.IsAuthenticatedOrReadOnly], + ) + async def permission_accept_type_and_instance_async(self): + return {"success": True} + @pytest.mark.django_db @pytest.mark.parametrize("route", ["permission/", "index/"]) def test_permission_controller_instance(route): user = User.objects.create_user( - username="eadwin", + username="eadwin3", email="eadwin@example.com", password="password", is_staff=True, @@ -269,3 +277,23 @@ def test_permission_controller_instance(route): res = client.get(route, user=user) assert res.status_code == 200 assert res.json() == {"success": True} + + +@pytest.mark.django_db +@pytest.mark.asyncio +async def test_permission_controller_instance_async(): + user = await sync_to_async(User.objects.create_user)( + username="eadwin2", + email="eadwin@example.com", + password="password", + is_staff=True, + is_superuser=True, + ) + + client = TestAsyncClient(Some2Controller) + res = await client.get("/permission/async/", user=AnonymousUser()) + assert res.status_code == 403 + + res = await client.get("/permission/async/", user=user) + assert res.status_code == 200 + assert res.json() == {"success": True} diff --git a/tests/test_route.py b/tests/test_route.py index b6148c6c..2eda8177 100644 --- a/tests/test_route.py +++ b/tests/test_route.py @@ -351,7 +351,7 @@ def setup_method(self): def get_real_user_request(cls): _request = Mock() user = User.objects.create_user( - username="eadwin", + username="eadwin1", email="eadwin@example.com", password="password", is_staff=True, diff --git a/tests/test_throthling/test_models.py b/tests/test_throthling/test_models.py index 48f2db1b..22cc87d8 100644 --- a/tests/test_throthling/test_models.py +++ b/tests/test_throthling/test_models.py @@ -125,7 +125,9 @@ def setup_method(self): def test_get_cache_key_returns_correct_value_for_authenticated_request(self): user = User.objects.create(username="test") self.request.user = user - assert self.throttle.get_cache_key(self.request) == "throttle_user_1" + assert self.throttle.get_cache_key(self.request) == "throttle_user_{}".format( + user.pk + ) def test_get_cache_key_defaults_to_none(self): cache_key = self.throttle.get_cache_key(self.request) @@ -165,7 +167,9 @@ def test_get_cache_key_returns_correct_value_for_authenticated_request( throttle = DynamicRateThrottle(scope="some_scope") user = User.objects.create(username="test") self.request.user = user - assert throttle.get_cache_key(self.request) == "throttle_some_scope_1" + assert throttle.get_cache_key( + self.request + ) == "throttle_some_scope_{}".format(user.pk) def test_get_cache_key_defaults_to_none(self, monkeypatch): with monkeypatch.context() as m: