1+ import inspect
12import time
23from contextlib import contextmanager
34from typing import (
89 List ,
910 Optional ,
1011 Sequence ,
12+ Type ,
1113 Union ,
14+ cast ,
1215)
1316
1417from django .http import HttpRequest
2124 PathView as NinjaPathView ,
2225)
2326from ninja .signature import is_async
27+ from ninja .types import TCallable
28+ from ninja .utils import check_csrf
2429
2530from ninja_extra .exceptions import APIException
31+ from ninja_extra .helper import get_function_name
2632from ninja_extra .logger import request_logger
2733from ninja_extra .signals import route_context_finished , route_context_started
2834
3541
3642class Operation (NinjaOperation ):
3743 def __init__ (
38- self , * args : Any , url_name : Optional [str ] = None , ** kwargs : Any
44+ self ,
45+ path : str ,
46+ methods : List [str ],
47+ view_func : Callable ,
48+ * ,
49+ url_name : Optional [str ] = None ,
50+ ** kwargs : Any ,
3951 ) -> None :
40- super (). __init__ ( * args , ** kwargs )
52+ self . is_coroutine = is_async ( view_func )
4153 self .url_name = url_name
54+ super ().__init__ (path , methods , view_func , ** kwargs )
4255 self .signature = ViewSignature (self .path , self .view_func )
4356
57+ def _set_auth (
58+ self , auth : Optional [Union [Sequence [Callable ], Callable , object ]]
59+ ) -> None :
60+ if auth is not None and auth is not NOT_SET :
61+ self .auth_callbacks = isinstance (auth , Sequence ) and auth or [auth ]
62+ for callback in self .auth_callbacks :
63+ _call_back = (
64+ callback if inspect .isfunction (callback ) else callback .__call__ # type: ignore
65+ )
66+
67+ if not getattr (callback , "is_coroutine" , None ):
68+ setattr (callback , "is_coroutine" , is_async (_call_back ))
69+
70+ if is_async (_call_back ) and not self .is_coroutine :
71+ raise Exception (
72+ f"Could apply auth=`{ get_function_name (callback )} ` "
73+ f"to view_func=`{ get_function_name (self .view_func )} `.\n "
74+ f"N:B - { get_function_name (callback )} can only be used on Asynchronous view functions"
75+ )
76+
77+
78+ class ControllerOperation (Operation ):
4479 def _log_action (
4580 self ,
4681 logger : Callable [..., Any ],
@@ -90,10 +125,8 @@ def _prep_run(self, request: HttpRequest, **kw: Any) -> Iterator:
90125 context = self .get_execution_context (request , ** kw )
91126 # send route_context_started signal
92127 route_context_started .send (RouteContext , route_context = context )
93- values = self ._get_values (request , kw )
94- context .kwargs = values
95128
96- yield values , context
129+ yield context
97130 self ._log_action (
98131 request_logger .info ,
99132 request = request ,
@@ -115,15 +148,16 @@ def _prep_run(self, request: HttpRequest, **kw: Any) -> Iterator:
115148 route_context_finished .send (RouteContext , route_context = None )
116149
117150 def run (self , request : HttpRequest , ** kw : Any ) -> HttpResponseBase :
118- error = super ( Operation , self ) ._run_checks (request )
151+ error = self ._run_checks (request )
119152 if error :
120153 return error
121154 try :
122155 with self ._prep_run (request , ** kw ) as ctx :
123- values , context = ctx
124- result = self .view_func (context = context , ** values )
156+ values = self ._get_values (request , kw )
157+ ctx .kwargs = values
158+ result = self .view_func (context = ctx , ** values )
125159 _processed_results = self ._result_to_response (request , result )
126- return _processed_results
160+ return _processed_results
127161 except Exception as e :
128162 if isinstance (e , TypeError ) and "required positional argument" in str (e ):
129163 msg = "Did you fail to use functools.wraps() in a decorator?"
@@ -133,16 +167,73 @@ def run(self, request: HttpRequest, **kw: Any) -> HttpResponseBase:
133167
134168
135169class AsyncOperation (Operation , NinjaAsyncOperation ):
170+ def __init__ (self , * args : Any , ** kwargs : Any ) -> None :
171+ super ().__init__ (* args , ** kwargs )
172+ from asgiref .sync import sync_to_async
173+
174+ self ._get_values = cast (Callable , sync_to_async (super ()._get_values )) # type: ignore
175+ self ._result_to_response = cast ( # type: ignore
176+ Callable ,
177+ sync_to_async (super ()._result_to_response ),
178+ )
179+
180+ async def _run_checks (self , request : HttpRequest ) -> Optional [HttpResponse ]: # type: ignore
181+ """Runs security checks for each operation"""
182+ # auth:
183+ if self .auth_callbacks :
184+ error = await self ._run_authentication (request )
185+ if error :
186+ return error
187+
188+ # csrf:
189+ if self .api .csrf :
190+ error = check_csrf (request , self .view_func )
191+ if error :
192+ return error
193+
194+ return None
195+
196+ async def _run_authentication (self , request : HttpRequest ) -> Optional [HttpResponse ]: # type: ignore
197+ for callback in self .auth_callbacks :
198+ try :
199+ is_coroutine = getattr (callback , "is_coroutine" , False )
200+ if is_coroutine :
201+ result = await callback (request )
202+ else :
203+ result = callback (request )
204+ except Exception as exc :
205+ return self .api .on_exception (request , exc )
206+
207+ if result :
208+ request .auth = result # type: ignore
209+ return None
210+ return self .api .create_response (request , {"detail" : "Unauthorized" }, status = 401 )
211+
136212 async def run (self , request : HttpRequest , ** kw : Any ) -> HttpResponseBase : # type: ignore
137- error = self ._run_checks (request )
213+ error = await self ._run_checks (request )
214+ if error :
215+ return error
216+ try :
217+ values = await self ._get_values (request , kw ) # type: ignore
218+ result = await self .view_func (request , ** values )
219+ _processed_results = await self ._result_to_response (request , result ) # type: ignore
220+ return cast (HttpResponseBase , _processed_results )
221+ except Exception as e :
222+ return self .api .on_exception (request , e )
223+
224+
225+ class AsyncControllerOperation (AsyncOperation , ControllerOperation ):
226+ async def run (self , request : HttpRequest , ** kw : Any ) -> HttpResponseBase : # type: ignore
227+ error = await self ._run_checks (request )
138228 if error :
139229 return error
140230 try :
141231 with self ._prep_run (request , ** kw ) as ctx :
142- values , context = ctx
143- result = await self .view_func (context = context , ** values )
144- _processed_results = self ._result_to_response (request , result )
145- return _processed_results
232+ values = await self ._get_values (request , kw ) # type: ignore
233+ ctx .kwargs = values
234+ result = await self .view_func (context = ctx , ** values )
235+ _processed_results = await self ._result_to_response (request , result ) # type: ignore
236+ return cast (HttpResponseBase , _processed_results )
146237 except Exception as e :
147238 return self .api .on_exception (request , e )
148239
@@ -176,12 +267,7 @@ def add_operation(
176267 ) -> Operation :
177268 if url_name :
178269 self .url_name = url_name
179-
180- operation_class = Operation
181- if is_async (view_func ):
182- self .is_async = True
183- operation_class = AsyncOperation
184-
270+ operation_class = self .get_operation_class (view_func )
185271 operation = operation_class (
186272 path ,
187273 methods ,
@@ -203,3 +289,23 @@ def add_operation(
203289
204290 self .operations .append (operation )
205291 return operation
292+
293+ def get_operation_class (
294+ self , view_func : TCallable
295+ ) -> Type [Union [Operation , AsyncOperation ]]:
296+ operation_class = Operation
297+ if is_async (view_func ):
298+ self .is_async = True
299+ operation_class = AsyncOperation
300+ return operation_class
301+
302+
303+ class ControllerPathView (PathView ):
304+ def get_operation_class (
305+ self , view_func : TCallable
306+ ) -> Type [Union [Operation , AsyncOperation ]]:
307+ operation_class = ControllerOperation
308+ if is_async (view_func ):
309+ self .is_async = True
310+ operation_class = AsyncControllerOperation
311+ return operation_class
0 commit comments