Skip to content

Commit 1660c85

Browse files
authored
Merge pull request #5 from eadwinCode/feat/Async-Support
Async support on Pagination and Route Authentication
2 parents 16246c4 + 3464779 commit 1660c85

22 files changed

+935
-71
lines changed

.github/workflows/test_full.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ jobs:
2323
- name: Install core
2424
run: pip install "Django${{ matrix.django-version }}" pydantic
2525
- name: Install tests
26-
run: pip install pytest pytest-asyncio pytest-django injector django-ninja
26+
run: pip install pytest pytest-asyncio pytest-django injector django-ninja asgiref
2727
- name: Test
2828
run: pytest
2929
codestyle:

ninja_extra/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from ninja_extra.controllers.route import route
1616
from ninja_extra.dependency_resolver import get_injector, service_resolver
1717
from ninja_extra.main import NinjaExtraAPI
18+
from ninja_extra.router import Router
1819

1920
default_app_config = "ninja_extra.apps.NinjaExtraConfig"
2021

@@ -36,4 +37,5 @@
3637
"get_injector",
3738
"service_resolver",
3839
"lazy",
40+
"Router",
3941
]

ninja_extra/controllers/base.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
Iterator,
1212
List,
1313
Optional,
14+
Sequence,
1415
Tuple,
1516
Type,
1617
Union,
@@ -25,10 +26,12 @@
2526
from ninja import NinjaAPI, Router
2627
from ninja.constants import NOT_SET
2728
from ninja.security.base import AuthBase
29+
from ninja.signature import is_async
2830
from ninja.utils import normalize_path
2931

3032
from ninja_extra.exceptions import APIException, NotFound, PermissionDenied, bad_request
31-
from ninja_extra.operation import Operation, PathView
33+
from ninja_extra.helper import get_function_name
34+
from ninja_extra.operation import ControllerPathView, Operation
3235
from ninja_extra.permissions import AllowAny, BasePermission
3336
from ninja_extra.shortcuts import (
3437
fail_silently,
@@ -39,7 +42,7 @@
3942

4043
from .registry import ControllerRegistry
4144
from .response import Detail, Id, Ok
42-
from .route.route_functions import RouteFunction
45+
from .route.route_functions import AsyncRouteFunction, RouteFunction
4346

4447
if TYPE_CHECKING:
4548
from ninja_extra import NinjaExtraAPI # pragma: no cover
@@ -256,7 +259,7 @@ def __init__(
256259
# `controller_class` target class that the APIController wraps
257260
self._controller_class: Optional[Type["ControllerBase"]] = None
258261
# `_path_operations` a converted dict of APIController route function used by Django-Ninja library
259-
self._path_operations: Dict[str, PathView] = dict()
262+
self._path_operations: Dict[str, ControllerPathView] = dict()
260263
# `permission_classes` a collection of BasePermission Types
261264
# a fallback if route functions has no permissions definition
262265
self.permission_classes: PermissionType = permissions or [AllowAny] # type: ignore
@@ -268,6 +271,15 @@ def __init__(
268271
if re.search(self._PATH_PARAMETER_COMPONENT_RE, prefix):
269272
self._prefix_has_route_param = True
270273

274+
self.has_auth_async = False
275+
if auth is not NOT_SET:
276+
auth_callbacks = isinstance(auth, Sequence) and auth or [auth]
277+
for _auth in auth_callbacks:
278+
_call_back = _auth if inspect.isfunction(_auth) else _auth.__call__
279+
if is_async(_call_back):
280+
self.has_auth_async = True
281+
break
282+
271283
@property
272284
def controller_class(self) -> Type["ControllerBase"]:
273285
assert self._controller_class, "Controller Class is not available"
@@ -310,7 +322,7 @@ def __call__(self, cls: Type) -> Type["ControllerBase"]:
310322
return cls
311323

312324
@property
313-
def path_operations(self) -> Dict[str, PathView]:
325+
def path_operations(self) -> Dict[str, ControllerPathView]:
314326
return self._path_operations
315327

316328
def set_api_instance(self, api: "NinjaExtraAPI") -> None:
@@ -346,7 +358,20 @@ def __str__(self) -> str: # pragma: no cover
346358
def add_operation_from_route_function(self, route_function: RouteFunction) -> None:
347359
# converts route functions to Operation model
348360
route_function.route.route_params.operation_id = f"{str(uuid.uuid4())[:8]}_controller_{route_function.route.view_func.__name__}"
349-
self.add_api_operation(
361+
362+
if (
363+
self.auth
364+
and self.has_auth_async
365+
and not isinstance(route_function, AsyncRouteFunction)
366+
):
367+
raise Exception(
368+
f"You are using a Controller level Asynchronous Authentication Class, "
369+
f"All controller endpoint must be `async`.\n"
370+
f"Controller={self.controller_class.__name__}, "
371+
f"endpoint={get_function_name(route_function.route.view_func)}"
372+
)
373+
374+
route_function.operation = self.add_api_operation( # type: ignore
350375
view_func=route_function.as_view, **route_function.route.route_params.dict()
351376
)
352377

@@ -375,7 +400,7 @@ def add_api_operation(
375400
if self._prefix_has_route_param:
376401
path = normalize_path("/".join([i for i in (self.prefix, path) if i]))
377402
if path not in self._path_operations:
378-
path_view = PathView()
403+
path_view = ControllerPathView()
379404
self._path_operations[path] = path_view
380405
else:
381406
path_view = self._path_operations[path]

ninja_extra/controllers/route/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -459,8 +459,7 @@ def __init__(
459459

460460
_response = response
461461
if (
462-
inspect.isclass(response)
463-
and issubclass(response, ControllerResponse) # type:ignore
462+
inspect.isclass(response) and issubclass(response, ControllerResponse)
464463
) or isinstance(response, ControllerResponse):
465464
response = cast(ControllerResponse, response)
466465
_response = {response.status_code: response.get_schema()}

ninja_extra/controllers/route/route_functions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from .context import RouteContext
1212

1313
if TYPE_CHECKING:
14+
from ninja_extra.operation import ControllerOperation
15+
1416
from ...controllers.base import APIController, ControllerBase
1517
from ...controllers.route import Route
1618

@@ -28,6 +30,7 @@ def __init__(
2830
self, route: "Route", api_controller: Optional["APIController"] = None
2931
):
3032
self.route = route
33+
self.operation: Optional["ControllerOperation"] = None
3134
self.has_request_param = False
3235
self.api_controller = api_controller
3336
self.as_view = wraps(route.view_func)(self.get_view_function())

ninja_extra/helper.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import inspect
2+
from typing import Any
3+
4+
5+
def get_function_name(func_class: Any) -> str:
6+
if inspect.isfunction(func_class) or inspect.isclass(func_class):
7+
return str(func_class.__name__)
8+
return str(func_class.__class__.__name__)

ninja_extra/main.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from ninja.parser import Parser
1111
from ninja.renderers import BaseRenderer
1212

13-
from ninja_extra import exceptions
13+
from ninja_extra import exceptions, router
1414
from ninja_extra.controllers.base import APIController, ControllerBase
1515
from ninja_extra.controllers.registry import ControllerRegistry
1616

@@ -49,6 +49,9 @@ def __init__(
4949
)
5050
self.app_name = app_name
5151
self.exception_handler(exceptions.APIException)(self.api_exception_handler)
52+
self._routers: List[Tuple[str, router.Router]] = [] # type: ignore
53+
self.default_router = router.Router()
54+
self.add_router("", self.default_router)
5255

5356
def api_exception_handler(
5457
self, request: HttpRequest, exc: exceptions.APIException

ninja_extra/operation.py

Lines changed: 126 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import inspect
12
import time
23
from contextlib import contextmanager
34
from typing import (
@@ -8,7 +9,9 @@
89
List,
910
Optional,
1011
Sequence,
12+
Type,
1113
Union,
14+
cast,
1215
)
1316

1417
from django.http import HttpRequest
@@ -21,8 +24,11 @@
2124
PathView as NinjaPathView,
2225
)
2326
from ninja.signature import is_async
27+
from ninja.types import TCallable
28+
from ninja.utils import check_csrf
2429

2530
from ninja_extra.exceptions import APIException
31+
from ninja_extra.helper import get_function_name
2632
from ninja_extra.logger import request_logger
2733
from ninja_extra.signals import route_context_finished, route_context_started
2834

@@ -35,12 +41,41 @@
3541

3642
class Operation(NinjaOperation):
3743
def __init__(
38-
self, *args: Any, url_name: Optional[str] = None, **kwargs: Any
44+
self,
45+
path: str,
46+
methods: List[str],
47+
view_func: Callable,
48+
*,
49+
url_name: Optional[str] = None,
50+
**kwargs: Any,
3951
) -> None:
40-
super().__init__(*args, **kwargs)
52+
self.is_coroutine = is_async(view_func)
4153
self.url_name = url_name
54+
super().__init__(path, methods, view_func, **kwargs)
4255
self.signature = ViewSignature(self.path, self.view_func)
4356

57+
def _set_auth(
58+
self, auth: Optional[Union[Sequence[Callable], Callable, object]]
59+
) -> None:
60+
if auth is not None and auth is not NOT_SET:
61+
self.auth_callbacks = isinstance(auth, Sequence) and auth or [auth]
62+
for callback in self.auth_callbacks:
63+
_call_back = (
64+
callback if inspect.isfunction(callback) else callback.__call__ # type: ignore
65+
)
66+
67+
if not getattr(callback, "is_coroutine", None):
68+
setattr(callback, "is_coroutine", is_async(_call_back))
69+
70+
if is_async(_call_back) and not self.is_coroutine:
71+
raise Exception(
72+
f"Could apply auth=`{get_function_name(callback)}` "
73+
f"to view_func=`{get_function_name(self.view_func)}`.\n"
74+
f"N:B - {get_function_name(callback)} can only be used on Asynchronous view functions"
75+
)
76+
77+
78+
class ControllerOperation(Operation):
4479
def _log_action(
4580
self,
4681
logger: Callable[..., Any],
@@ -90,10 +125,8 @@ def _prep_run(self, request: HttpRequest, **kw: Any) -> Iterator:
90125
context = self.get_execution_context(request, **kw)
91126
# send route_context_started signal
92127
route_context_started.send(RouteContext, route_context=context)
93-
values = self._get_values(request, kw)
94-
context.kwargs = values
95128

96-
yield values, context
129+
yield context
97130
self._log_action(
98131
request_logger.info,
99132
request=request,
@@ -115,15 +148,16 @@ def _prep_run(self, request: HttpRequest, **kw: Any) -> Iterator:
115148
route_context_finished.send(RouteContext, route_context=None)
116149

117150
def run(self, request: HttpRequest, **kw: Any) -> HttpResponseBase:
118-
error = super(Operation, self)._run_checks(request)
151+
error = self._run_checks(request)
119152
if error:
120153
return error
121154
try:
122155
with self._prep_run(request, **kw) as ctx:
123-
values, context = ctx
124-
result = self.view_func(context=context, **values)
156+
values = self._get_values(request, kw)
157+
ctx.kwargs = values
158+
result = self.view_func(context=ctx, **values)
125159
_processed_results = self._result_to_response(request, result)
126-
return _processed_results
160+
return _processed_results
127161
except Exception as e:
128162
if isinstance(e, TypeError) and "required positional argument" in str(e):
129163
msg = "Did you fail to use functools.wraps() in a decorator?"
@@ -133,16 +167,73 @@ def run(self, request: HttpRequest, **kw: Any) -> HttpResponseBase:
133167

134168

135169
class AsyncOperation(Operation, NinjaAsyncOperation):
170+
def __init__(self, *args: Any, **kwargs: Any) -> None:
171+
super().__init__(*args, **kwargs)
172+
from asgiref.sync import sync_to_async
173+
174+
self._get_values = cast(Callable, sync_to_async(super()._get_values)) # type: ignore
175+
self._result_to_response = cast( # type: ignore
176+
Callable,
177+
sync_to_async(super()._result_to_response),
178+
)
179+
180+
async def _run_checks(self, request: HttpRequest) -> Optional[HttpResponse]: # type: ignore
181+
"""Runs security checks for each operation"""
182+
# auth:
183+
if self.auth_callbacks:
184+
error = await self._run_authentication(request)
185+
if error:
186+
return error
187+
188+
# csrf:
189+
if self.api.csrf:
190+
error = check_csrf(request, self.view_func)
191+
if error:
192+
return error
193+
194+
return None
195+
196+
async def _run_authentication(self, request: HttpRequest) -> Optional[HttpResponse]: # type: ignore
197+
for callback in self.auth_callbacks:
198+
try:
199+
is_coroutine = getattr(callback, "is_coroutine", False)
200+
if is_coroutine:
201+
result = await callback(request)
202+
else:
203+
result = callback(request)
204+
except Exception as exc:
205+
return self.api.on_exception(request, exc)
206+
207+
if result:
208+
request.auth = result # type: ignore
209+
return None
210+
return self.api.create_response(request, {"detail": "Unauthorized"}, status=401)
211+
136212
async def run(self, request: HttpRequest, **kw: Any) -> HttpResponseBase: # type: ignore
137-
error = self._run_checks(request)
213+
error = await self._run_checks(request)
214+
if error:
215+
return error
216+
try:
217+
values = await self._get_values(request, kw) # type: ignore
218+
result = await self.view_func(request, **values)
219+
_processed_results = await self._result_to_response(request, result) # type: ignore
220+
return cast(HttpResponseBase, _processed_results)
221+
except Exception as e:
222+
return self.api.on_exception(request, e)
223+
224+
225+
class AsyncControllerOperation(AsyncOperation, ControllerOperation):
226+
async def run(self, request: HttpRequest, **kw: Any) -> HttpResponseBase: # type: ignore
227+
error = await self._run_checks(request)
138228
if error:
139229
return error
140230
try:
141231
with self._prep_run(request, **kw) as ctx:
142-
values, context = ctx
143-
result = await self.view_func(context=context, **values)
144-
_processed_results = self._result_to_response(request, result)
145-
return _processed_results
232+
values = await self._get_values(request, kw) # type: ignore
233+
ctx.kwargs = values
234+
result = await self.view_func(context=ctx, **values)
235+
_processed_results = await self._result_to_response(request, result) # type: ignore
236+
return cast(HttpResponseBase, _processed_results)
146237
except Exception as e:
147238
return self.api.on_exception(request, e)
148239

@@ -176,12 +267,7 @@ def add_operation(
176267
) -> Operation:
177268
if url_name:
178269
self.url_name = url_name
179-
180-
operation_class = Operation
181-
if is_async(view_func):
182-
self.is_async = True
183-
operation_class = AsyncOperation
184-
270+
operation_class = self.get_operation_class(view_func)
185271
operation = operation_class(
186272
path,
187273
methods,
@@ -203,3 +289,23 @@ def add_operation(
203289

204290
self.operations.append(operation)
205291
return operation
292+
293+
def get_operation_class(
294+
self, view_func: TCallable
295+
) -> Type[Union[Operation, AsyncOperation]]:
296+
operation_class = Operation
297+
if is_async(view_func):
298+
self.is_async = True
299+
operation_class = AsyncOperation
300+
return operation_class
301+
302+
303+
class ControllerPathView(PathView):
304+
def get_operation_class(
305+
self, view_func: TCallable
306+
) -> Type[Union[Operation, AsyncOperation]]:
307+
operation_class = ControllerOperation
308+
if is_async(view_func):
309+
self.is_async = True
310+
operation_class = AsyncControllerOperation
311+
return operation_class

0 commit comments

Comments
 (0)