@@ -44,7 +44,15 @@ class CancelledElicitation(BaseModel):
4444def _validate_elicitation_schema (schema : type [BaseModel ]) -> None :
4545 """Validate that a Pydantic model only contains primitive field types."""
4646 for field_name , field_info in schema .model_fields .items ():
47- if not _is_primitive_field (field_info ):
47+ annotation = field_info .annotation
48+
49+ if annotation is types .NoneType :
50+ return True
51+ elif _is_primitive_field (annotation ):
52+ continue
53+ elif _is_string_sequence (annotation ):
54+ continue
55+ else :
4856 raise TypeError (
4957 f"Elicitation schema field '{ field_name } ' must be a primitive type "
5058 f"{ _ELICITATION_PRIMITIVE_TYPES } , a sequence of strings (list[str], etc.), "
@@ -63,33 +71,18 @@ def _is_string_sequence(annotation: type) -> bool:
6371 return False
6472
6573
66- def _is_primitive_field (field_info : FieldInfo ) -> bool :
74+ def _is_primitive_field (annotation : type ) -> bool :
6775 """Check if a field is a primitive type allowed in elicitation schemas."""
68- annotation = field_info .annotation
69-
70- # Handle None type
71- if annotation is types .NoneType :
72- return True
73-
7476 # Handle basic primitive types
7577 if annotation in _ELICITATION_PRIMITIVE_TYPES :
7678 return True
7779
78- # Handle string sequences for multi-select enums
79- if annotation is not None and _is_string_sequence (annotation ):
80- return True
81-
8280 # Handle Union types
8381 origin = get_origin (annotation )
8482 if origin is Union or origin is types .UnionType :
8583 args = get_args (annotation )
86- # All args must be primitive types, None, or string sequences
87- return all (
88- arg is types .NoneType
89- or arg in _ELICITATION_PRIMITIVE_TYPES
90- or (arg is not None and _is_string_sequence (arg ))
91- for arg in args
92- )
84+ # All args must be primitive types or None
85+ return all (arg is types .NoneType or arg in _ELICITATION_PRIMITIVE_TYPES for arg in args )
9386
9487 return False
9588
0 commit comments