Skip to content

Commit 8ae0e27

Browse files
committed
fixed permission check before schema validation as described in #116, #192 issues
1 parent 4e31ab4 commit 8ae0e27

File tree

3 files changed

+84
-19
lines changed

3 files changed

+84
-19
lines changed

ninja_extra/controllers/route/route_functions.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@
66

77
from django.http import HttpRequest, HttpResponse
88

9-
from ...dependency_resolver import get_injector, service_resolver
9+
from ninja_extra.dependency_resolver import get_injector, service_resolver
10+
1011
from .context import RouteContext, get_route_execution_context
1112

1213
if TYPE_CHECKING: # pragma: no cover
14+
from ninja_extra.controllers.base import APIController, ControllerBase
15+
from ninja_extra.controllers.route import Route
1316
from ninja_extra.operation import Operation
1417

15-
from ...controllers.base import APIController, ControllerBase
16-
from ...controllers.route import Route
17-
1818

1919
class RouteFunctionContext:
2020
def __init__(
@@ -74,6 +74,13 @@ def _resolve_api_func_signature_(self, context_func: Callable) -> Callable:
7474
context_func.__signature__ = sig_replaced # type: ignore
7575
return context_func
7676

77+
def run_permission_check(self, route_context: RouteContext) -> None:
78+
_route_context = route_context or cast(
79+
RouteContext, service_resolver(RouteContext)
80+
)
81+
with self._prep_controller_route_execution(_route_context) as ctx:
82+
ctx.controller_instance.check_permissions()
83+
7784
def get_view_function(self) -> Callable:
7885
def as_view(
7986
request: HttpRequest,
@@ -85,23 +92,30 @@ def as_view(
8592
RouteContext, service_resolver(RouteContext)
8693
)
8794
with self._prep_controller_route_execution(_route_context, **kwargs) as ctx:
88-
ctx.controller_instance.check_permissions()
95+
# ctx.controller_instance.check_permissions()
8996
result = self.route.view_func(
9097
ctx.controller_instance, *args, **ctx.view_func_kwargs
9198
)
92-
return self._process_view_function_result(result)
99+
return result
93100

94101
as_view.get_route_function = lambda: self # type:ignore
95102
return as_view
96103

97104
def _process_view_function_result(self, result: Any) -> Any:
98105
"""
99-
This process any an returned value from view_func
100-
and creates an api response if result is ControllerResponseSchema
101-
"""
106+
This process any a returned value from view_func
107+
and creates an api response if a result is ControllerResponseSchema
102108
103-
# if result and isinstance(result, ControllerResponse):
104-
# return result.status_code, result.convert_to_schema()
109+
deprecated:: 0.21.5
110+
This method is deprecated and will be removed in a future version.
111+
The result processing should be handled by the response handlers.
112+
"""
113+
warnings.warn(
114+
"_process_view_function_result() is deprecated and will be removed in a future version. "
115+
"The result processing should be handled by the response handlers.",
116+
DeprecationWarning,
117+
stacklevel=2,
118+
)
105119
return result
106120

107121
def _get_controller_instance(self) -> "ControllerBase":
@@ -163,24 +177,27 @@ def __repr__(self) -> str: # pragma: no cover
163177

164178

165179
class AsyncRouteFunction(RouteFunction):
180+
async def async_run_check_permissions(self, route_context: RouteContext) -> None:
181+
from asgiref.sync import sync_to_async
182+
183+
await sync_to_async(self.run_permission_check)(route_context)
184+
166185
def get_view_function(self) -> Callable:
167186
async def as_view(
168187
request: HttpRequest,
169188
route_context: Optional[RouteContext] = None,
170189
*args: Any,
171190
**kwargs: Any,
172191
) -> Any:
173-
from asgiref.sync import sync_to_async
174-
175192
_route_context = route_context or cast(
176193
RouteContext, service_resolver(RouteContext)
177194
)
178195
with self._prep_controller_route_execution(_route_context, **kwargs) as ctx:
179-
await sync_to_async(ctx.controller_instance.check_permissions)()
196+
# await sync_to_async(ctx.controller_instance.check_permissions)()
180197
result = await self.route.view_func(
181198
ctx.controller_instance, *args, **ctx.view_func_kwargs
182199
)
183-
return self._process_view_function_result(result)
200+
return result
184201

185202
as_view.get_route_function = lambda: self # type:ignore
186203
return as_view

ninja_extra/operation.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,12 @@
4848
from .details import ViewSignature
4949

5050
if TYPE_CHECKING: # pragma: no cover
51-
from .controllers.route.route_functions import RouteFunction
51+
from .controllers.route.route_functions import AsyncRouteFunction, RouteFunction
5252

5353

5454
class Operation(NinjaOperation):
55+
view_func: Callable
56+
5557
def __init__(
5658
self,
5759
path: str,
@@ -88,6 +90,16 @@ def _set_auth(
8890
f"N:B - {get_function_name(callback)} can only be used on Asynchronous view functions"
8991
)
9092

93+
def _get_route_function(
94+
self,
95+
) -> Optional[Union["RouteFunction", "AsyncRouteFunction"]]:
96+
if hasattr(self.view_func, "get_route_function"):
97+
return cast(
98+
Union["RouteFunction", "AsyncRouteFunction"],
99+
self.view_func.get_route_function(),
100+
)
101+
return None
102+
91103
def _log_action(
92104
self,
93105
logger: Callable[..., Any],
@@ -102,8 +114,8 @@ def _log_action(
102114
f'{self.view_func.__name__} {request.path}" '
103115
f"{duration if duration else ''}"
104116
)
105-
if hasattr(self.view_func, "get_route_function"):
106-
route_function: "RouteFunction" = self.view_func.get_route_function()
117+
route_function = self._get_route_function()
118+
if route_function:
107119
api_controller = route_function.get_api_controller()
108120

109121
msg = (
@@ -185,6 +197,10 @@ def run(self, request: HttpRequest, **kw: Any) -> HttpResponseBase:
185197
with self._prep_run(
186198
request, temporal_response=temporal_response, **kw
187199
) as ctx:
200+
route_function = self._get_route_function()
201+
if route_function:
202+
route_function.run_permission_check(ctx)
203+
188204
error = self._run_checks(request)
189205
if error:
190206
return error
@@ -309,6 +325,10 @@ async def run(self, request: HttpRequest, **kw: Any) -> HttpResponseBase: # typ
309325
async with self._prep_run(
310326
request, temporal_response=temporal_response, **kw
311327
) as ctx:
328+
route_function = self._get_route_function()
329+
if route_function:
330+
await route_function.async_run_check_permissions(ctx) # type: ignore[attr-defined]
331+
312332
error = await self._run_checks(request)
313333
if error:
314334
return error

tests/test_permissions.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
from unittest.mock import Mock
33

44
import pytest
5+
from asgiref.sync import sync_to_async
56
from django.contrib.auth.models import AnonymousUser, User
67

78
from ninja_extra import ControllerBase, api_controller, http_get, permissions
8-
from ninja_extra.testing import TestClient
9+
from ninja_extra.testing import TestAsyncClient, TestClient
910

1011
anonymous_request = Mock()
1112
anonymous_request.user = AnonymousUser()
@@ -250,6 +251,13 @@ def index(self):
250251
def permission_accept_type_and_instance(self):
251252
return {"success": True}
252253

254+
@http_get(
255+
"permission/async/",
256+
permissions=[permissions.IsAdminUser() & permissions.IsAuthenticatedOrReadOnly],
257+
)
258+
async def permission_accept_type_and_instance_async(self):
259+
return {"success": True}
260+
253261

254262
@pytest.mark.django_db
255263
@pytest.mark.parametrize("route", ["permission/", "index/"])
@@ -269,3 +277,23 @@ def test_permission_controller_instance(route):
269277
res = client.get(route, user=user)
270278
assert res.status_code == 200
271279
assert res.json() == {"success": True}
280+
281+
282+
@pytest.mark.django_db
283+
@pytest.mark.asyncio
284+
async def test_permission_controller_instance_async():
285+
user = await sync_to_async(User.objects.create_user)(
286+
username="eadwin",
287+
288+
password="password",
289+
is_staff=True,
290+
is_superuser=True,
291+
)
292+
293+
client = TestAsyncClient(Some2Controller)
294+
res = await client.get("/permission/async/", user=AnonymousUser())
295+
assert res.status_code == 403
296+
297+
res = await client.get("/permission/async/", user=user)
298+
assert res.status_code == 200
299+
assert res.json() == {"success": True}

0 commit comments

Comments
 (0)