Skip to content

Commit e7b761a

Browse files
committed
added get_object with check_object_permission support
1 parent 3e99ca2 commit e7b761a

File tree

6 files changed

+123
-33
lines changed

6 files changed

+123
-33
lines changed

ninja_extra/controllers/base.py

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,35 @@
44
Any,
55
Callable,
66
Dict,
7-
Iterator,
7+
Iterable,
88
List,
99
Optional,
1010
Type,
11+
Union,
1112
cast,
1213
no_type_check,
1314
)
1415

16+
from django.db.models import Model, QuerySet
17+
from django.http import HttpResponse
1518
from injector import inject, is_decorated_with_inject
1619
from ninja import NinjaAPI
1720
from ninja.constants import NOT_SET
1821
from ninja.operation import Operation
1922
from ninja.security.base import AuthBase
2023
from ninja.types import DictStrAny
2124

22-
from ninja_extra.exceptions import PermissionDenied
25+
from ninja_extra.exceptions import APIException, NotFound, PermissionDenied
2326
from ninja_extra.operation import PathView
2427
from ninja_extra.permissions import AllowAny, BasePermission
25-
from ninja_extra.shortcuts import fail_silently
28+
from ninja_extra.shortcuts import (
29+
fail_silently,
30+
get_object_or_exception,
31+
get_object_or_none,
32+
)
2633
from ninja_extra.types import PermissionType
2734

28-
from .response import ControllerResponse, Detail, Id, Ok
35+
from .response import Detail, Id, Ok
2936
from .route.route_functions import RouteFunction
3037
from .router import ControllerRouter
3138

@@ -171,7 +178,7 @@ def add_api_operation(
171178
return operation
172179

173180
@classmethod
174-
def get_route_functions(cls) -> Iterator[RouteFunction]:
181+
def get_route_functions(cls) -> Iterable[RouteFunction]:
175182
for method in cls.__dict__.values():
176183
if isinstance(method, RouteFunction):
177184
yield method
@@ -181,7 +188,28 @@ def permission_denied(cls, permission: BasePermission) -> None:
181188
message = getattr(permission, "message", None)
182189
raise PermissionDenied(message)
183190

184-
def get_permissions(self) -> Iterator[BasePermission]:
191+
def get_object_or_exception(
192+
self,
193+
klass: Union[Type[Model], QuerySet],
194+
error_message: str = None,
195+
exception: Type[APIException] = NotFound,
196+
**kwargs: Any,
197+
) -> Any:
198+
obj = get_object_or_exception(
199+
klass=klass, error_message=error_message, exception=exception, **kwargs
200+
)
201+
self.check_object_permissions(obj)
202+
return obj
203+
204+
def get_object_or_none(
205+
self, klass: Union[Type[Model], QuerySet], **kwargs: Any
206+
) -> Optional[Any]:
207+
obj = get_object_or_none(klass=klass, **kwargs)
208+
if obj:
209+
self.check_object_permissions(obj)
210+
return obj
211+
212+
def _get_permissions(self) -> Iterable[BasePermission]:
185213
"""
186214
Instantiates and returns the list of permissions that this view requires.
187215
"""
@@ -197,7 +225,7 @@ def check_permissions(self) -> None:
197225
Check if the request should be permitted.
198226
Raises an appropriate exception if the request is not permitted.
199227
"""
200-
for permission in self.get_permissions():
228+
for permission in self._get_permissions():
201229
if (
202230
self.context
203231
and self.context.request
@@ -207,12 +235,12 @@ def check_permissions(self) -> None:
207235
):
208236
self.permission_denied(permission)
209237

210-
def check_object_permissions(self, obj: Any) -> None:
238+
def check_object_permissions(self, obj: Union[Any, Model]) -> None:
211239
"""
212240
Check if the request should be permitted for a given object.
213241
Raises an appropriate exception if the request is not permitted.
214242
"""
215-
for permission in self.get_permissions():
243+
for permission in self._get_permissions():
216244
if (
217245
self.context
218246
and self.context.request
@@ -222,7 +250,11 @@ def check_object_permissions(self, obj: Any) -> None:
222250
):
223251
self.permission_denied(permission)
224252

225-
def create_response(
226-
self, message: Any, status_code: int = 200
227-
) -> ControllerResponse:
228-
return self.Detail(message=message, status_code=status_code)
253+
def create_response(self, message: Any, status_code: int = 200) -> HttpResponse:
254+
response = self.Detail(message=message, status_code=status_code)
255+
assert self.context and self.context.request and self.api
256+
return self.api.create_response(
257+
request=self.context.request,
258+
data=response.convert_to_schema().dict(),
259+
status=response.status_code,
260+
)

ninja_extra/controllers/router.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,10 +129,10 @@ def urls_paths(self, prefix: str) -> Iterator[URLPattern]:
129129
# to skip lot of checks we simply treat double slash as a mistake:
130130
route = normalize_path(route)
131131
route = route.lstrip("/")
132-
133-
yield django_path(
134-
route, path_view.get_view(), name=cast(str, path_view.url_name)
135-
)
132+
for op in path_view.operations:
133+
yield django_path(
134+
route, path_view.get_view(), name=cast(str, op.url_name)
135+
)
136136

137137
def __repr__(self) -> str:
138138
return f"<controller - {self._controller.__name__}>"

tests/controllers.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,27 @@ class EventSchemaOut(Schema):
2424

2525
@router("events")
2626
class EventController(APIController):
27-
@route.post(
28-
"/create", url_name="event-create-url-name", response={201: EventSchemaOut}
29-
)
27+
@route.post("", url_name="event-create-url-name", response={201: EventSchemaOut})
3028
def create_event(self, event: EventSchema):
3129
event = Event.objects.create(**event.dict())
3230
return 201, event
3331

34-
@route.get("", response=List[EventSchema])
32+
@route.get(
33+
"",
34+
response=List[EventSchema],
35+
url_name="event-list",
36+
)
3537
def list_events(self):
3638
return list(Event.objects.all())
3739

40+
@route.get(
41+
"/list",
42+
response=List[EventSchema],
43+
url_name="event-list-2",
44+
)
45+
def list_events_example_2(self):
46+
return list(Event.objects.all())
47+
3848
@route.get("/{int:id}", response=EventSchema)
3949
def get_event(self, id: int):
4050
event = get_object_or_404(Event, id=id)

tests/test_controller.py

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
from typing import Any, Type, Union
2-
from unittest.mock import Mock
1+
from unittest.mock import Mock, patch
32

43
import pytest
5-
from ninja import Schema
4+
from django.contrib.auth.models import Group
65

7-
from ninja_extra import APIController, NinjaExtraAPI, route, router
8-
from ninja_extra.controllers import RouteFunction
6+
from ninja_extra import APIController, NinjaExtraAPI, exceptions, route, router, testing
7+
from ninja_extra.controllers import RouteContext, RouteFunction
98
from ninja_extra.controllers.base import MissingRouterDecoratorException
109
from ninja_extra.controllers.response import Detail, Id, Ok
1110
from ninja_extra.controllers.router import ControllerRouter
@@ -98,6 +97,52 @@ def test_controller_route_definition_should_return_instance_route_definitions(se
9897
for route_definition in SomeControllerWithRoute.get_route_functions():
9998
assert isinstance(route_definition, RouteFunction)
10099

100+
@pytest.mark.django_db
101+
def test_controller_get_object_or_exception_works(self):
102+
group_instance = Group.objects.create(name="_groupowner")
103+
104+
controller_object = SomeController()
105+
context = RouteContext(request=Mock(), permission_classes=[AllowAny])
106+
controller_object.context = context
107+
with patch.object(
108+
AllowAny, "has_object_permission", return_value=True
109+
) as c_cop:
110+
group = controller_object.get_object_or_exception(
111+
Group, id=group_instance.id
112+
)
113+
c_cop.assert_called()
114+
assert group == group_instance
115+
116+
with pytest.raises(Exception) as ex:
117+
controller_object.get_object_or_exception(Group, id=1000)
118+
assert isinstance(ex, exceptions.NotFound)
119+
120+
with pytest.raises(Exception) as ex:
121+
with patch.object(AllowAny, "has_object_permission", return_value=False):
122+
controller_object.get_object_or_exception(Group, id=group_instance.id)
123+
assert isinstance(ex, exceptions.PermissionDenied)
124+
125+
@pytest.mark.django_db
126+
def test_controller_get_object_or_none_works(self):
127+
group_instance = Group.objects.create(name="_groupowner2")
128+
129+
controller_object = SomeController()
130+
context = RouteContext(request=Mock(), permission_classes=[AllowAny])
131+
controller_object.context = context
132+
with patch.object(
133+
AllowAny, "has_object_permission", return_value=True
134+
) as c_cop:
135+
group = controller_object.get_object_or_none(Group, id=group_instance.id)
136+
c_cop.assert_called()
137+
assert group == group_instance
138+
139+
assert controller_object.get_object_or_none(Group, id=1000) is None
140+
141+
with pytest.raises(Exception) as ex:
142+
with patch.object(AllowAny, "has_object_permission", return_value=False):
143+
controller_object.get_object_or_none(Group, id=group_instance.id)
144+
assert isinstance(ex, exceptions.PermissionDenied)
145+
101146

102147
class TestAPIControllerResponse:
103148
ok_response = Ok("OK")
@@ -122,10 +167,11 @@ def test_controller_response(self):
122167

123168
def test_controller_response_works(self):
124169
detail = Detail("5242", status_code=302)
125-
result = SomeControllerWithRouter.example2(request=Mock(), ex_id="5242")
126-
assert isinstance(result, tuple)
127-
assert result[1] == detail.convert_to_schema()
128-
assert result[0] == detail.status_code
170+
client = testing.TestClient(SomeControllerWithRouter)
171+
response = client.get("/example/5242")
172+
173+
assert response.status_code == 302
174+
assert detail.convert_to_schema().dict() == response.json()
129175

130176
ok_response = Ok("5242")
131177
result = SomeControllerWithRouter.example_with_ok_response(

tests/test_django_model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def test_with_client(client: Client):
1111

1212
test_item = {"start_date": "2020-01-01", "end_date": "2020-01-02", "title": "test"}
1313

14-
response = client.post("/api/events/create", **json_payload(test_item))
14+
response = client.post("/api/events", **json_payload(test_item))
1515
assert response.status_code == 201
1616
assert Event.objects.count() == 1
1717

@@ -26,7 +26,9 @@ def test_with_client(client: Client):
2626

2727
def test_reverse():
2828
# check that url reversing works
29-
assert reverse("api-1.0.0:event-create-url-name") == "/api/events/create"
29+
assert reverse("api-1.0.0:event-create-url-name") == "/api/events"
30+
assert reverse("api-1.0.0:event-list") == "/api/events"
31+
assert reverse("api-1.0.0:event-list-2") == "/api/events/list"
3032

3133

3234
def json_payload(data):

tests/test_event_controller.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class TestEventController:
2020
def test_create_event_works(self):
2121
client = TestClient(EventController)
2222
response = client.post(
23-
"/create",
23+
"",
2424
json=dict(
2525
title="TestEvent1Title",
2626
start_date=str(datetime.now().date()),

0 commit comments

Comments
 (0)