@@ -113,6 +113,7 @@ class _PropertyInfo:
113113 type_ : str
114114
115115 root : Any = None
116+ anno : pydantic .Field = None
116117 config : dict [str , Any ] = dataclasses .field (default_factory = dict )
117118 properties : dict [str , _PropertyInfo ] = dataclasses .field (
118119 default_factory = lambda : collections .defaultdict (lambda : _ClassInfo ._PropertyInfo ())
@@ -257,10 +258,17 @@ def collapse(cls, type_name, items: list["_ClassInfo"]) -> type[BaseModel]:
257258 r = [i .model () for i in items ]
258259
259260 if len (r ) > 1 :
260- ru : object = Union [tuple (r )]
261+ # Annotations are collected, last element has all of them
262+ # ru: object = Annotated[Union[tuple(r)], items[-1].anno]
263+ v = list ()
264+ for i in range (len (items )):
265+ v .append (Annotated [r [i ], items [i ].anno ])
266+ ru = Annotated [Union [tuple (v )], Field (default = None )]
261267 m : type [RootModel ] = create_model (type_name , __base__ = (ConfiguredRootModel [ru ],), __module__ = me .__name__ )
262268 elif len (r ) == 1 :
263269 m : type [BaseModel ] = cast (type [BaseModel ], r [0 ])
270+ if items [0 ].anno :
271+ m = Annotated [m , items [0 ].anno ]
264272 if not is_basemodel (m ):
265273 m = create_model (type_name , __base__ = (ConfiguredRootModel [m ],), __module__ = me .__name__ )
266274 else : # == 0
@@ -299,7 +307,8 @@ def from_schema(
299307 r : list [_ClassInfo ] = list ()
300308
301309 for _type in Model .types (schema ):
302- r .append (Model .createClassInfo (schema , _type , schemanames , discriminators , extra ))
310+ args = dict ()
311+ r .append (Model .createClassInfo (schema , _type , schemanames , discriminators , extra , args ))
303312
304313 m = _ClassInfo .collapse (schema ._get_identity ("L8" ), r )
305314
@@ -313,6 +322,7 @@ def createClassInfo(
313322 schemanames : list [str ],
314323 discriminators : list ["DiscriminatorType" ],
315324 extra : list ["SchemaType" ] | None ,
325+ args : dict [str , Any ] = None ,
316326 ) -> _ClassInfo :
317327 from . import v20 , v30 , v31
318328
@@ -326,9 +336,8 @@ def createClassInfo(
326336 for primitive types the anyOf/oneOf is taken care of in Model.createAnnotation
327337 """
328338 if typing .get_origin (_t := Model .createAnnotation (schema , _type = _type )) != Literal :
329- classinfo .root = Annotated [_t , Model .createField (schema , _type = _type , args = None )]
330- else :
331- classinfo .root = _t
339+ classinfo .anno = Model .createField (schema , _type = _type , args = args )
340+ classinfo .root = _t
332341 elif _type == "array" :
333342 """anyOf/oneOf is taken care in in createAnnotation"""
334343 classinfo .root = Model .createAnnotation (schema , _type = "array" )
@@ -355,9 +364,8 @@ def createClassInfo(
355364 if _type in Model .types (i )
356365 )
357366 if schema .discriminator and schema .discriminator .mapping :
358- classinfo .root = Annotated [
359- Union [t ], Field (discriminator = Model .nameof (schema .discriminator .propertyName ))
360- ]
367+ classinfo .root = Union [t ]
368+ classinfo .anno = Field (discriminator = Model .nameof (schema .discriminator .propertyName ))
361369 else :
362370 if len (t ):
363371 classinfo .root = Union [t ]
@@ -373,9 +381,8 @@ def createClassInfo(
373381 if _type in Model .types (i )
374382 )
375383 if schema .discriminator and schema .discriminator .mapping :
376- classinfo .root = Annotated [
377- Union [t ], Field (discriminator = Model .nameof (schema .discriminator .propertyName ))
378- ]
384+ classinfo .root = Union [t ]
385+ classinfo .anno = Field (discriminator = Model .nameof (schema .discriminator .propertyName ))
379386 else :
380387 if len (t ):
381388 classinfo .root = Union [t ]
@@ -451,8 +458,8 @@ def validate_patternProperties(self_):
451458 """
452459 assert isinstance (schema , v20 .Schema )
453460 schema_ = v20 .Schema (type = "string" , format = "binary" )
454- _t = Model .createAnnotation (schema_ , _type = "string" )
455- classinfo .root = Annotated [ _t , Model .createField (schema_ , _type = "string" , args = None )]
461+ classinfo . root = Model .createAnnotation (schema_ , _type = "string" )
462+ classinfo .anno = Model .createField (schema_ , _type = "string" , args = None )
456463 else :
457464 raise ValueError (_type )
458465
@@ -716,7 +723,7 @@ def booleanFalse(schema: Optional["SchemaType"]) -> bool:
716723 raise ValueError (schema )
717724
718725 @staticmethod
719- def createField (schema : "SchemaType" , _type = None , args = None ):
726+ def createField (schema : "SchemaType" , _type = None , args = None ) -> Field :
720727 if args is None :
721728 args = dict (default = getattr (schema , "default" , None ))
722729
0 commit comments