Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 34 additions & 15 deletions ninja_extra/controllers/route/route_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@

from django.http import HttpRequest, HttpResponse

from ...dependency_resolver import get_injector, service_resolver
from ninja_extra.dependency_resolver import get_injector, service_resolver

from .context import RouteContext, get_route_execution_context

if TYPE_CHECKING: # pragma: no cover
from ninja_extra.controllers.base import APIController, ControllerBase
from ninja_extra.controllers.route import Route
from ninja_extra.operation import Operation

from ...controllers.base import APIController, ControllerBase
from ...controllers.route import Route


class RouteFunctionContext:
def __init__(
Expand Down Expand Up @@ -50,6 +50,7 @@ def __call__(
*args,
**kwargs,
)
self.run_permission_check(context)
return self.as_view(request, *args, route_context=context, **kwargs)

def _get_required_api_func_signature(self) -> Tuple:
Expand All @@ -74,6 +75,13 @@ def _resolve_api_func_signature_(self, context_func: Callable) -> Callable:
context_func.__signature__ = sig_replaced # type: ignore
return context_func

def run_permission_check(self, route_context: RouteContext) -> None:
_route_context = route_context or cast(
RouteContext, service_resolver(RouteContext)
)
with self._prep_controller_route_execution(_route_context) as ctx:
ctx.controller_instance.check_permissions()

def get_view_function(self) -> Callable:
def as_view(
request: HttpRequest,
Expand All @@ -85,23 +93,30 @@ def as_view(
RouteContext, service_resolver(RouteContext)
)
with self._prep_controller_route_execution(_route_context, **kwargs) as ctx:
ctx.controller_instance.check_permissions()
# ctx.controller_instance.check_permissions()
result = self.route.view_func(
ctx.controller_instance, *args, **ctx.view_func_kwargs
)
return self._process_view_function_result(result)
return result

as_view.get_route_function = lambda: self # type:ignore
return as_view

def _process_view_function_result(self, result: Any) -> Any:
"""
This process any an returned value from view_func
and creates an api response if result is ControllerResponseSchema
"""
This process any a returned value from view_func
and creates an api response if a result is ControllerResponseSchema

# if result and isinstance(result, ControllerResponse):
# return result.status_code, result.convert_to_schema()
deprecated:: 0.21.5
This method is deprecated and will be removed in a future version.
The result processing should be handled by the response handlers.
"""
warnings.warn(
"_process_view_function_result() is deprecated and will be removed in a future version. "
"The result processing should be handled by the response handlers.",
DeprecationWarning,
stacklevel=2,
)
return result

def _get_controller_instance(self) -> "ControllerBase":
Expand Down Expand Up @@ -163,24 +178,27 @@ def __repr__(self) -> str: # pragma: no cover


class AsyncRouteFunction(RouteFunction):
async def async_run_check_permissions(self, route_context: RouteContext) -> None:
from asgiref.sync import sync_to_async

await sync_to_async(self.run_permission_check)(route_context)

def get_view_function(self) -> Callable:
async def as_view(
request: HttpRequest,
route_context: Optional[RouteContext] = None,
*args: Any,
**kwargs: Any,
) -> Any:
from asgiref.sync import sync_to_async

_route_context = route_context or cast(
RouteContext, service_resolver(RouteContext)
)
with self._prep_controller_route_execution(_route_context, **kwargs) as ctx:
await sync_to_async(ctx.controller_instance.check_permissions)()
# await sync_to_async(ctx.controller_instance.check_permissions)()
result = await self.route.view_func(
ctx.controller_instance, *args, **ctx.view_func_kwargs
)
return self._process_view_function_result(result)
return result

as_view.get_route_function = lambda: self # type:ignore
return as_view
Expand All @@ -205,4 +223,5 @@ async def __call__(
*args,
**kwargs,
)
await self.async_run_check_permissions(context)
return await self.as_view(request, *args, route_context=context, **kwargs)
26 changes: 23 additions & 3 deletions ninja_extra/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,12 @@
from .details import ViewSignature

if TYPE_CHECKING: # pragma: no cover
from .controllers.route.route_functions import RouteFunction
from .controllers.route.route_functions import AsyncRouteFunction, RouteFunction


class Operation(NinjaOperation):
view_func: Callable

def __init__(
self,
path: str,
Expand Down Expand Up @@ -88,6 +90,16 @@ def _set_auth(
f"N:B - {get_function_name(callback)} can only be used on Asynchronous view functions"
)

def _get_route_function(
self,
) -> Optional[Union["RouteFunction", "AsyncRouteFunction"]]:
if hasattr(self.view_func, "get_route_function"):
return cast(
Union["RouteFunction", "AsyncRouteFunction"],
self.view_func.get_route_function(),
)
return None

def _log_action(
self,
logger: Callable[..., Any],
Expand All @@ -102,8 +114,8 @@ def _log_action(
f'{self.view_func.__name__} {request.path}" '
f"{duration if duration else ''}"
)
if hasattr(self.view_func, "get_route_function"):
route_function: "RouteFunction" = self.view_func.get_route_function()
route_function = self._get_route_function()
if route_function:
api_controller = route_function.get_api_controller()

msg = (
Expand Down Expand Up @@ -185,6 +197,10 @@ def run(self, request: HttpRequest, **kw: Any) -> HttpResponseBase:
with self._prep_run(
request, temporal_response=temporal_response, **kw
) as ctx:
route_function = self._get_route_function()
if route_function:
route_function.run_permission_check(ctx)

error = self._run_checks(request)
if error:
return error
Expand Down Expand Up @@ -309,6 +325,10 @@ async def run(self, request: HttpRequest, **kw: Any) -> HttpResponseBase: # typ
async with self._prep_run(
request, temporal_response=temporal_response, **kw
) as ctx:
route_function = self._get_route_function()
if route_function:
await route_function.async_run_check_permissions(ctx) # type: ignore[attr-defined]

error = await self._run_checks(request)
if error:
return error
Expand Down
32 changes: 30 additions & 2 deletions tests/test_permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
from unittest.mock import Mock

import pytest
from asgiref.sync import sync_to_async
from django.contrib.auth.models import AnonymousUser, User

from ninja_extra import ControllerBase, api_controller, http_get, permissions
from ninja_extra.testing import TestClient
from ninja_extra.testing import TestAsyncClient, TestClient

anonymous_request = Mock()
anonymous_request.user = AnonymousUser()
Expand Down Expand Up @@ -250,12 +251,19 @@ def index(self):
def permission_accept_type_and_instance(self):
return {"success": True}

@http_get(
"permission/async/",
permissions=[permissions.IsAdminUser() & permissions.IsAuthenticatedOrReadOnly],
)
async def permission_accept_type_and_instance_async(self):
return {"success": True}


@pytest.mark.django_db
@pytest.mark.parametrize("route", ["permission/", "index/"])
def test_permission_controller_instance(route):
user = User.objects.create_user(
username="eadwin",
username="eadwin3",
email="[email protected]",
password="password",
is_staff=True,
Expand All @@ -269,3 +277,23 @@ def test_permission_controller_instance(route):
res = client.get(route, user=user)
assert res.status_code == 200
assert res.json() == {"success": True}


@pytest.mark.django_db
@pytest.mark.asyncio
async def test_permission_controller_instance_async():
user = await sync_to_async(User.objects.create_user)(
username="eadwin2",
email="[email protected]",
password="password",
is_staff=True,
is_superuser=True,
)

client = TestAsyncClient(Some2Controller)
res = await client.get("/permission/async/", user=AnonymousUser())
assert res.status_code == 403

res = await client.get("/permission/async/", user=user)
assert res.status_code == 200
assert res.json() == {"success": True}
2 changes: 1 addition & 1 deletion tests/test_route.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def setup_method(self):
def get_real_user_request(cls):
_request = Mock()
user = User.objects.create_user(
username="eadwin",
username="eadwin1",
email="[email protected]",
password="password",
is_staff=True,
Expand Down
8 changes: 6 additions & 2 deletions tests/test_throthling/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ def setup_method(self):
def test_get_cache_key_returns_correct_value_for_authenticated_request(self):
user = User.objects.create(username="test")
self.request.user = user
assert self.throttle.get_cache_key(self.request) == "throttle_user_1"
assert self.throttle.get_cache_key(self.request) == "throttle_user_{}".format(
user.pk
)

def test_get_cache_key_defaults_to_none(self):
cache_key = self.throttle.get_cache_key(self.request)
Expand Down Expand Up @@ -165,7 +167,9 @@ def test_get_cache_key_returns_correct_value_for_authenticated_request(
throttle = DynamicRateThrottle(scope="some_scope")
user = User.objects.create(username="test")
self.request.user = user
assert throttle.get_cache_key(self.request) == "throttle_some_scope_1"
assert throttle.get_cache_key(
self.request
) == "throttle_some_scope_{}".format(user.pk)

def test_get_cache_key_defaults_to_none(self, monkeypatch):
with monkeypatch.context() as m:
Expand Down
Loading