@@ -88,22 +88,22 @@ def get_base_model(
8888 type_origin = get_origin (pydantic_class )
8989 key_type_origin = None
9090
91- if type_origin == list :
91+ if type_origin is list :
9292 item_types = get_args (pydantic_class )
9393 if len (item_types ) > 1 :
9494 raise ValueError ("List data type must have exactly one child." )
9595 item_type = safe_get (item_types , 0 )
9696 if not item_type or not issubclass (item_type , BaseModel ):
9797 raise ValueError ("List item type must be a Pydantic model." )
9898 schema_model = item_type
99- elif type_origin == dict :
99+ elif type_origin is dict :
100100 key_value_types = get_args (pydantic_class )
101101 value_type = safe_get (key_value_types , 1 )
102102 key_type_origin = safe_get (key_value_types , 0 )
103103 if not value_type or not issubclass (value_type , BaseModel ):
104104 raise ValueError ("Dict value type must be a Pydantic model." )
105105 schema_model = value_type
106- elif type_origin == Union :
106+ elif type_origin is Union :
107107 union_members = get_args (pydantic_class )
108108 model_members = list (filter (is_base_model_type , union_members ))
109109 if len (model_members ) > 0 :
@@ -141,7 +141,7 @@ def extract_union_member(
141141 field_model , field_type_origin , key_type_origin = try_get_base_model (member )
142142 if not field_model :
143143 return member
144- if field_type_origin == Union :
144+ if field_type_origin is Union :
145145 union_members = get_args (field_model )
146146 extracted_union_members = []
147147 for m in union_members :
@@ -157,9 +157,9 @@ def extract_union_member(
157157 json_path = json_path ,
158158 aliases = aliases ,
159159 )
160- if field_type_origin == list :
160+ if field_type_origin is list :
161161 return List [extracted_field_model ]
162- elif field_type_origin == dict :
162+ elif field_type_origin is dict :
163163 return Dict [key_type_origin , extracted_field_model ] # type: ignore
164164 return extracted_field_model
165165
@@ -231,7 +231,7 @@ def extract_validators(
231231 field .annotation
232232 )
233233 if field_model :
234- if field_type_origin == Union :
234+ if field_type_origin is Union :
235235 union_members = list (get_args (field_model ))
236236 extracted_union_members = []
237237 for m in union_members :
@@ -254,11 +254,11 @@ def extract_validators(
254254 json_path = field_path ,
255255 aliases = alias_paths ,
256256 )
257- if field_type_origin == list :
257+ if field_type_origin is list :
258258 model .model_fields [field_name ].annotation = List [
259259 extracted_field_model
260260 ]
261- elif field_type_origin == dict :
261+ elif field_type_origin is dict :
262262 model .model_fields [field_name ].annotation = Dict [
263263 key_type_origin , extracted_field_model # type: ignore
264264 ]
@@ -276,7 +276,7 @@ def pydantic_to_json_schema(
276276 json_schema = pydantic_class .model_json_schema ()
277277 json_schema ["title" ] = pydantic_class .__name__
278278
279- if type_origin == list :
279+ if type_origin is list :
280280 json_schema = {
281281 "title" : f"Array<{ json_schema .get ('title' )} >" ,
282282 "type" : "array" ,
@@ -294,7 +294,7 @@ def pydantic_model_to_schema(
294294 schema_model , type_origin , _key_type_origin = get_base_model (pydantic_class )
295295
296296 processed_schema .output_type = (
297- OutputTypes .LIST if type_origin == list else OutputTypes .DICT
297+ OutputTypes .LIST if type_origin is list else OutputTypes .DICT
298298 )
299299
300300 model = extract_validators (schema_model , processed_schema , "$" )
0 commit comments