Skip to content

Commit 119a214

Browse files
Add support for query arguments (#107)
1 parent 6ae4c43 commit 119a214

File tree

5 files changed

+257
-44
lines changed

5 files changed

+257
-44
lines changed

aiohttp_apischema/generator.py

Lines changed: 124 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
import inspect
33
import sys
44
from collections.abc import Awaitable, Callable, Mapping
5+
from functools import partial
56
from http import HTTPStatus
67
from pathlib import Path
78
from types import UnionType
8-
from typing import Any, Iterable, Literal, TypedDict, TypeGuard, TypeVar, cast, get_args, get_origin
9+
from typing import (Any, Annotated, Concatenate, Generic, Iterable, Literal, ParamSpec,
10+
Protocol, TypeGuard, TypeVar, cast, get_args, get_origin, get_type_hints)
911

1012
from aiohttp import web
1113
from aiohttp.hdrs import METH_ALL
@@ -14,21 +16,42 @@
1416

1517
from aiohttp_apischema.response import APIResponse
1618

19+
if sys.version_info >= (3, 12):
20+
from typing import TypedDict
21+
else:
22+
from typing_extensions import TypedDict
23+
1724
if sys.version_info >= (3, 11):
18-
from typing import Required
25+
from typing import NotRequired, Required
1926
else:
20-
from typing_extensions import Required
27+
from typing_extensions import NotRequired, Required
2128

2229
OPENAPI_METHODS = frozenset({"get", "put", "post", "delete", "options", "head", "patch", "trace"})
2330

2431
_T = TypeVar("_T")
25-
_Resp = TypeVar("_Resp", bound=APIResponse[Any, Any])
32+
_U = TypeVar("_U")
33+
_P = ParamSpec("_P")
34+
_Resp = TypeVar("_Resp", bound=APIResponse[Any, Any], covariant=True)
2635
_View = TypeVar("_View", bound=web.View)
36+
OpenAPIMethod = Literal["get", "put", "post", "delete", "options", "head", "patch", "trace"]
37+
__ModelKey = tuple[str, OpenAPIMethod, _T, _U]
38+
_ModelKey = (
39+
__ModelKey[Literal["requestBody"], None]
40+
| __ModelKey[Literal["parameter"], tuple[str, bool]]
41+
| __ModelKey[Literal["response"], int]
42+
)
43+
44+
class _APIHandler(Protocol, Generic[_Resp]):
45+
def __call__(self, request: web.Request, *, query: Any) -> Awaitable[_Resp]:
46+
...
47+
48+
2749
APIHandler = (
2850
Callable[[web.Request], Awaitable[_Resp]]
2951
| Callable[[web.Request, Any], Awaitable[_Resp]]
52+
| _APIHandler[_Resp]
3053
)
31-
OpenAPIMethod = Literal["get", "put", "post", "delete", "options", "head", "patch", "trace"]
54+
3255

3356
class Contact(TypedDict, total=False):
3457
name: str
@@ -57,7 +80,9 @@ class Info(TypedDict, total=False):
5780
class _EndpointData(TypedDict, total=False):
5881
body: TypeAdapter[object]
5982
desc: str
60-
resps: dict[int, TypeAdapter[Any]]
83+
query: TypeAdapter[dict[str, object]]
84+
query_raw: dict[str, object]
85+
resps: dict[int, TypeAdapter[object]]
6186
summary: str
6287
tags: list[str]
6388

@@ -72,6 +97,16 @@ class _Components(TypedDict, total=False):
7297
class _MediaTypeObject(TypedDict, total=False):
7398
schema: object
7499

100+
# in is a reserved keyword.
101+
_ParameterObject = TypedDict("_ParameterObject", {
102+
"deprecated": bool,
103+
"description": str,
104+
"name": Required[str],
105+
"in": Required[Literal["query", "header", "path", "cookie"]],
106+
"required": bool,
107+
"schema": object
108+
}, total=False)
109+
75110
class _RequestBodyObject(TypedDict, total=False):
76111
content: Required[dict[str, _MediaTypeObject]]
77112

@@ -82,6 +117,7 @@ class _ResponseObject(TypedDict, total=False):
82117
class _OperationObject(TypedDict, total=False):
83118
description: str
84119
operationId: str
120+
parameters: list[_ParameterObject]
85121
requestBody: _RequestBodyObject
86122
responses: dict[str, _ResponseObject]
87123
summary: str
@@ -125,18 +161,43 @@ class _OpenApi(TypedDict, total=False):
125161
</html>"""
126162
SWAGGER_PATH = Path(__file__).parent / "swagger-ui"
127163

164+
_Wrapper = Callable[[APIHandler[_Resp], web.Request], Awaitable[_Resp]]
165+
128166
def is_openapi_method(method: str) -> TypeGuard[OpenAPIMethod]:
129167
return method in OPENAPI_METHODS
130168

131-
def create_view_wrapper(handler: Callable[[_View, _T], Awaitable[_Resp]], ta: TypeAdapter[_T]) -> Callable[[_View], Awaitable[_Resp]]:
132-
@functools.wraps(handler)
133-
async def wrapper(self: _View) -> _Resp: # type: ignore[misc]
134-
try:
135-
request_body = ta.validate_python(await self.request.read())
136-
except ValidationError as e:
137-
raise web.HTTPBadRequest(text=e.json(), content_type="application/json")
138-
return await handler(self, request_body)
139-
return wrapper
169+
def make_wrapper(ep_data: _EndpointData, wrapped: APIHandler[_Resp], handler: Callable[Concatenate[_Wrapper[_Resp], APIHandler[_Resp], _P], Awaitable[_Resp]]) -> Callable[_P, Awaitable[_Resp]] | None:
170+
# Only these keys need a wrapper created.
171+
if not {"body", "query_raw"} & ep_data.keys():
172+
return None
173+
174+
async def _wrapper(handler: APIHandler[_Resp], request: web.Request) -> _Resp:
175+
inner_handler: Callable[..., Awaitable[_Resp]] = handler
176+
177+
if body_ta := ep_data.get("body"):
178+
try:
179+
request_body = body_ta.validate_python(await request.read())
180+
except ValidationError as e:
181+
raise web.HTTPBadRequest(text=e.json(), content_type="application/json")
182+
inner_handler = partial(inner_handler, request_body)
183+
184+
if query_ta := ep_data.get("query"):
185+
try:
186+
query = query_ta.validate_python(request.query)
187+
except ValidationError as e:
188+
raise web.HTTPBadRequest(text=e.json(), content_type="application/json")
189+
inner_handler = partial(inner_handler, query=query)
190+
191+
return await inner_handler()
192+
193+
# To handle both web.View methods and regular handlers (with different ways to get the
194+
# request object), this outer_wrapper() is needed with a custom handler lambda.
195+
196+
@functools.wraps(wrapped)
197+
async def outer_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _Resp: # type: ignore[misc]
198+
return await handler(_wrapper, wrapped, *args, **kwargs)
199+
200+
return outer_wrapper
140201

141202
class SchemaGenerator:
142203
def __init__(self, info: Info | None = None):
@@ -173,6 +234,10 @@ def _save_handler(self, handler: APIHandler[APIResponse[object, int]], tags: lis
173234
if body.kind in {body.POSITIONAL_ONLY, body.POSITIONAL_OR_KEYWORD}:
174235
ep_data["body"] = TypeAdapter(Json[body.annotation]) # type: ignore[misc,name-defined]
175236

237+
query_param = sig.parameters.get("query")
238+
if query_param and query_param.kind is query_param.KEYWORD_ONLY:
239+
ep_data["query_raw"] = query_param.annotation
240+
176241
ep_data["resps"] = {}
177242
if get_origin(sig.return_annotation) is UnionType:
178243
resps = get_args(sig.return_annotation)
@@ -206,9 +271,9 @@ def decorator(view: type[_View]) -> type[_View]:
206271
for func, method in methods:
207272
ep_data = self._save_handler(func, tags=list(tags))
208273
self._endpoints[view]["meths"][method] = ep_data
209-
ta = ep_data.get("body")
210-
if ta:
211-
setattr(view, method, create_view_wrapper(func, ta))
274+
wrapper = make_wrapper(ep_data, func, lambda w, f, self: w(partial(f, self), self.request))
275+
if wrapper is not None:
276+
setattr(view, method, wrapper)
212277

213278
return view
214279

@@ -217,18 +282,8 @@ def decorator(view: type[_View]) -> type[_View]:
217282
def api(self, tags: Iterable[str] = ()) -> Callable[[APIHandler[_Resp]], Callable[[web.Request], Awaitable[_Resp]]]:
218283
def decorator(handler: APIHandler[_Resp]) -> Callable[[web.Request], Awaitable[_Resp]]:
219284
ep_data = self._save_handler(handler, tags=list(tags))
220-
ta = ep_data.get("body")
221-
if ta:
222-
@functools.wraps(handler)
223-
async def wrapper(request: web.Request) -> _Resp: # type: ignore[misc]
224-
nonlocal handler
225-
try:
226-
request_body = ta.validate_python(await request.read())
227-
except ValidationError as e:
228-
raise web.HTTPBadRequest(text=e.json(), content_type="application/json")
229-
handler = cast(Callable[[web.Request, Any], Awaitable[_Resp]], handler)
230-
return await handler(request, request_body)
231-
285+
wrapper = make_wrapper(ep_data, handler, lambda w, f, r: w(partial(f, r), r))
286+
if wrapper is not None:
232287
self._endpoints[wrapper] = {"meths": {None: ep_data}}
233288
return wrapper
234289

@@ -240,7 +295,7 @@ async def wrapper(request: web.Request) -> _Resp: # type: ignore[misc]
240295

241296
async def _on_startup(self, app: web.Application) -> None:
242297
#assert app.router.frozen
243-
models: list[tuple[tuple[str, OpenAPIMethod, int | Literal["requestBody"]], Literal["serialization", "validation"], TypeAdapter[object]]] = []
298+
models: list[tuple[_ModelKey, Literal["serialization", "validation"], TypeAdapter[object]]] = []
244299
paths: dict[str, _PathObject] = {}
245300
for route in app.router.routes():
246301
ep_data = self._endpoints.get(route.handler)
@@ -281,30 +336,59 @@ async def _on_startup(self, app: web.Application) -> None:
281336
path_data[method] = operation
282337

283338
body = endpoints.get("body")
284-
key: tuple[str, OpenAPIMethod, int | Literal["requestBody"]]
339+
key: _ModelKey
285340
if body:
286-
key = (path, method, "requestBody")
341+
key = (path, method, "requestBody", None)
287342
models.append((key, "validation", body))
343+
if query := endpoints.get("query_raw"):
344+
# We need separate schemas for each key of the TypedDict.
345+
td = {}
346+
for param_name, param_type in get_type_hints(query).items():
347+
required = param_name in query.__required_keys__ # type: ignore[attr-defined]
348+
key = (path, method, "parameter", (param_name, required))
349+
350+
extracted_type = param_type
351+
while get_origin(extracted_type) in {Annotated, Literal, Required, NotRequired}:
352+
extracted_type = get_args(param_type)[0]
353+
try:
354+
is_str = issubclass(extracted_type, str)
355+
except TypeError:
356+
is_str = isinstance(extracted_type, str) # Literal
357+
358+
# We also need to convert values to Json for runtime checking.
359+
ann_type = param_type if is_str else Json[param_type] # type: ignore[misc,valid-type]
360+
models.append((key, "validation", TypeAdapter(ann_type)))
361+
td[param_name] = Required[ann_type] if required else NotRequired[ann_type]
362+
endpoints["query"] = TypeAdapter(TypedDict(query.__name__, td)) # type: ignore[attr-defined,operator]
288363
for code, model in endpoints["resps"].items():
289-
key = (path, method, code)
364+
key = (path, method, "response", code)
290365
models.append((key, "serialization", model))
291366

292367
elems, defs = TypeAdapter.json_schemas(models, ref_template="#/components/schemas/{model}")
293368
if defs:
294369
self._openapi["components"] = {"schemas": defs["$defs"]}
295370

296371
# TODO: default response
297-
for ((path, method, code_or_key), mode), schema in elems.items():
298-
if code_or_key == "requestBody":
372+
key_type: str
373+
for (key, mode), schema in elems.items():
374+
if key[2] == "requestBody":
375+
path, method, key_type, _ = key
376+
assert mode == "validation"
377+
paths[path][method]["requestBody"] = {"content": {"application/json": {"schema": schema}}}
378+
elif key[2] == "parameter":
379+
path, method, key_type, (param_name, required) = key
299380
assert mode == "validation"
300-
paths[path][method][code_or_key] = {"content": {"application/json": {"schema": schema}}}
381+
parameter: _ParameterObject = {
382+
"name": param_name, "in": "query", "required": required, "schema": schema}
383+
paths[path][method].setdefault("parameters", []).append(parameter)
301384
else:
302-
assert isinstance(code_or_key, int)
385+
path, method, key_type, code = key
386+
assert key_type == "response"
303387
assert mode == "serialization"
304388
responses = paths[path][method].setdefault("responses", {})
305389
content: dict[str, _MediaTypeObject] = {"application/json": {"schema": schema}}
306-
reason = HTTPStatus(code_or_key).phrase
307-
responses[str(code_or_key)] = {"description": reason, "content": content}
390+
reason = HTTPStatus(code).phrase
391+
responses[str(code)] = {"description": reason, "content": content}
308392
if paths:
309393
self._openapi["paths"] = paths
310394

docs/api.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ SchemaGenerator
4343
in the schema. When the handler is executed, the request body will be read and
4444
validated against that type.
4545

46+
The handler function can define a `query` keyword-only parameter whose type
47+
annotation must be a form of :class:`typing.TypedDict`. When the handler is
48+
executed, the query parameters will be validated against that type.
49+
4650
:param tags: Sequence of strings used to specify tags to group endpoints.
4751

4852
.. method:: api_view(tags=())

docs/index.rst

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,37 @@ Validation of the request body is achieved by adding a positional parameter:
6262

6363
.. code-block:: python
6464
65-
async def handler(request: web.Request, body: dict[int, str]) -> APIResponse[int, Literal[200]]:
65+
async def handler(request: web.Request, body: dict[int, str]) -> APIResponse[int]:
6666
# body has been validated, so we can be sure the keys are int now.
6767
return APIResponse(sum(body.keys()))
6868
6969
This will include the information in the schema's requestBody, plus it will validate
7070
the input from the user. If validation fails it will return a 400 response with
7171
information about what was incorrect.
7272

73+
Keyword-only parameters can be defined for ``query`` arguments:
74+
75+
.. code-block:: python
76+
77+
class QueryArgs(TypedDict, total=False):
78+
sort: Literal["asc", "desc"]
79+
80+
async def handler(request: web.Request, *, query: QueryArgs) -> APIResponse[int]:
81+
return sorted(results, reverse=query.get("sort", "asc") == "desc")
82+
83+
Pydantic options
84+
----------------
85+
86+
You can add custom Pydantic options using :class:`typing.Annotated`:
87+
88+
.. code-block:: python
89+
90+
class QueryArgs(TypedDict):
91+
sort: Annotated[Literal["asc", "desc"], pydantic.Field(default="asc")]
92+
93+
async def handler(request: web.Request, *, query: QueryArgs) -> APIResponse[int]:
94+
return sorted(results, reverse=query["sort"] == "desc")
95+
7396
Customising schema generation
7497
-----------------------------
7598

example.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class NewPoll(TypedDict):
3737

3838

3939
SCHEMA = SchemaGenerator()
40+
NotFound = APIResponse[None, Literal[404]]
4041

4142
POLLS = {1: POLL1}
4243
CHOICES = {1: list(CHOICES1)}
@@ -52,7 +53,7 @@ async def list_polls(request: web.Request) -> APIResponse[tuple[Poll, ...], Lite
5253

5354

5455
@SCHEMA.api()
55-
async def add_choice(request: web.Request, message: str) -> APIResponse[int, Literal[201]] | APIResponse[None, Literal[404]]:
56+
async def add_choice(request: web.Request, message: str) -> APIResponse[int, Literal[201]] | NotFound:
5657
"""Edit a choice.
5758
5859
Return the ID of the new choice.
@@ -65,16 +66,28 @@ async def add_choice(request: web.Request, message: str) -> APIResponse[int, Lit
6566
return APIResponse[None, Literal[404]](None, status=404)
6667

6768

69+
class GetQuery(TypedDict):
70+
"""Define our query arguments for the get endpoint."""
71+
results: Annotated[bool, Field(default=True)]
72+
73+
74+
class GetPollResult(Poll, total=False):
75+
results: list[Choice]
76+
77+
6878
@SCHEMA.api_view()
6979
class PollView(web.View):
7080
"""Endpoints for individual polls."""
7181

72-
async def get(self) -> APIResponse[Poll, Literal[200]] | APIResponse[None, Literal[404]]:
82+
async def get(self, *, query: GetQuery) -> APIResponse[GetPollResult, Literal[200]] | NotFound:
7383
"""Fetch a poll by ID."""
7484
poll_id = int(self.request.match_info["id"])
7585
poll = POLLS.get(poll_id)
7686
if poll:
77-
return APIResponse(poll)
87+
poll_result: GetPollResult = poll.copy() # type: ignore[assignment]
88+
if query["results"]:
89+
poll_result["results"] = CHOICES[poll_id]
90+
return APIResponse(poll_result)
7891
return APIResponse[None, Literal[404]](None, status=404)
7992

8093
async def put(self, body: NewPoll) -> APIResponse[int]:

0 commit comments

Comments
 (0)