Skip to content

Commit 412aa44

Browse files
Tapan Chughfelixweinberger
authored andcommitted
cleanup changes a bit
1 parent c8684d8 commit 412aa44

File tree

1 file changed

+12
-19
lines changed

1 file changed

+12
-19
lines changed

src/mcp/server/elicitation.py

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,15 @@ class CancelledElicitation(BaseModel):
4444
def _validate_elicitation_schema(schema: type[BaseModel]) -> None:
4545
"""Validate that a Pydantic model only contains primitive field types."""
4646
for field_name, field_info in schema.model_fields.items():
47-
if not _is_primitive_field(field_info):
47+
annotation = field_info.annotation
48+
49+
if annotation is types.NoneType:
50+
return True
51+
elif _is_primitive_field(annotation):
52+
continue
53+
elif _is_string_sequence(annotation):
54+
continue
55+
else:
4856
raise TypeError(
4957
f"Elicitation schema field '{field_name}' must be a primitive type "
5058
f"{_ELICITATION_PRIMITIVE_TYPES}, a sequence of strings (list[str], etc.), "
@@ -63,33 +71,18 @@ def _is_string_sequence(annotation: type) -> bool:
6371
return False
6472

6573

66-
def _is_primitive_field(field_info: FieldInfo) -> bool:
74+
def _is_primitive_field(annotation: type) -> bool:
6775
"""Check if a field is a primitive type allowed in elicitation schemas."""
68-
annotation = field_info.annotation
69-
70-
# Handle None type
71-
if annotation is types.NoneType:
72-
return True
73-
7476
# Handle basic primitive types
7577
if annotation in _ELICITATION_PRIMITIVE_TYPES:
7678
return True
7779

78-
# Handle string sequences for multi-select enums
79-
if annotation is not None and _is_string_sequence(annotation):
80-
return True
81-
8280
# Handle Union types
8381
origin = get_origin(annotation)
8482
if origin is Union or origin is types.UnionType:
8583
args = get_args(annotation)
86-
# All args must be primitive types, None, or string sequences
87-
return all(
88-
arg is types.NoneType
89-
or arg in _ELICITATION_PRIMITIVE_TYPES
90-
or (arg is not None and _is_string_sequence(arg))
91-
for arg in args
92-
)
84+
# All args must be primitive types or None
85+
return all(arg is types.NoneType or arg in _ELICITATION_PRIMITIVE_TYPES for arg in args)
9386

9487
return False
9588

0 commit comments

Comments
 (0)