3131from typing import Any , ClassVar
3232
3333from pydantic import BaseModel
34+ from pydantic import ValidationError as PydanticValidationError
35+ from pydantic import create_model
3436
3537from fastopenapi .error_handler import (
3638 BadRequestError ,
@@ -77,13 +79,14 @@ class BaseRouter:
7779
7880 # Class-level cache for model schemas to avoid redundant processing
7981 _model_schema_cache : ClassVar [dict [str , dict ]] = {}
82+ _param_model_cache : ClassVar [dict [frozenset , type [BaseModel ]]] = {}
8083
8184 def __init__ (
8285 self ,
8386 app : Any = None ,
84- docs_url : str = "/docs" ,
85- redoc_url : str = "/redoc" ,
86- openapi_url : str = "/openapi.json" ,
87+ docs_url : str | None = "/docs" ,
88+ redoc_url : str | None = "/redoc" ,
89+ openapi_url : str | None = "/openapi.json" ,
8790 openapi_version : str = "3.0.0" ,
8891 title : str = "My App" ,
8992 version : str = "0.1.0" ,
@@ -166,12 +169,7 @@ def generate_openapi(self) -> dict:
166169 "description" : self .description ,
167170 }
168171
169- schema = {
170- "openapi" : self .openapi_version ,
171- "info" : info ,
172- "paths" : {},
173- "components" : {"schemas" : {}},
174- }
172+ paths = {}
175173 definitions = {}
176174
177175 # Add standard error responses to components schema
@@ -183,8 +181,19 @@ def generate_openapi(self) -> dict:
183181 operation = self ._build_operation (
184182 endpoint , definitions , openapi_path , method
185183 )
186- schema ["paths" ].setdefault (openapi_path , {})[method .lower ()] = operation
187- schema ["components" ]["schemas" ].update (definitions )
184+
185+ if openapi_path not in paths :
186+ paths [openapi_path ] = {}
187+
188+ paths [openapi_path ][method .lower ()] = operation
189+
190+ schema = {
191+ "openapi" : self .openapi_version ,
192+ "info" : info ,
193+ "paths" : paths ,
194+ "components" : {"schemas" : definitions },
195+ }
196+
188197 return schema
189198
190199 def _generate_error_schema (self ) -> dict [str , Any ]:
@@ -209,7 +218,7 @@ def _generate_error_schema(self) -> dict[str, Any]:
209218 }
210219
211220 def _build_operation (
212- self , endpoint , definitions : dict , route_path : str , http_method : str
221+ self , endpoint : Callable , definitions : dict , route_path : str , http_method : str
213222 ) -> dict :
214223 parameters , request_body = self ._build_parameters_and_body (
215224 endpoint , definitions , route_path , http_method
@@ -501,73 +510,121 @@ def _resolve_pydantic_model(model_class, params, param_name):
501510 f"Validation error for parameter '{ param_name } '" , str (e )
502511 )
503512
504- @staticmethod
505- def _resolve_list_param (param_name , value , annotation ):
506- """Resolving a list-type parameter"""
507- args = typing .get_args (annotation )
508- try :
509- if args :
510- return [args [0 ](value )]
511- else :
512- return [value ]
513- except Exception as e :
514- type_name = args [0 ].__name__ if args else "value"
515- raise BadRequestError (
516- f"Error parsing parameter '{ param_name } ' as list item. "
517- f"Must be a valid { type_name } " ,
518- str (e ),
519- )
520-
521- @staticmethod
522- def _resolve_scalar_param (param_name , value , annotation ):
523- """Resolving a scalar parameter"""
524- try :
525- return annotation (value )
526- except Exception as e :
527- type_name = getattr (annotation , "__name__" , str (annotation ))
528- raise BadRequestError (
529- f"Error parsing parameter '{ param_name } '. "
530- f"Must be a valid { type_name } " ,
531- str (e ),
532- )
533-
534- @staticmethod
513+ @classmethod
535514 def resolve_endpoint_params (
536- endpoint : Callable , all_params : dict , body : dict
537- ) -> dict :
538- """Main method for resolving endpoint parameters """
515+ cls , endpoint : Callable , all_params : dict [ str , Any ], body : dict [ str , Any ]
516+ ) -> dict [ str , Any ] :
517+ """Resolves endpoint parameters using Pydantic validation with caching """
539518 sig = inspect .signature (endpoint )
540519 kwargs = {}
520+ model_fields = {}
521+ model_values = {}
522+ param_types = cls ._extract_param_types (sig )
541523
524+ # Process each parameter from the endpoint signature
542525 for name , param in sig .parameters .items ():
543526 annotation = param .annotation
544- is_required = param .default is inspect .Parameter .empty
545527
546- if isinstance (annotation , type ) and issubclass (annotation , BaseModel ):
547- kwargs [name ] = BaseRouter ._resolve_pydantic_model (
548- annotation , body if body else all_params , name
528+ # Handle Pydantic model parameters
529+ if cls ._is_pydantic_model (annotation ):
530+ kwargs [name ] = cls ._process_pydantic_param (
531+ name , annotation , body if body else all_params
549532 )
550533 continue
551534
535+ # Handle missing parameters
552536 if name not in all_params :
553- if is_required :
554- raise BadRequestError (f"Missing required parameter: '{ name } '" )
555- kwargs [name ] = param .default
537+ kwargs [name ] = cls ._handle_missing_param (name , param )
556538 continue
557539
558- origin = typing .get_origin (annotation )
540+ # Collect fields for dynamic model validation
541+ model_fields [name ] = (
542+ annotation ,
543+ param .default if param .default is not inspect .Parameter .empty else ...,
544+ )
545+ model_values [name ] = all_params [name ]
559546
560- if origin is list and not isinstance (all_params [name ], list ):
561- kwargs [name ] = BaseRouter ._resolve_list_param (
562- name , all_params [name ], annotation
563- )
564- else :
565- kwargs [name ] = BaseRouter ._resolve_scalar_param (
566- name , all_params [name ], annotation
567- )
547+ # Validate collected parameters using dynamic model
548+ if model_fields :
549+ validated_params = cls ._validate_with_dynamic_model (
550+ endpoint , model_fields , model_values , param_types
551+ )
552+ kwargs .update (validated_params )
568553
569554 return kwargs
570555
556+ @classmethod
557+ def _extract_param_types (cls , sig : inspect .Signature ) -> dict [str , Any ]:
558+ """Extract parameter types from signature"""
559+ return {name : param .annotation for name , param in sig .parameters .items ()}
560+
561+ @classmethod
562+ def _process_pydantic_param (
563+ cls , name : str , model_class : type [BaseModel ], params : dict [str , Any ]
564+ ) -> BaseModel :
565+ """Process a parameter that's a Pydantic model"""
566+ try :
567+ return cls ._resolve_pydantic_model (model_class , params , name )
568+ except Exception as e :
569+ raise ValidationError (f"Validation error for parameter '{ name } '" , str (e ))
570+
571+ @staticmethod
572+ def _handle_missing_param (name : str , param : inspect .Parameter ) -> Any :
573+ """Handle parameters not provided in the request"""
574+ if param .default is inspect .Parameter .empty :
575+ raise BadRequestError (f"Missing required parameter: '{ name } '" )
576+ return param .default
577+
578+ @classmethod
579+ def _validate_with_dynamic_model (
580+ cls ,
581+ endpoint : Callable ,
582+ model_fields : dict ,
583+ model_values : dict ,
584+ param_types : dict [str , Any ],
585+ ) -> dict [str , Any ] | None :
586+ """Validate parameters using a dynamically created Pydantic model"""
587+ # Create cache key for the dynamic model
588+ cache_key = frozenset (
589+ (endpoint .__module__ , endpoint .__name__ , name , str (ann ))
590+ for name , (ann , _ ) in model_fields .items ()
591+ )
592+
593+ # Get or create the model class
594+ if cache_key not in cls ._param_model_cache :
595+ cls ._param_model_cache [cache_key ] = create_model (
596+ "ParamsModel" , ** model_fields
597+ )
598+
599+ try :
600+ # Validate parameters against the model
601+ validated = cls ._param_model_cache [cache_key ](** model_values )
602+ return validated .model_dump ()
603+ except PydanticValidationError as e :
604+ raise cls ._handle_validation_error (e , param_types )
605+
606+ @staticmethod
607+ def _handle_validation_error (
608+ error : PydanticValidationError , param_types : dict [str , Any ]
609+ ) -> BadRequestError :
610+ """Handle validation errors with detailed messages"""
611+ exc = BadRequestError ("Parameter validation failed" , str (error ))
612+ errors = error .errors ()
613+ if errors :
614+ error_info = errors [0 ]
615+ loc = error_info .get ("loc" , [])
616+ if loc and len (loc ) > 0 :
617+ param_name = str (loc [0 ])
618+ if param_name in param_types :
619+ type_name = getattr (param_types [param_name ], "__name__" , "value" )
620+ exc = BadRequestError (
621+ f"Error parsing parameter '{ param_name } '. "
622+ f"Must be a valid { type_name } " ,
623+ str (error_info .get ("msg" , "" )),
624+ )
625+
626+ return exc
627+
571628 @property
572629 def openapi (self ) -> dict :
573630 if self ._openapi_schema is None :
0 commit comments