Skip to content

Commit cd8a853

Browse files
committed
Throttle model refactor to support ninja v1.2.0
1 parent 30066d5 commit cd8a853

File tree

16 files changed

+404
-376
lines changed

16 files changed

+404
-376
lines changed

ninja_extra/conf/settings.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,12 @@ class Config:
4141
THROTTLE_RATES: Dict[str, Optional[str]] = Field(
4242
{"user": "1000/day", "anon": "100/day"}
4343
)
44-
THROTTLE_CLASSES: List[Any] = []
44+
THROTTLE_CLASSES: List[Any] = Field(
45+
[
46+
"ninja_extra.throttling.AnonRateThrottle",
47+
"ninja_extra.throttling.UserRateThrottle",
48+
]
49+
)
4550
NUM_PROXIES: Optional[int] = None
4651
INJECTOR_MODULES: List[Any] = []
4752
ORDERING_CLASS: Any = Field(

ninja_extra/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
TRACE = "TRACE"
1616
ROUTE_METHODS = [POST, PUT, PATCH, DELETE, GET, HEAD, OPTIONS, TRACE]
1717
THROTTLED_FUNCTION = "__throttled_endpoint__"
18+
THROTTLED_OBJECTS = "__throttled_objects__"
1819
ROUTE_FUNCTION = "__route_function__"
1920

2021
ROUTE_CONTEXT_VAR: contextvars.ContextVar[t.Optional["RouteContext"]] = (

ninja_extra/controllers/base.py

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,13 @@
2626
from django.urls import path as django_path
2727
from injector import inject, is_decorated_with_inject
2828
from ninja import NinjaAPI, Router
29-
from ninja.constants import NOT_SET
29+
from ninja.constants import NOT_SET, NOT_SET_TYPE
3030
from ninja.security.base import AuthBase
3131
from ninja.signature import is_async
32+
from ninja.throttling import BaseThrottle
3233
from ninja.utils import normalize_path
3334

34-
from ninja_extra.constants import ROUTE_FUNCTION, THROTTLED_FUNCTION
35+
from ninja_extra.constants import ROUTE_FUNCTION, THROTTLED_FUNCTION, THROTTLED_OBJECTS
3536
from ninja_extra.exceptions import APIException, NotFound, PermissionDenied, bad_request
3637
from ninja_extra.helper import get_function_name
3738
from ninja_extra.operation import Operation, PathView
@@ -51,7 +52,6 @@
5152

5253
if TYPE_CHECKING: # pragma: no cover
5354
from ninja_extra import NinjaExtraAPI
54-
from ninja_extra.throttling import BaseThrottle
5555

5656
from .route.context import RouteContext
5757

@@ -126,10 +126,10 @@ def some_method_name(self):
126126
throttling_classes: List[Type["BaseThrottle"]] = []
127127
throttling_init_kwargs: Optional[Dict[Any, Any]] = None
128128

129-
Ok = Ok
130-
Id = Id
131-
Detail = Detail
132-
bad_request = bad_request
129+
Ok = Ok # TODO: remove soonest
130+
Id = Id # TODO: remove soonest
131+
Detail = Detail # TODO: remove soonest
132+
bad_request = bad_request # TODO: remove soonest
133133

134134
@classmethod
135135
def get_api_controller(cls) -> "APIController":
@@ -294,6 +294,7 @@ def __init__(
294294
prefix: str,
295295
*,
296296
auth: Any = NOT_SET,
297+
throttle: Union[BaseThrottle, List[BaseThrottle], NOT_SET_TYPE] = NOT_SET,
297298
tags: Union[Optional[List[str]], str] = None,
298299
permissions: Optional["PermissionType"] = None,
299300
auto_import: bool = True,
@@ -303,6 +304,7 @@ def __init__(
303304
self.auth: Optional[AuthBase] = auth
304305

305306
self.tags = tags # type: ignore
307+
self.throttle = throttle
306308

307309
self.auto_import: bool = auto_import # set to false and it would be ignored when api.auto_discover is called
308310
# `controller_class` target class that the APIController wraps
@@ -348,8 +350,6 @@ def tags(self, value: Union[str, List[str], None]) -> None:
348350
self._tags = tag
349351

350352
def __call__(self, cls: ControllerClassType) -> ControllerClassType:
351-
from ninja_extra.throttling import throttle
352-
353353
self.auto_import = getattr(cls, "auto_import", self.auto_import)
354354
if not issubclass(cls, ControllerBase):
355355
# We force the cls to inherit from `ControllerBase` by creating another type.
@@ -360,8 +360,15 @@ def __call__(self, cls: ControllerClassType) -> ControllerClassType:
360360
assert isinstance(
361361
cls.throttling_classes, (list, tuple)
362362
), f"Controller[{cls.__name__}].throttling_class must be a list or tuple"
363-
has_throttling_classes = len(cls.throttling_classes) > 0
364-
throttling_init_kwargs = cls.throttling_init_kwargs or {}
363+
364+
throttling_objects: Union[BaseThrottle, List[BaseThrottle], NOT_SET_TYPE]
365+
if cls.throttling_classes:
366+
throttling_init_kwargs = cls.throttling_init_kwargs or {}
367+
throttling_objects = [
368+
item(**throttling_init_kwargs) for item in cls.throttling_classes
369+
]
370+
else:
371+
throttling_objects = self.throttle
365372

366373
if not self.tags:
367374
tag = str(cls.__name__).lower().replace("controller", "")
@@ -386,10 +393,13 @@ def __call__(self, cls: ControllerClassType) -> ControllerClassType:
386393

387394
for _, v in self._controller_class_route_functions.items():
388395
throttled_endpoint = v.as_view.__dict__.get(THROTTLED_FUNCTION)
389-
if not throttled_endpoint and has_throttling_classes:
390-
v.route.view_func = throttle(
391-
*cls.throttling_classes, **throttling_init_kwargs
392-
)(v.route.view_func)
396+
397+
if throttled_endpoint or throttling_objects is not NOT_SET:
398+
v.route.route_params.throttle = v.as_view.__dict__.get(
399+
THROTTLED_OBJECTS, lambda: throttling_objects
400+
)()
401+
setattr(v.route.view_func, THROTTLED_FUNCTION, True)
402+
393403
self._add_operation_from_route_function(v)
394404

395405
if not is_decorated_with_inject(cls.__init__):
@@ -463,6 +473,7 @@ def add_api_operation(
463473
view_func: Callable,
464474
*,
465475
auth: Any = NOT_SET,
476+
throttle: Union[BaseThrottle, List[BaseThrottle], NOT_SET_TYPE] = NOT_SET,
466477
response: Any = NOT_SET,
467478
operation_id: Optional[str] = None,
468479
summary: Optional[str] = None,
@@ -504,13 +515,14 @@ def add_api_operation(
504515
url_name=url_name,
505516
include_in_schema=include_in_schema,
506517
openapi_extra=openapi_extra,
518+
throttle=throttle,
507519
)
508520
return operation
509521

510522

511523
@overload
512524
def api_controller(
513-
prefix_or_class: Type[T],
525+
prefix_or_class: Union[ControllerClassType, Type[T]],
514526
) -> Union[Type[ControllerBase], Type[T]]: # pragma: no cover
515527
...
516528

@@ -522,12 +534,14 @@ def api_controller(
522534
tags: Union[Optional[List[str]], str] = None,
523535
permissions: Optional["PermissionType"] = None,
524536
auto_import: bool = True,
525-
) -> Callable[[Type[T]], Union[Type[ControllerBase], Type[T]]]: # pragma: no cover
537+
) -> Callable[
538+
[Union[Type, Type[T]]], Union[Type[ControllerBase], Type[T]]
539+
]: # pragma: no cover
526540
...
527541

528542

529543
def api_controller(
530-
prefix_or_class: Union[str, ControllerClassType] = "",
544+
prefix_or_class: Union[str, Union[ControllerClassType, Type]] = "",
531545
auth: Any = NOT_SET,
532546
tags: Union[Optional[List[str]], str] = None,
533547
permissions: Optional["PermissionType"] = None,

ninja_extra/controllers/route/__init__.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import typing as t
22

3-
from ninja.constants import NOT_SET
3+
from ninja.constants import NOT_SET, NOT_SET_TYPE
44
from ninja.signature import is_async
5+
from ninja.throttling import BaseThrottle
56
from ninja.types import TCallable
67

78
from ninja_extra.constants import (
@@ -52,6 +53,7 @@ def __init__(
5253
path: str,
5354
methods: t.List[str],
5455
auth: t.Any = NOT_SET,
56+
throttle: t.Union[BaseThrottle, t.List[BaseThrottle], NOT_SET_TYPE] = NOT_SET,
5557
response: t.Union[t.Any, t.List[t.Any]] = NOT_SET,
5658
operation_id: t.Optional[str] = None,
5759
summary: t.Optional[str] = None,
@@ -97,6 +99,7 @@ def __init__(
9799
path=path,
98100
methods=methods,
99101
auth=auth,
102+
throttle=throttle,
100103
response=_response,
101104
operation_id=operation_id,
102105
summary=summary,
@@ -124,6 +127,7 @@ def _create_route_function(
124127
path: str,
125128
methods: t.List[str],
126129
auth: t.Any = NOT_SET,
130+
throttle: t.Union[BaseThrottle, t.List[BaseThrottle], NOT_SET_TYPE] = NOT_SET,
127131
response: t.Union[t.Any, t.List[t.Any]] = NOT_SET,
128132
operation_id: t.Optional[str] = None,
129133
summary: t.Optional[str] = None,
@@ -164,6 +168,7 @@ def _create_route_function(
164168
include_in_schema=include_in_schema,
165169
permissions=permissions,
166170
openapi_extra=openapi_extra,
171+
throttle=throttle,
167172
)
168173
route_function_class = RouteFunction
169174
if route_obj.is_async:
@@ -178,6 +183,7 @@ def get(
178183
path: str = "",
179184
*,
180185
auth: t.Any = NOT_SET,
186+
throttle: t.Union[BaseThrottle, t.List[BaseThrottle], NOT_SET_TYPE] = NOT_SET,
181187
response: t.Union[t.Any, t.List[t.Any]] = NOT_SET,
182188
operation_id: t.Optional[str] = None,
183189
summary: t.Optional[str] = None,
@@ -242,6 +248,7 @@ def decorator(view_func: TCallable) -> TCallable:
242248
include_in_schema=include_in_schema,
243249
permissions=permissions,
244250
openapi_extra=openapi_extra,
251+
throttle=throttle,
245252
)
246253

247254
return decorator
@@ -252,6 +259,7 @@ def post(
252259
path: str = "",
253260
*,
254261
auth: t.Any = NOT_SET,
262+
throttle: t.Union[BaseThrottle, t.List[BaseThrottle], NOT_SET_TYPE] = NOT_SET,
255263
response: t.Union[t.Any, t.List[t.Any]] = NOT_SET,
256264
operation_id: t.Optional[str] = None,
257265
summary: t.Optional[str] = None,
@@ -316,6 +324,7 @@ def decorator(view_func: TCallable) -> TCallable:
316324
include_in_schema=include_in_schema,
317325
permissions=permissions,
318326
openapi_extra=openapi_extra,
327+
throttle=throttle,
319328
)
320329

321330
return decorator
@@ -326,6 +335,7 @@ def delete(
326335
path: str = "",
327336
*,
328337
auth: t.Any = NOT_SET,
338+
throttle: t.Union[BaseThrottle, t.List[BaseThrottle], NOT_SET_TYPE] = NOT_SET,
329339
response: t.Union[t.Any, t.List[t.Any]] = NOT_SET,
330340
operation_id: t.Optional[str] = None,
331341
summary: t.Optional[str] = None,
@@ -390,6 +400,7 @@ def decorator(view_func: TCallable) -> TCallable:
390400
include_in_schema=include_in_schema,
391401
permissions=permissions,
392402
openapi_extra=openapi_extra,
403+
throttle=throttle,
393404
)
394405

395406
return decorator
@@ -400,6 +411,7 @@ def patch(
400411
path: str = "",
401412
*,
402413
auth: t.Any = NOT_SET,
414+
throttle: t.Union[BaseThrottle, t.List[BaseThrottle], NOT_SET_TYPE] = NOT_SET,
403415
response: t.Union[t.Any, t.List[t.Any]] = NOT_SET,
404416
operation_id: t.Optional[str] = None,
405417
summary: t.Optional[str] = None,
@@ -465,6 +477,7 @@ def decorator(view_func: TCallable) -> TCallable:
465477
include_in_schema=include_in_schema,
466478
permissions=permissions,
467479
openapi_extra=openapi_extra,
480+
throttle=throttle,
468481
)
469482

470483
return decorator
@@ -475,6 +488,7 @@ def put(
475488
path: str = "",
476489
*,
477490
auth: t.Any = NOT_SET,
491+
throttle: t.Union[BaseThrottle, t.List[BaseThrottle], NOT_SET_TYPE] = NOT_SET,
478492
response: t.Union[t.Any, t.List[t.Any]] = NOT_SET,
479493
operation_id: t.Optional[str] = None,
480494
summary: t.Optional[str] = None,
@@ -540,6 +554,7 @@ def decorator(view_func: TCallable) -> TCallable:
540554
include_in_schema=include_in_schema,
541555
permissions=permissions,
542556
openapi_extra=openapi_extra,
557+
throttle=throttle,
543558
)
544559

545560
return decorator
@@ -551,6 +566,7 @@ def generic(
551566
*,
552567
methods: t.List[str],
553568
auth: t.Any = NOT_SET,
569+
throttle: t.Union[BaseThrottle, t.List[BaseThrottle], NOT_SET_TYPE] = NOT_SET,
554570
response: t.Union[t.Any, t.List[t.Any]] = NOT_SET,
555571
operation_id: t.Optional[str] = None,
556572
summary: t.Optional[str] = None,
@@ -617,6 +633,7 @@ def decorator(view_func: TCallable) -> TCallable:
617633
include_in_schema=include_in_schema,
618634
permissions=permissions,
619635
openapi_extra=openapi_extra,
636+
throttle=throttle,
620637
)
621638

622639
return decorator

ninja_extra/operation.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from django.http import HttpRequest
2020
from django.http.response import HttpResponse, HttpResponseBase
2121
from django.utils.encoding import force_str
22-
from ninja.constants import NOT_SET
22+
from ninja.constants import NOT_SET, NOT_SET_TYPE
2323
from ninja.errors import AuthenticationError
2424
from ninja.operation import (
2525
AsyncOperation as NinjaAsyncOperation,
@@ -31,12 +31,13 @@
3131
PathView as NinjaPathView,
3232
)
3333
from ninja.signature import is_async
34+
from ninja.throttling import BaseThrottle
3435
from ninja.types import TCallable
3536
from ninja.utils import check_csrf
3637

3738
from ninja_extra.compatible import asynccontextmanager
3839
from ninja_extra.constants import ROUTE_CONTEXT_VAR
39-
from ninja_extra.exceptions import APIException
40+
from ninja_extra.exceptions import APIException, Throttled
4041
from ninja_extra.helper import get_function_name
4142
from ninja_extra.logger import request_logger
4243

@@ -204,6 +205,22 @@ def run(self, request: HttpRequest, **kw: Any) -> HttpResponseBase:
204205
e.args = (msg,) + e.args[1:]
205206
return self.api.on_exception(request, e)
206207

208+
def _check_throttles(self, request: HttpRequest) -> Optional[HttpResponse]:
209+
throttle_durations = []
210+
for throttle in self.throttle_objects:
211+
if not throttle.allow_request(request):
212+
throttle_durations.append(throttle.wait())
213+
214+
if throttle_durations:
215+
# Filter out `None` values which may happen in case of config / rate
216+
durations = [
217+
duration for duration in throttle_durations if duration is not None
218+
]
219+
220+
duration = max(durations, default=None)
221+
raise Throttled(wait=duration)
222+
return None
223+
207224

208225
class AsyncOperation(Operation, NinjaAsyncOperation):
209226
def __init__(self, *args: Any, **kwargs: Any) -> None:
@@ -218,15 +235,21 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
218235

219236
async def _run_checks(self, request: HttpRequest) -> Optional[HttpResponse]: # type: ignore
220237
"""Runs security checks for each operation"""
238+
# csrf:
239+
if self.api.csrf:
240+
error = check_csrf(request, self.view_func)
241+
if error:
242+
return error
243+
221244
# auth:
222245
if self.auth_callbacks:
223-
error = await self._run_authentication(request)
246+
error = await self._run_authentication(request) # type: ignore[assignment]
224247
if error:
225248
return error
226249

227-
# csrf:
228-
if self.api.csrf:
229-
error = check_csrf(request, self.view_func)
250+
# Throttling:
251+
if self.throttle_objects:
252+
error = self._check_throttles(request) # type: ignore
230253
if error:
231254
return error
232255

@@ -295,7 +318,7 @@ async def run(self, request: HttpRequest, **kw: Any) -> HttpResponseBase: # typ
295318
result = await self.view_func(request, **values)
296319
_processed_results = await self._result_to_response(
297320
request, result, temporal_response
298-
) # type: ignore
321+
)
299322
return cast(HttpResponseBase, _processed_results)
300323
except Exception as e:
301324
return self.api.on_exception(request, e)
@@ -317,6 +340,7 @@ def add_operation(
317340
view_func: Callable,
318341
*,
319342
auth: Optional[Union[Sequence[Callable], Callable, object]] = NOT_SET,
343+
throttle: Union[BaseThrottle, List[BaseThrottle], NOT_SET_TYPE] = NOT_SET,
320344
response: Any = NOT_SET,
321345
operation_id: Optional[str] = None,
322346
summary: Optional[str] = None,
@@ -352,6 +376,7 @@ def add_operation(
352376
include_in_schema=include_in_schema,
353377
url_name=url_name,
354378
openapi_extra=openapi_extra,
379+
throttle=throttle,
355380
)
356381

357382
self.operations.append(operation)

0 commit comments

Comments
 (0)