|
4 | 4 | from typing import ( |
5 | 5 | TYPE_CHECKING, |
6 | 6 | Any, |
| 7 | + AsyncIterator, |
7 | 8 | Callable, |
8 | 9 | Iterator, |
9 | 10 | List, |
|
27 | 28 | from ninja.types import TCallable |
28 | 29 | from ninja.utils import check_csrf |
29 | 30 |
|
| 31 | +from ninja_extra.compatible import asynccontextmanager |
30 | 32 | from ninja_extra.exceptions import APIException |
31 | 33 | from ninja_extra.helper import get_function_name |
32 | 34 | from ninja_extra.logger import request_logger |
@@ -119,7 +121,7 @@ def get_execution_context( |
119 | 121 | return route_function.get_route_execution_context(request, *args, **kwargs) |
120 | 122 |
|
121 | 123 | @contextmanager |
122 | | - def _prep_run(self, request: HttpRequest, **kw: Any) -> Iterator: |
| 124 | + def _prep_run(self, request: HttpRequest, **kw: Any) -> Iterator[RouteContext]: |
123 | 125 | try: |
124 | 126 | start_time = time.time() |
125 | 127 | context = self.get_execution_context(request, **kw) |
@@ -223,12 +225,43 @@ async def run(self, request: HttpRequest, **kw: Any) -> HttpResponseBase: # typ |
223 | 225 |
|
224 | 226 |
|
225 | 227 | class AsyncControllerOperation(AsyncOperation, ControllerOperation): |
| 228 | + @asynccontextmanager |
| 229 | + async def _prep_run( # type:ignore |
| 230 | + self, request: HttpRequest, **kw: Any |
| 231 | + ) -> AsyncIterator[RouteContext]: |
| 232 | + try: |
| 233 | + start_time = time.time() |
| 234 | + context = self.get_execution_context(request, **kw) |
| 235 | + # send route_context_started signal |
| 236 | + route_context_started.send(RouteContext, route_context=context) |
| 237 | + |
| 238 | + yield context |
| 239 | + self._log_action( |
| 240 | + request_logger.info, |
| 241 | + request=request, |
| 242 | + duration=time.time() - start_time, |
| 243 | + extra=dict(request=request), |
| 244 | + exc_info=None, |
| 245 | + ) |
| 246 | + except Exception as e: |
| 247 | + self._log_action( |
| 248 | + request_logger.error, |
| 249 | + request=request, |
| 250 | + ex=e, |
| 251 | + extra=dict(request=request), |
| 252 | + exc_info=None, |
| 253 | + ) |
| 254 | + raise e |
| 255 | + finally: |
| 256 | + # send route_context_finished signal |
| 257 | + route_context_finished.send(RouteContext, route_context=None) |
| 258 | + |
226 | 259 | async def run(self, request: HttpRequest, **kw: Any) -> HttpResponseBase: # type: ignore |
227 | 260 | error = await self._run_checks(request) |
228 | 261 | if error: |
229 | 262 | return error |
230 | 263 | try: |
231 | | - with self._prep_run(request, **kw) as ctx: |
| 264 | + async with self._prep_run(request, **kw) as ctx: |
232 | 265 | values = await self._get_values(request, kw) # type: ignore |
233 | 266 | ctx.kwargs = values |
234 | 267 | result = await self.view_func(context=ctx, **values) |
|
0 commit comments