22import inspect
33import sys
44from collections .abc import Awaitable , Callable , Mapping
5+ from functools import partial
56from http import HTTPStatus
67from pathlib import Path
78from types import UnionType
8- from typing import Any , Iterable , Literal , TypedDict , TypeGuard , TypeVar , cast , get_args , get_origin
9+ from typing import (Any , Annotated , Concatenate , Generic , Iterable , Literal , ParamSpec ,
10+ Protocol , TypeGuard , TypeVar , cast , get_args , get_origin , get_type_hints )
911
1012from aiohttp import web
1113from aiohttp .hdrs import METH_ALL
1416
1517from aiohttp_apischema .response import APIResponse
1618
19+ if sys .version_info >= (3 , 12 ):
20+ from typing import TypedDict
21+ else :
22+ from typing_extensions import TypedDict
23+
1724if sys .version_info >= (3 , 11 ):
18- from typing import Required
25+ from typing import NotRequired , Required
1926else :
20- from typing_extensions import Required
27+ from typing_extensions import NotRequired , Required
2128
2229OPENAPI_METHODS = frozenset ({"get" , "put" , "post" , "delete" , "options" , "head" , "patch" , "trace" })
2330
2431_T = TypeVar ("_T" )
25- _Resp = TypeVar ("_Resp" , bound = APIResponse [Any , Any ])
32+ _U = TypeVar ("_U" )
33+ _P = ParamSpec ("_P" )
34+ _Resp = TypeVar ("_Resp" , bound = APIResponse [Any , Any ], covariant = True )
2635_View = TypeVar ("_View" , bound = web .View )
36+ OpenAPIMethod = Literal ["get" , "put" , "post" , "delete" , "options" , "head" , "patch" , "trace" ]
37+ __ModelKey = tuple [str , OpenAPIMethod , _T , _U ]
38+ _ModelKey = (
39+ __ModelKey [Literal ["requestBody" ], None ]
40+ | __ModelKey [Literal ["parameter" ], tuple [str , bool ]]
41+ | __ModelKey [Literal ["response" ], int ]
42+ )
43+
44+ class _APIHandler (Protocol , Generic [_Resp ]):
45+ def __call__ (self , request : web .Request , * , query : Any ) -> Awaitable [_Resp ]:
46+ ...
47+
48+
2749APIHandler = (
2850 Callable [[web .Request ], Awaitable [_Resp ]]
2951 | Callable [[web .Request , Any ], Awaitable [_Resp ]]
52+ | _APIHandler [_Resp ]
3053)
31- OpenAPIMethod = Literal [ "get" , "put" , "post" , "delete" , "options" , "head" , "patch" , "trace" ]
54+
3255
3356class Contact (TypedDict , total = False ):
3457 name : str
@@ -57,7 +80,9 @@ class Info(TypedDict, total=False):
5780class _EndpointData (TypedDict , total = False ):
5881 body : TypeAdapter [object ]
5982 desc : str
60- resps : dict [int , TypeAdapter [Any ]]
83+ query : TypeAdapter [dict [str , object ]]
84+ query_raw : dict [str , object ]
85+ resps : dict [int , TypeAdapter [object ]]
6186 summary : str
6287 tags : list [str ]
6388
@@ -72,6 +97,16 @@ class _Components(TypedDict, total=False):
7297class _MediaTypeObject (TypedDict , total = False ):
7398 schema : object
7499
100+ # in is a reserved keyword.
101+ _ParameterObject = TypedDict ("_ParameterObject" , {
102+ "deprecated" : bool ,
103+ "description" : str ,
104+ "name" : Required [str ],
105+ "in" : Required [Literal ["query" , "header" , "path" , "cookie" ]],
106+ "required" : bool ,
107+ "schema" : object
108+ }, total = False )
109+
75110class _RequestBodyObject (TypedDict , total = False ):
76111 content : Required [dict [str , _MediaTypeObject ]]
77112
@@ -82,6 +117,7 @@ class _ResponseObject(TypedDict, total=False):
82117class _OperationObject (TypedDict , total = False ):
83118 description : str
84119 operationId : str
120+ parameters : list [_ParameterObject ]
85121 requestBody : _RequestBodyObject
86122 responses : dict [str , _ResponseObject ]
87123 summary : str
@@ -125,18 +161,43 @@ class _OpenApi(TypedDict, total=False):
125161</html>"""
126162SWAGGER_PATH = Path (__file__ ).parent / "swagger-ui"
127163
164+ _Wrapper = Callable [[APIHandler [_Resp ], web .Request ], Awaitable [_Resp ]]
165+
128166def is_openapi_method (method : str ) -> TypeGuard [OpenAPIMethod ]:
129167 return method in OPENAPI_METHODS
130168
131- def create_view_wrapper (handler : Callable [[_View , _T ], Awaitable [_Resp ]], ta : TypeAdapter [_T ]) -> Callable [[_View ], Awaitable [_Resp ]]:
132- @functools .wraps (handler )
133- async def wrapper (self : _View ) -> _Resp : # type: ignore[misc]
134- try :
135- request_body = ta .validate_python (await self .request .read ())
136- except ValidationError as e :
137- raise web .HTTPBadRequest (text = e .json (), content_type = "application/json" )
138- return await handler (self , request_body )
139- return wrapper
169+ def make_wrapper (ep_data : _EndpointData , wrapped : APIHandler [_Resp ], handler : Callable [Concatenate [_Wrapper [_Resp ], APIHandler [_Resp ], _P ], Awaitable [_Resp ]]) -> Callable [_P , Awaitable [_Resp ]] | None :
170+ # Only these keys need a wrapper created.
171+ if not {"body" , "query_raw" } & ep_data .keys ():
172+ return None
173+
174+ async def _wrapper (handler : APIHandler [_Resp ], request : web .Request ) -> _Resp :
175+ inner_handler : Callable [..., Awaitable [_Resp ]] = handler
176+
177+ if body_ta := ep_data .get ("body" ):
178+ try :
179+ request_body = body_ta .validate_python (await request .read ())
180+ except ValidationError as e :
181+ raise web .HTTPBadRequest (text = e .json (), content_type = "application/json" )
182+ inner_handler = partial (inner_handler , request_body )
183+
184+ if query_ta := ep_data .get ("query" ):
185+ try :
186+ query = query_ta .validate_python (request .query )
187+ except ValidationError as e :
188+ raise web .HTTPBadRequest (text = e .json (), content_type = "application/json" )
189+ inner_handler = partial (inner_handler , query = query )
190+
191+ return await inner_handler ()
192+
193+ # To handle both web.View methods and regular handlers (with different ways to get the
194+ # request object), this outer_wrapper() is needed with a custom handler lambda.
195+
196+ @functools .wraps (wrapped )
197+ async def outer_wrapper (* args : _P .args , ** kwargs : _P .kwargs ) -> _Resp : # type: ignore[misc]
198+ return await handler (_wrapper , wrapped , * args , ** kwargs )
199+
200+ return outer_wrapper
140201
141202class SchemaGenerator :
142203 def __init__ (self , info : Info | None = None ):
@@ -173,6 +234,10 @@ def _save_handler(self, handler: APIHandler[APIResponse[object, int]], tags: lis
173234 if body .kind in {body .POSITIONAL_ONLY , body .POSITIONAL_OR_KEYWORD }:
174235 ep_data ["body" ] = TypeAdapter (Json [body .annotation ]) # type: ignore[misc,name-defined]
175236
237+ query_param = sig .parameters .get ("query" )
238+ if query_param and query_param .kind is query_param .KEYWORD_ONLY :
239+ ep_data ["query_raw" ] = query_param .annotation
240+
176241 ep_data ["resps" ] = {}
177242 if get_origin (sig .return_annotation ) is UnionType :
178243 resps = get_args (sig .return_annotation )
@@ -206,9 +271,9 @@ def decorator(view: type[_View]) -> type[_View]:
206271 for func , method in methods :
207272 ep_data = self ._save_handler (func , tags = list (tags ))
208273 self ._endpoints [view ]["meths" ][method ] = ep_data
209- ta = ep_data . get ( "body" )
210- if ta :
211- setattr (view , method , create_view_wrapper ( func , ta ) )
274+ wrapper = make_wrapper ( ep_data , func , lambda w , f , self : w ( partial ( f , self ), self . request ) )
275+ if wrapper is not None :
276+ setattr (view , method , wrapper )
212277
213278 return view
214279
@@ -217,18 +282,8 @@ def decorator(view: type[_View]) -> type[_View]:
217282 def api (self , tags : Iterable [str ] = ()) -> Callable [[APIHandler [_Resp ]], Callable [[web .Request ], Awaitable [_Resp ]]]:
218283 def decorator (handler : APIHandler [_Resp ]) -> Callable [[web .Request ], Awaitable [_Resp ]]:
219284 ep_data = self ._save_handler (handler , tags = list (tags ))
220- ta = ep_data .get ("body" )
221- if ta :
222- @functools .wraps (handler )
223- async def wrapper (request : web .Request ) -> _Resp : # type: ignore[misc]
224- nonlocal handler
225- try :
226- request_body = ta .validate_python (await request .read ())
227- except ValidationError as e :
228- raise web .HTTPBadRequest (text = e .json (), content_type = "application/json" )
229- handler = cast (Callable [[web .Request , Any ], Awaitable [_Resp ]], handler )
230- return await handler (request , request_body )
231-
285+ wrapper = make_wrapper (ep_data , handler , lambda w , f , r : w (partial (f , r ), r ))
286+ if wrapper is not None :
232287 self ._endpoints [wrapper ] = {"meths" : {None : ep_data }}
233288 return wrapper
234289
@@ -240,7 +295,7 @@ async def wrapper(request: web.Request) -> _Resp: # type: ignore[misc]
240295
241296 async def _on_startup (self , app : web .Application ) -> None :
242297 #assert app.router.frozen
243- models : list [tuple [tuple [ str , OpenAPIMethod , int | Literal [ "requestBody" ]] , Literal ["serialization" , "validation" ], TypeAdapter [object ]]] = []
298+ models : list [tuple [_ModelKey , Literal ["serialization" , "validation" ], TypeAdapter [object ]]] = []
244299 paths : dict [str , _PathObject ] = {}
245300 for route in app .router .routes ():
246301 ep_data = self ._endpoints .get (route .handler )
@@ -281,30 +336,59 @@ async def _on_startup(self, app: web.Application) -> None:
281336 path_data [method ] = operation
282337
283338 body = endpoints .get ("body" )
284- key : tuple [ str , OpenAPIMethod , int | Literal [ "requestBody" ]]
339+ key : _ModelKey
285340 if body :
286- key = (path , method , "requestBody" )
341+ key = (path , method , "requestBody" , None )
287342 models .append ((key , "validation" , body ))
343+ if query := endpoints .get ("query_raw" ):
344+ # We need separate schemas for each key of the TypedDict.
345+ td = {}
346+ for param_name , param_type in get_type_hints (query ).items ():
347+ required = param_name in query .__required_keys__ # type: ignore[attr-defined]
348+ key = (path , method , "parameter" , (param_name , required ))
349+
350+ extracted_type = param_type
351+ while get_origin (extracted_type ) in {Annotated , Literal , Required , NotRequired }:
352+ extracted_type = get_args (param_type )[0 ]
353+ try :
354+ is_str = issubclass (extracted_type , str )
355+ except TypeError :
356+ is_str = isinstance (extracted_type , str ) # Literal
357+
358+ # We also need to convert values to Json for runtime checking.
359+ ann_type = param_type if is_str else Json [param_type ] # type: ignore[misc,valid-type]
360+ models .append ((key , "validation" , TypeAdapter (ann_type )))
361+ td [param_name ] = Required [ann_type ] if required else NotRequired [ann_type ]
362+ endpoints ["query" ] = TypeAdapter (TypedDict (query .__name__ , td )) # type: ignore[attr-defined,operator]
288363 for code , model in endpoints ["resps" ].items ():
289- key = (path , method , code )
364+ key = (path , method , "response" , code )
290365 models .append ((key , "serialization" , model ))
291366
292367 elems , defs = TypeAdapter .json_schemas (models , ref_template = "#/components/schemas/{model}" )
293368 if defs :
294369 self ._openapi ["components" ] = {"schemas" : defs ["$defs" ]}
295370
296371 # TODO: default response
297- for ((path , method , code_or_key ), mode ), schema in elems .items ():
298- if code_or_key == "requestBody" :
372+ key_type : str
373+ for (key , mode ), schema in elems .items ():
374+ if key [2 ] == "requestBody" :
375+ path , method , key_type , _ = key
376+ assert mode == "validation"
377+ paths [path ][method ]["requestBody" ] = {"content" : {"application/json" : {"schema" : schema }}}
378+ elif key [2 ] == "parameter" :
379+ path , method , key_type , (param_name , required ) = key
299380 assert mode == "validation"
300- paths [path ][method ][code_or_key ] = {"content" : {"application/json" : {"schema" : schema }}}
381+ parameter : _ParameterObject = {
382+ "name" : param_name , "in" : "query" , "required" : required , "schema" : schema }
383+ paths [path ][method ].setdefault ("parameters" , []).append (parameter )
301384 else :
302- assert isinstance (code_or_key , int )
385+ path , method , key_type , code = key
386+ assert key_type == "response"
303387 assert mode == "serialization"
304388 responses = paths [path ][method ].setdefault ("responses" , {})
305389 content : dict [str , _MediaTypeObject ] = {"application/json" : {"schema" : schema }}
306- reason = HTTPStatus (code_or_key ).phrase
307- responses [str (code_or_key )] = {"description" : reason , "content" : content }
390+ reason = HTTPStatus (code ).phrase
391+ responses [str (code )] = {"description" : reason , "content" : content }
308392 if paths :
309393 self ._openapi ["paths" ] = paths
310394
0 commit comments