Skip to content

Commit ef804bd

Browse files
committed
update type comparisons
1 parent 34dd096 commit ef804bd

File tree

2 files changed

+13
-13
lines changed

2 files changed

+13
-13
lines changed

guardrails/schema/pydantic_schema.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -88,22 +88,22 @@ def get_base_model(
8888
type_origin = get_origin(pydantic_class)
8989
key_type_origin = None
9090

91-
if type_origin == list:
91+
if type_origin is list:
9292
item_types = get_args(pydantic_class)
9393
if len(item_types) > 1:
9494
raise ValueError("List data type must have exactly one child.")
9595
item_type = safe_get(item_types, 0)
9696
if not item_type or not issubclass(item_type, BaseModel):
9797
raise ValueError("List item type must be a Pydantic model.")
9898
schema_model = item_type
99-
elif type_origin == dict:
99+
elif type_origin is dict:
100100
key_value_types = get_args(pydantic_class)
101101
value_type = safe_get(key_value_types, 1)
102102
key_type_origin = safe_get(key_value_types, 0)
103103
if not value_type or not issubclass(value_type, BaseModel):
104104
raise ValueError("Dict value type must be a Pydantic model.")
105105
schema_model = value_type
106-
elif type_origin == Union:
106+
elif type_origin is Union:
107107
union_members = get_args(pydantic_class)
108108
model_members = list(filter(is_base_model_type, union_members))
109109
if len(model_members) > 0:
@@ -141,7 +141,7 @@ def extract_union_member(
141141
field_model, field_type_origin, key_type_origin = try_get_base_model(member)
142142
if not field_model:
143143
return member
144-
if field_type_origin == Union:
144+
if field_type_origin is Union:
145145
union_members = get_args(field_model)
146146
extracted_union_members = []
147147
for m in union_members:
@@ -157,9 +157,9 @@ def extract_union_member(
157157
json_path=json_path,
158158
aliases=aliases,
159159
)
160-
if field_type_origin == list:
160+
if field_type_origin is list:
161161
return List[extracted_field_model]
162-
elif field_type_origin == dict:
162+
elif field_type_origin is dict:
163163
return Dict[key_type_origin, extracted_field_model] # type: ignore
164164
return extracted_field_model
165165

@@ -231,7 +231,7 @@ def extract_validators(
231231
field.annotation
232232
)
233233
if field_model:
234-
if field_type_origin == Union:
234+
if field_type_origin is Union:
235235
union_members = list(get_args(field_model))
236236
extracted_union_members = []
237237
for m in union_members:
@@ -254,11 +254,11 @@ def extract_validators(
254254
json_path=field_path,
255255
aliases=alias_paths,
256256
)
257-
if field_type_origin == list:
257+
if field_type_origin is list:
258258
model.model_fields[field_name].annotation = List[
259259
extracted_field_model
260260
]
261-
elif field_type_origin == dict:
261+
elif field_type_origin is dict:
262262
model.model_fields[field_name].annotation = Dict[
263263
key_type_origin, extracted_field_model # type: ignore
264264
]
@@ -276,7 +276,7 @@ def pydantic_to_json_schema(
276276
json_schema = pydantic_class.model_json_schema()
277277
json_schema["title"] = pydantic_class.__name__
278278

279-
if type_origin == list:
279+
if type_origin is list:
280280
json_schema = {
281281
"title": f"Array<{json_schema.get('title')}>",
282282
"type": "array",
@@ -294,7 +294,7 @@ def pydantic_model_to_schema(
294294
schema_model, type_origin, _key_type_origin = get_base_model(pydantic_class)
295295

296296
processed_schema.output_type = (
297-
OutputTypes.LIST if type_origin == list else OutputTypes.DICT
297+
OutputTypes.LIST if type_origin is list else OutputTypes.DICT
298298
)
299299

300300
model = extract_validators(schema_model, processed_schema, "$")

guardrails/utils/pydantic_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def convert_pydantic_model_to_openai_fn(
2828
schema_model = model
2929

3030
type_origin = get_origin(model)
31-
if type_origin == list:
31+
if type_origin is list:
3232
item_types = get_args(model)
3333
if len(item_types) > 1:
3434
raise ValueError("List data type must have exactly one child.")
@@ -41,7 +41,7 @@ def convert_pydantic_model_to_openai_fn(
4141
json_schema = schema_model.model_json_schema()
4242
json_schema["title"] = schema_model.__name__
4343

44-
if type_origin == list:
44+
if type_origin is list:
4545
json_schema = {
4646
"title": f"Array<{json_schema.get('title')}>",
4747
"type": "array",

0 commit comments

Comments
 (0)