Skip to content

Commit cbd3d06

Browse files
committed
Reworked resolve_endpoint_params, added tests
1 parent f08cf64 commit cbd3d06

File tree

19 files changed

+356
-169
lines changed

19 files changed

+356
-169
lines changed

benchmarks/tornado/without_fastopenapi/run.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,10 @@ def write_error(self, status_code, **kwargs):
2525

2626
class RecordsHandler(BaseRecordHandler):
2727
def get(self):
28-
"""Получить все записи (эквивалент @app.get("/records"))"""
2928
records = store.get_all()
3029
self.write(json.dumps([record.model_dump() for record in records]))
3130

3231
def post(self):
33-
"""Создать новую запись (эквивалент @app.post("/records"))"""
3432
try:
3533
data = json.loads(self.request.body)
3634
record_data = RecordCreate(**data)
@@ -44,7 +42,6 @@ def post(self):
4442

4543
class RecordHandler(BaseRecordHandler):
4644
def get(self, record_id):
47-
"""Получить конкретную запись (эквивалент @app.get("/records/<record_id>"))"""
4845
record = store.get_by_id(record_id)
4946
if record:
5047
self.write(json.dumps(record.model_dump()))
@@ -53,7 +50,6 @@ def get(self, record_id):
5350
self.write(json.dumps({"error": "Record not found"}))
5451

5552
def put(self, record_id):
56-
"""Заменить запись (эквивалент @app.put("/records/<record_id>"))"""
5753
try:
5854
if not store.get_by_id(record_id):
5955
self.set_status(404)
@@ -71,7 +67,6 @@ def put(self, record_id):
7167
self.write(json.dumps({"error": str(e)}))
7268

7369
def patch(self, record_id):
74-
"""Обновить запись частично (эквивалент @app.patch("/records/<record_id>"))"""
7570
try:
7671
data = json.loads(self.request.body)
7772
record_data = RecordUpdate(**data)
@@ -86,7 +81,6 @@ def patch(self, record_id):
8681
self.write(json.dumps({"error": str(e)}))
8782

8883
def delete(self, record_id):
89-
"""Удалить запись (эквивалент @app.delete("/records/<record_id>"))"""
9084
if store.delete(record_id):
9185
self.set_status(204)
9286
self.finish()
@@ -98,9 +92,8 @@ def delete(self, record_id):
9892
def make_app():
9993
return tornado.web.Application(
10094
[
101-
(r"/records", RecordsHandler), # Обрабатывает GET и POST для /records
95+
(r"/records", RecordsHandler),
10296
(r"/records/([^/]+)", RecordHandler),
103-
# Обрабатывает GET, PUT, PATCH, DELETE для /records/<id>
10497
]
10598
)
10699

fastopenapi/base_router.py

Lines changed: 119 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
from typing import Any, ClassVar
3232

3333
from pydantic import BaseModel
34+
from pydantic import ValidationError as PydanticValidationError
35+
from pydantic import create_model
3436

3537
from fastopenapi.error_handler import (
3638
BadRequestError,
@@ -77,13 +79,14 @@ class BaseRouter:
7779

7880
# Class-level cache for model schemas to avoid redundant processing
7981
_model_schema_cache: ClassVar[dict[str, dict]] = {}
82+
_param_model_cache: ClassVar[dict[frozenset, type[BaseModel]]] = {}
8083

8184
def __init__(
8285
self,
8386
app: Any = None,
84-
docs_url: str = "/docs",
85-
redoc_url: str = "/redoc",
86-
openapi_url: str = "/openapi.json",
87+
docs_url: str | None = "/docs",
88+
redoc_url: str | None = "/redoc",
89+
openapi_url: str | None = "/openapi.json",
8790
openapi_version: str = "3.0.0",
8891
title: str = "My App",
8992
version: str = "0.1.0",
@@ -166,12 +169,7 @@ def generate_openapi(self) -> dict:
166169
"description": self.description,
167170
}
168171

169-
schema = {
170-
"openapi": self.openapi_version,
171-
"info": info,
172-
"paths": {},
173-
"components": {"schemas": {}},
174-
}
172+
paths = {}
175173
definitions = {}
176174

177175
# Add standard error responses to components schema
@@ -183,8 +181,19 @@ def generate_openapi(self) -> dict:
183181
operation = self._build_operation(
184182
endpoint, definitions, openapi_path, method
185183
)
186-
schema["paths"].setdefault(openapi_path, {})[method.lower()] = operation
187-
schema["components"]["schemas"].update(definitions)
184+
185+
if openapi_path not in paths:
186+
paths[openapi_path] = {}
187+
188+
paths[openapi_path][method.lower()] = operation
189+
190+
schema = {
191+
"openapi": self.openapi_version,
192+
"info": info,
193+
"paths": paths,
194+
"components": {"schemas": definitions},
195+
}
196+
188197
return schema
189198

190199
def _generate_error_schema(self) -> dict[str, Any]:
@@ -209,7 +218,7 @@ def _generate_error_schema(self) -> dict[str, Any]:
209218
}
210219

211220
def _build_operation(
212-
self, endpoint, definitions: dict, route_path: str, http_method: str
221+
self, endpoint: Callable, definitions: dict, route_path: str, http_method: str
213222
) -> dict:
214223
parameters, request_body = self._build_parameters_and_body(
215224
endpoint, definitions, route_path, http_method
@@ -501,73 +510,121 @@ def _resolve_pydantic_model(model_class, params, param_name):
501510
f"Validation error for parameter '{param_name}'", str(e)
502511
)
503512

504-
@staticmethod
505-
def _resolve_list_param(param_name, value, annotation):
506-
"""Resolving a list-type parameter"""
507-
args = typing.get_args(annotation)
508-
try:
509-
if args:
510-
return [args[0](value)]
511-
else:
512-
return [value]
513-
except Exception as e:
514-
type_name = args[0].__name__ if args else "value"
515-
raise BadRequestError(
516-
f"Error parsing parameter '{param_name}' as list item. "
517-
f"Must be a valid {type_name}",
518-
str(e),
519-
)
520-
521-
@staticmethod
522-
def _resolve_scalar_param(param_name, value, annotation):
523-
"""Resolving a scalar parameter"""
524-
try:
525-
return annotation(value)
526-
except Exception as e:
527-
type_name = getattr(annotation, "__name__", str(annotation))
528-
raise BadRequestError(
529-
f"Error parsing parameter '{param_name}'. "
530-
f"Must be a valid {type_name}",
531-
str(e),
532-
)
533-
534-
@staticmethod
513+
@classmethod
535514
def resolve_endpoint_params(
536-
endpoint: Callable, all_params: dict, body: dict
537-
) -> dict:
538-
"""Main method for resolving endpoint parameters"""
515+
cls, endpoint: Callable, all_params: dict[str, Any], body: dict[str, Any]
516+
) -> dict[str, Any]:
517+
"""Resolves endpoint parameters using Pydantic validation with caching"""
539518
sig = inspect.signature(endpoint)
540519
kwargs = {}
520+
model_fields = {}
521+
model_values = {}
522+
param_types = cls._extract_param_types(sig)
541523

524+
# Process each parameter from the endpoint signature
542525
for name, param in sig.parameters.items():
543526
annotation = param.annotation
544-
is_required = param.default is inspect.Parameter.empty
545527

546-
if isinstance(annotation, type) and issubclass(annotation, BaseModel):
547-
kwargs[name] = BaseRouter._resolve_pydantic_model(
548-
annotation, body if body else all_params, name
528+
# Handle Pydantic model parameters
529+
if cls._is_pydantic_model(annotation):
530+
kwargs[name] = cls._process_pydantic_param(
531+
name, annotation, body if body else all_params
549532
)
550533
continue
551534

535+
# Handle missing parameters
552536
if name not in all_params:
553-
if is_required:
554-
raise BadRequestError(f"Missing required parameter: '{name}'")
555-
kwargs[name] = param.default
537+
kwargs[name] = cls._handle_missing_param(name, param)
556538
continue
557539

558-
origin = typing.get_origin(annotation)
540+
# Collect fields for dynamic model validation
541+
model_fields[name] = (
542+
annotation,
543+
param.default if param.default is not inspect.Parameter.empty else ...,
544+
)
545+
model_values[name] = all_params[name]
559546

560-
if origin is list and not isinstance(all_params[name], list):
561-
kwargs[name] = BaseRouter._resolve_list_param(
562-
name, all_params[name], annotation
563-
)
564-
else:
565-
kwargs[name] = BaseRouter._resolve_scalar_param(
566-
name, all_params[name], annotation
567-
)
547+
# Validate collected parameters using dynamic model
548+
if model_fields:
549+
validated_params = cls._validate_with_dynamic_model(
550+
endpoint, model_fields, model_values, param_types
551+
)
552+
kwargs.update(validated_params)
568553

569554
return kwargs
570555

556+
@classmethod
557+
def _extract_param_types(cls, sig: inspect.Signature) -> dict[str, Any]:
558+
"""Extract parameter types from signature"""
559+
return {name: param.annotation for name, param in sig.parameters.items()}
560+
561+
@classmethod
562+
def _process_pydantic_param(
563+
cls, name: str, model_class: type[BaseModel], params: dict[str, Any]
564+
) -> BaseModel:
565+
"""Process a parameter that's a Pydantic model"""
566+
try:
567+
return cls._resolve_pydantic_model(model_class, params, name)
568+
except Exception as e:
569+
raise ValidationError(f"Validation error for parameter '{name}'", str(e))
570+
571+
@staticmethod
572+
def _handle_missing_param(name: str, param: inspect.Parameter) -> Any:
573+
"""Handle parameters not provided in the request"""
574+
if param.default is inspect.Parameter.empty:
575+
raise BadRequestError(f"Missing required parameter: '{name}'")
576+
return param.default
577+
578+
@classmethod
579+
def _validate_with_dynamic_model(
580+
cls,
581+
endpoint: Callable,
582+
model_fields: dict,
583+
model_values: dict,
584+
param_types: dict[str, Any],
585+
) -> dict[str, Any] | None:
586+
"""Validate parameters using a dynamically created Pydantic model"""
587+
# Create cache key for the dynamic model
588+
cache_key = frozenset(
589+
(endpoint.__module__, endpoint.__name__, name, str(ann))
590+
for name, (ann, _) in model_fields.items()
591+
)
592+
593+
# Get or create the model class
594+
if cache_key not in cls._param_model_cache:
595+
cls._param_model_cache[cache_key] = create_model(
596+
"ParamsModel", **model_fields
597+
)
598+
599+
try:
600+
# Validate parameters against the model
601+
validated = cls._param_model_cache[cache_key](**model_values)
602+
return validated.model_dump()
603+
except PydanticValidationError as e:
604+
raise cls._handle_validation_error(e, param_types)
605+
606+
@staticmethod
607+
def _handle_validation_error(
608+
error: PydanticValidationError, param_types: dict[str, Any]
609+
) -> BadRequestError:
610+
"""Handle validation errors with detailed messages"""
611+
exc = BadRequestError("Parameter validation failed", str(error))
612+
errors = error.errors()
613+
if errors:
614+
error_info = errors[0]
615+
loc = error_info.get("loc", [])
616+
if loc and len(loc) > 0:
617+
param_name = str(loc[0])
618+
if param_name in param_types:
619+
type_name = getattr(param_types[param_name], "__name__", "value")
620+
exc = BadRequestError(
621+
f"Error parsing parameter '{param_name}'. "
622+
f"Must be a valid {type_name}",
623+
str(error_info.get("msg", "")),
624+
)
625+
626+
return exc
627+
571628
@property
572629
def openapi(self) -> dict:
573630
if self._openapi_schema is None:

fastopenapi/routers/falcon.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -115,14 +115,6 @@ async def _read_body(self, req):
115115
pass
116116
return {}
117117

118-
# def _handle_request_error(self, resp, error_message: str):
119-
# resp.status = falcon.HTTP_422
120-
# resp.media = {"detail": error_message}
121-
#
122-
# def _handle_response_error(self, resp, error_message: str):
123-
# resp.status = falcon.HTTP_500
124-
# resp.media = {"detail": error_message}
125-
126118
def _register_docs_endpoints(self):
127119
outer = self
128120

fastopenapi/routers/sanic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def add_route(self, path: str, method: str, endpoint: Callable):
1818
async def view_func(request, **path_params):
1919
query_params = {}
2020
for k, v in request.args.items():
21-
values = request.args.getall(k)
21+
values = request.args.getlist(k)
2222
query_params[k] = values[0] if len(values) == 1 else values
2323
json_body = request.json or {}
2424
all_params = {**query_params, **path_params}

tests/aiohttp/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ def app(items_db): # noqa: C901
4545
version="0.1.0",
4646
)
4747

48+
@router.get("/list-test")
49+
def list_endpoint(param1: str, param2: list[str] = None):
50+
"""Test endpoint that returns the parameters it receives"""
51+
return {"received_param1": param1, "received_param2": param2}
52+
4853
@router.get("/items", response_model=list[ItemResponse], tags=["items"])
4954
async def get_items():
5055
return [Item(**item) for item in items_db]

tests/aiohttp/test_aiohttp_integration.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,22 @@ async def test_redoc_ui_endpoint(self, client):
138138
text = await resp.text()
139139
assert "text/html" in resp.headers["Content-Type"]
140140
assert "redoc" in text
141+
142+
@pytest.mark.asyncio
143+
async def test_query_parameters_handling(self, client):
144+
"""Test handling of query parameters"""
145+
# Test with a single value parameter
146+
response = await client.get("/list-test?param1=single_value")
147+
assert response.status == 200
148+
data = await response.json()
149+
assert data["received_param1"] == "single_value"
150+
151+
# Test with a parameter that has multiple values
152+
response = await client.get(
153+
"/list-test?param1=first_value&param2=value1&param2=value2"
154+
)
155+
assert response.status == 200
156+
data = await response.json()
157+
assert data["received_param1"] == "first_value"
158+
assert isinstance(data["received_param2"], list)
159+
assert data["received_param2"] == ["value1", "value2"]

0 commit comments

Comments
 (0)