Skip to content

Commit 2312d0b

Browse files
committed
added support for endpoint return type as response type
1 parent b34c038 commit 2312d0b

File tree

8 files changed

+45
-15
lines changed

8 files changed

+45
-15
lines changed

.github/workflows/publish.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313
- name: Set up Python
1414
uses: actions/setup-python@v4
1515
with:
16-
python-version: 3.6
16+
python-version: 3.8
1717
- name: Install Flit
1818
run: pip install flit
1919
- name: Install Dependencies
@@ -24,4 +24,4 @@ jobs:
2424
FLIT_PASSWORD: ${{ secrets.FLIT_PASSWORD }}
2525
run: flit publish
2626
- name: Deploy Documentation
27-
run: make doc-deploy
27+
run: make doc-deploy

ninja_extra/controllers/route/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import inspect
2-
from typing import Any, Callable, List, Optional, Type, Union, cast
2+
from typing import Any, Callable, List, Optional, Type, Union, cast, get_type_hints
33

44
from ninja.constants import NOT_SET
55
from ninja.signature import is_async
@@ -125,6 +125,8 @@ def _create_route_function(
125125
include_in_schema: bool = True,
126126
permissions: Optional[List[Type[BasePermission]]] = None,
127127
) -> RouteFunction:
128+
if response is NOT_SET:
129+
response = get_type_hints(view_func).get("return") or NOT_SET
128130
route_obj = cls(
129131
view_func,
130132
path=path,

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ classifiers = [
3535
"Framework :: Django :: 3.1",
3636
"Framework :: Django :: 3.2",
3737
"Framework :: Django :: 4.0",
38+
"Framework :: Django :: 4.1",
3839
"Framework :: AsyncIO",
3940
"Topic :: Internet :: WWW/HTTP :: HTTP Servers",
4041
"Topic :: Internet :: WWW/HTTP",
@@ -72,7 +73,8 @@ doc = [
7273
"mkdocs-material >=7.1.9,<8.0.0",
7374
"mdx-include >=1.4.1,<2.0.0",
7475
"mkdocs-markdownextradata-plugin >=0.1.7,<0.3.0",
75-
"markdown-include"
76+
"markdown-include",
77+
"mkdocstrings"
7678
]
7779

7880
dev = [

tests/controllers.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,12 @@ def list_events(self):
4545
def list_events_example_2(self):
4646
return list(Event.objects.all())
4747

48-
@http_get("/{int:id}", response=EventSchema)
49-
def get_event(self, id: int):
48+
@http_get("/{int:id}")
49+
def get_event(self, id: int) -> EventSchema:
5050
event = get_object_or_404(Event, id=id)
5151
return event
52+
53+
@http_get("/{int:id}/from-orm")
54+
def get_event_from_orm(self, id: int) -> EventSchema:
55+
event = get_object_or_404(Event, id=id)
56+
return EventSchema.from_orm(event)

tests/schemas.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from pydantic import BaseModel
2+
3+
4+
class UserSchema(BaseModel):
5+
name: str
6+
age: int

tests/test_controller.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import django
55
import pytest
66
from django.contrib.auth.models import Group
7-
from pydantic import UUID4, BaseModel
87

98
from ninja_extra import (
109
NinjaExtraAPI,
@@ -19,14 +18,10 @@
1918
from ninja_extra.controllers.response import Detail, Id, Ok
2019
from ninja_extra.permissions.common import AllowAny
2120

21+
from .schemas import UserSchema
2222
from .utils import AsyncFakeAuth, FakeAuth
2323

2424

25-
class UserSchema(BaseModel):
26-
name: str
27-
age: int
28-
29-
3025
@api_controller
3126
class SomeController:
3227
pass

tests/test_event_controller.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,20 @@ def test_list_events_works(self):
4949
]
5050
assert event_schema == data
5151

52-
def test_get_event_works(self):
52+
@pytest.mark.parametrize(
53+
"path",
54+
[
55+
"/{event_id}",
56+
"/{event_id}/from-orm",
57+
],
58+
)
59+
def test_get_event_works(self, path):
5360
object_data = self.dummy_data.copy()
5461
object_data.update(title=f"{object_data['title']}_get")
5562

5663
event = Event.objects.create(**object_data)
5764
client = TestClient(EventController)
58-
response = client.get(f"/{event.id}")
65+
response = client.get(path.format(event_id=event.id))
5966
assert response.status_code == 200
6067
data = response.json()
6168
event_schema = json.loads(EventSchema.from_orm(event).json())

tests/test_route.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44
import pytest
55
from django.contrib.auth.models import AnonymousUser, User
66
from ninja import Schema
7+
from ninja.constants import NOT_SET
78

89
from ninja_extra import api_controller, permissions, route
910
from ninja_extra.controllers import (
1011
AsyncRouteFunction,
1112
Detail,
1213
Id,
1314
Ok,
14-
Route,
1515
RouteFunction,
1616
RouteInvalidParameterException,
1717
)
@@ -20,6 +20,7 @@
2020
from ninja_extra.exceptions import PermissionDenied
2121
from ninja_extra.permissions import AllowAny
2222

23+
from .schemas import UserSchema
2324
from .utils import FakeAuth
2425

2526
anonymous_request = Mock()
@@ -71,6 +72,10 @@ def example_list_create(self, ex_id: str):
7172
def example_post_operation_id(self):
7273
pass
7374

75+
@route.get("/example/return-response-as-schema")
76+
def function_return_as_response_schema(self) -> UserSchema:
77+
pass
78+
7479

7580
class TestControllerRoute:
7681
@pytest.mark.parametrize(
@@ -114,6 +119,14 @@ def test_controller_route_should_right_view_func_type(self):
114119
== SomeTestController.example
115120
)
116121

122+
def test_controller_route_should_use_userschema_as_response(self):
123+
route_function: RouteFunction = SomeTestController.example
124+
assert route_function.route.route_params.response == NOT_SET
125+
route_function: RouteFunction = (
126+
SomeTestController.function_return_as_response_schema
127+
)
128+
assert route_function.route.route_params.response == UserSchema
129+
117130
def test_route_generic_invalid_parameters(self):
118131
with pytest.raises(RouteInvalidParameterException) as ex:
119132

0 commit comments

Comments
 (0)