@@ -875,6 +875,10 @@ def _convert_to_content(
875875 return [TextContent (type = "text" , text = result )]
876876
877877
878+ # Primitive types allowed in elicitation schemas
879+ _ELICITATION_PRIMITIVE_TYPES = (str , int , float , bool )
880+
881+
878882def _validate_elicitation_schema (schema : type [BaseModel ]) -> None :
879883 """Validate that a Pydantic model only contains primitive field types."""
880884 for field_name , field_info in schema .model_fields .items ():
@@ -886,28 +890,24 @@ def _validate_elicitation_schema(schema: type[BaseModel]) -> None:
886890 )
887891
888892
889- # Primitive types allowed in elicitation schemas
890- _ELICITATION_PRIMITIVE_TYPES = (str , int , float , bool )
891-
892-
893893def _is_primitive_field (field_info : FieldInfo ) -> bool :
894894 """Check if a field is a primitive type allowed in elicitation schemas."""
895895 annotation = field_info .annotation
896896
897897 # Handle None type
898- if annotation is type ( None ) :
898+ if annotation is types . NoneType :
899899 return True
900900
901901 # Handle basic primitive types
902902 if annotation in _ELICITATION_PRIMITIVE_TYPES :
903903 return True
904904
905- # Handle Union types (including Optional and Python 3.10+ union syntax)
905+ # Handle Union types
906906 origin = get_origin (annotation )
907- if origin is Union or ( hasattr ( types , 'UnionType' ) and isinstance ( annotation , types .UnionType )) :
907+ if origin is Union or origin is types .UnionType :
908908 args = get_args (annotation )
909909 # All args must be primitive types or None
910- return all (arg is type ( None ) or arg in _ELICITATION_PRIMITIVE_TYPES for arg in args )
910+ return all (arg is types . NoneType or arg in _ELICITATION_PRIMITIVE_TYPES for arg in args )
911911
912912 return False
913913
0 commit comments