Skip to content

Commit d693541

Browse files
committed
Adapter paginator and pagination operations to both controllers route functions and functional routes
1 parent bf48a4f commit d693541

File tree

5 files changed

+121
-55
lines changed

5 files changed

+121
-55
lines changed

ninja_extra/ordering.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818

1919
from asgiref.sync import sync_to_async
2020
from django.db.models import QuerySet
21-
from ninja import Field, Query, Schema
21+
from django.http import HttpRequest
22+
from ninja import Field, P, Query, Schema
2223
from ninja.constants import NOT_SET
2324
from ninja.signature import is_async
2425
from pydantic import BaseModel
@@ -65,14 +66,24 @@ def __init__(
6566
) -> None:
6667
super().__init__(pass_parameter=pass_parameter)
6768
self.ordering_fields = ordering_fields or "__all__"
69+
self.Input = self.create_input(ordering_fields) # type:ignore
70+
71+
def create_input(self, ordering_fields: Optional[List[str]]) -> Type[Input]:
72+
if ordering_fields:
73+
74+
class DynamicInput(Ordering.Input):
75+
ordering: Query[Optional[str], P(default=",".join(ordering_fields))] # type:ignore[type-arg,valid-type]
76+
77+
return DynamicInput
78+
return Ordering.Input
6879

6980
def ordering_queryset(
7081
self, items: Union[QuerySet, List], ordering_input: Input
7182
) -> Union[QuerySet, List]:
72-
ordering = self.get_ordering(items, ordering_input.ordering)
73-
if ordering:
83+
ordering_ = self.get_ordering(items, ordering_input.ordering)
84+
if ordering_:
7485
if isinstance(items, QuerySet): # type:ignore
75-
return items.order_by(*ordering)
86+
return items.order_by(*ordering_)
7687
elif isinstance(items, list) and items:
7788

7889
def multisort(xs: List, specs: List[Tuple[str, bool]]) -> List:
@@ -85,7 +96,7 @@ def multisort(xs: List, specs: List[Tuple[str, bool]]) -> List:
8596
items,
8697
[
8798
(o[int(o.startswith("-")) :], o.startswith("-"))
88-
for o in ordering
99+
for o in ordering_
89100
],
90101
)
91102
return items
@@ -201,49 +212,51 @@ def __init__(
201212
self.view_func = view_func
202213

203214
orderator_view = self.get_view_function()
204-
_ninja_contribute_args: List[Tuple] = getattr(
205-
self.view_func, "_ninja_contribute_args", []
206-
)
207-
orderator_view._ninja_contribute_args = ( # type:ignore[attr-defined]
208-
_ninja_contribute_args
209-
)
215+
self.as_view = wraps(view_func)(orderator_view)
210216
add_ninja_contribute_args(
211-
orderator_view,
217+
self.as_view,
212218
(
213219
self.orderator_kwargs_name,
214220
self.orderator.Input,
215221
self.orderator.InputSource,
216222
),
217223
)
218224
orderator_view.orderator_operation = self # type:ignore[attr-defined]
219-
self.as_view = wraps(view_func)(orderator_view)
220225

221226
@property
222227
def view_func_has_kwargs(self) -> bool: # pragma: no cover
223228
return self.orderator.pass_parameter is not None
224229

225230
def get_view_function(self) -> Callable:
226-
def as_view(controller: "ControllerBase", *args: Any, **kw: Any) -> Any:
231+
def as_view(
232+
request_or_controller: Union["ControllerBase", HttpRequest],
233+
*args: Any,
234+
**kw: Any,
235+
) -> Any:
227236
func_kwargs = dict(**kw)
228237
ordering_params = func_kwargs.pop(self.orderator_kwargs_name)
229238
if self.orderator.pass_parameter:
230239
func_kwargs[self.orderator.pass_parameter] = ordering_params
231240

232-
items = self.view_func(controller, *args, **func_kwargs)
241+
items = self.view_func(request_or_controller, *args, **func_kwargs)
233242
return self.orderator.ordering_queryset(items, ordering_params)
234243

235244
return as_view
236245

237246

238247
class AsyncOrderatorOperation(OrderatorOperation):
239248
def get_view_function(self) -> Callable:
240-
async def as_view(controller: "ControllerBase", *args: Any, **kw: Any) -> Any:
249+
async def as_view(
250+
request_or_controller: Union["ControllerBase", HttpRequest],
251+
*args: Any,
252+
**kw: Any,
253+
) -> Any:
241254
func_kwargs = dict(**kw)
242255
ordering_params = func_kwargs.pop(self.orderator_kwargs_name)
243256
if self.orderator.pass_parameter:
244257
func_kwargs[self.orderator.pass_parameter] = ordering_params
245258

246-
items = await self.view_func(controller, *args, **func_kwargs)
259+
items = await self.view_func(request_or_controller, *args, **func_kwargs)
247260
ordering_queryset = cast(
248261
Callable, sync_to_async(self.orderator.ordering_queryset)
249262
)

ninja_extra/pagination/operations.py

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
TYPE_CHECKING,
44
Any,
55
Callable,
6-
List,
7-
Tuple,
6+
Union,
87
cast,
98
)
109

1110
from asgiref.sync import sync_to_async
11+
from django.http import HttpRequest
1212
from ninja.pagination import PaginationBase
1313

14+
from ninja_extra.controllers.route.context import RouteContext
1415
from ninja_extra.shortcuts import add_ninja_contribute_args
1516

1617
if TYPE_CHECKING: # pragma: no cover
@@ -30,60 +31,69 @@ def __init__(
3031
self.view_func = view_func
3132

3233
paginator_view = self.get_view_function()
33-
_ninja_contribute_args: List[Tuple] = getattr(
34-
self.view_func, "_ninja_contribute_args", []
35-
)
36-
paginator_view._ninja_contribute_args = ( # type:ignore[attr-defined]
37-
_ninja_contribute_args
38-
)
34+
self.as_view = wraps(view_func)(paginator_view)
3935
add_ninja_contribute_args(
40-
paginator_view,
36+
self.as_view,
4137
(
4238
self.paginator_kwargs_name,
4339
self.paginator.Input,
4440
self.paginator.InputSource,
4541
),
4642
)
4743
paginator_view.paginator_operation = self # type:ignore[attr-defined]
48-
self.as_view = wraps(view_func)(paginator_view)
4944

5045
@property
5146
def view_func_has_kwargs(self) -> bool: # pragma: no cover
5247
return self.paginator.pass_parameter is not None
5348

5449
def get_view_function(self) -> Callable:
55-
def as_view(controller: "ControllerBase", *args: Any, **kw: Any) -> Any:
50+
def as_view(
51+
request_or_controller: Union["ControllerBase", HttpRequest],
52+
*args: Any,
53+
**kw: Any,
54+
) -> Any:
5655
func_kwargs = dict(**kw)
5756
pagination_params = func_kwargs.pop(self.paginator_kwargs_name)
5857
if self.paginator.pass_parameter:
5958
func_kwargs[self.paginator.pass_parameter] = pagination_params
6059

61-
items = self.view_func(controller, *args, **func_kwargs)
62-
assert (
63-
controller.context and controller.context.request
64-
), "Request object is None"
60+
items = self.view_func(request_or_controller, *args, **func_kwargs)
61+
if hasattr(request_or_controller, "context") and isinstance(
62+
request_or_controller.context, RouteContext
63+
):
64+
request = request_or_controller.context.request
65+
assert request, "Request object is None"
66+
else:
67+
request = request_or_controller
6568
params = dict(kw)
66-
params["request"] = controller.context.request
69+
params["request"] = request
6770
return self.paginator.paginate_queryset(items, **params)
6871

6972
return as_view
7073

7174

7275
class AsyncPaginatorOperation(PaginatorOperation):
7376
def get_view_function(self) -> Callable:
74-
async def as_view(controller: "ControllerBase", *args: Any, **kw: Any) -> Any:
77+
async def as_view(
78+
request_or_controller: Union["ControllerBase", HttpRequest],
79+
*args: Any,
80+
**kw: Any,
81+
) -> Any:
7582
func_kwargs = dict(**kw)
7683
pagination_params = func_kwargs.pop(self.paginator_kwargs_name)
7784
if self.paginator.pass_parameter:
7885
func_kwargs[self.paginator.pass_parameter] = pagination_params
7986

80-
items = await self.view_func(controller, *args, **func_kwargs)
81-
assert (
82-
controller.context and controller.context.request
83-
), "Request object is None"
84-
87+
items = await self.view_func(request_or_controller, *args, **func_kwargs)
88+
if hasattr(request_or_controller, "context") and isinstance(
89+
request_or_controller.context, RouteContext
90+
):
91+
request = request_or_controller.context.request
92+
assert request, "Request object is None"
93+
else:
94+
request = request_or_controller
8595
params = dict(kw)
86-
params["request"] = controller.context.request
96+
params["request"] = request
8797
paginate_queryset = cast(
8898
Callable, sync_to_async(self.paginator.paginate_queryset)
8999
)

ninja_extra/searching.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,8 @@ def __init__(
9191
search_fields: Optional[List[str]] = None,
9292
pass_parameter: Optional[str] = None,
9393
) -> None:
94+
super().__init__(pass_parameter=pass_parameter)
9495
self.search_fields = search_fields or []
95-
self.pass_parameter = pass_parameter
9696

9797
def searching_queryset(
9898
self, items: Union[QuerySet, List], searching_input: Input
@@ -240,22 +240,16 @@ def __init__(
240240
self.view_func = view_func
241241

242242
searcherator_view = self.get_view_function()
243-
_ninja_contribute_args: List[Tuple] = getattr(
244-
self.view_func, "_ninja_contribute_args", []
245-
)
246-
searcherator_view._ninja_contribute_args = ( # type:ignore[attr-defined]
247-
_ninja_contribute_args
248-
)
243+
self.as_view = wraps(view_func)(searcherator_view)
249244
add_ninja_contribute_args(
250-
searcherator_view,
245+
self.as_view,
251246
(
252247
self.searcherator_kwargs_name,
253248
self.searcherator.Input,
254249
self.searcherator.InputSource,
255250
),
256251
)
257252
searcherator_view.searcherator_operation = self # type:ignore[attr-defined]
258-
self.as_view = wraps(view_func)(searcherator_view)
259253

260254
@property
261255
def view_func_has_kwargs(self) -> bool: # pragma: no cover

tests/test_ordering.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,16 +189,17 @@ def test_case4(self):
189189
assert response[0]["title"] == "title_2"
190190

191191
schema = api.get_openapi_schema()["paths"]["/api/items_4"]["get"]
192-
# print(schema)
192+
print(schema["parameters"])
193193
assert schema["parameters"] == [
194194
{
195195
"in": "query",
196196
"name": "ordering",
197-
"required": False,
198197
"schema": {
199198
"anyOf": [{"type": "string"}, {"type": "null"}],
199+
"default": "title",
200200
"title": "Ordering",
201201
},
202+
"required": False,
202203
}
203204
]
204205

@@ -382,11 +383,12 @@ async def test_case4(self):
382383
{
383384
"in": "query",
384385
"name": "ordering",
385-
"required": False,
386386
"schema": {
387387
"anyOf": [{"type": "string"}, {"type": "null"}],
388+
"default": "title",
388389
"title": "Ordering",
389390
},
391+
"required": False,
390392
}
391393
]
392394

tests/test_pagination.py

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import django
55
import pytest
6-
from ninja import Schema
6+
from ninja import NinjaAPI, Schema
77

88
from ninja_extra import NinjaExtraAPI, api_controller, route
99
from ninja_extra.controllers import RouteFunction
@@ -385,5 +385,52 @@ async def test_case5(self):
385385
assert data["items"] == ITEMS[10:20]
386386

387387

388-
def test_pagination_extra_get_schema():
389-
pass
388+
def test_pagination_extra_with_ninja_api():
389+
app = NinjaAPI()
390+
391+
@app.get("/items_2", response=NinjaPaginationResponseSchema[int])
392+
@paginate() # with brackets (should use default pagination)
393+
def items_2(request, someparam: int = 0):
394+
# also having custom param `someparam` - that should not be lost
395+
return FakeQuerySet()
396+
397+
@app.get("/items_3")
398+
@paginate(CustomPagination, pass_parameter="pass_kwargs")
399+
def items_3(request, **kwargs):
400+
return ITEMS
401+
402+
_client = TestClient(app)
403+
404+
response = _client.get("/items_3?skip=5")
405+
assert response.json() == ITEMS[5:10]
406+
407+
response = _client.get("/items_2?limit=10").json()
408+
assert response.get("items")
409+
assert response["items"] == ITEMS[:10]
410+
411+
412+
@pytest.mark.skipif(django.VERSION < (3, 1), reason="requires django 3.1 or higher")
413+
@pytest.mark.asyncio
414+
async def test_pagination_extra_with_ninja_api_async():
415+
app = NinjaAPI()
416+
417+
@app.get("/items_2", response=NinjaPaginationResponseSchema[int])
418+
@paginate() # with brackets (should use default pagination)
419+
async def items_2(request, someparam: int = 0):
420+
# also having custom param `someparam` - that should not be lost
421+
return FakeQuerySet()
422+
423+
@app.get("/items_3")
424+
@paginate(CustomPagination, pass_parameter="pass_kwargs")
425+
async def items_3(request, **kwargs):
426+
return ITEMS
427+
428+
_client = TestAsyncClient(app)
429+
430+
response = await _client.get("/items_3?skip=5")
431+
assert response.json() == ITEMS[5:10]
432+
433+
response = await _client.get("/items_2?limit=10")
434+
result = response.json()
435+
assert result.get("items")
436+
assert result["items"] == ITEMS[:10]

0 commit comments

Comments
 (0)