Skip to content

Commit e4c64c8

Browse files
committed
better types parsing
1 parent 17f7eee commit e4c64c8

File tree

1 file changed

+34
-12
lines changed

1 file changed

+34
-12
lines changed

langchain_mcp_adapters/tools.py

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
tools, handle tool execution, and manage tool conversion between the two formats.
55
"""
66

7-
from typing import Any, cast, get_args, Annotated
7+
from typing import Any, cast, get_args, Annotated, List, Union, Literal
88

99
from langchain_core.tools import (
1010
BaseTool,
@@ -138,27 +138,49 @@ async def call_tool(
138138
return _convert_call_tool_result(call_tool_result)
139139

140140
type_map = {
141+
'null': None,
141142
'integer': int,
142143
'float': float,
143144
'string': str,
144145
'bool': bool,
145146
'object': dict,
146-
'bytes': bytes
147+
'bytes': bytes,
147148
}
148149

149-
args = tool.inputSchema
150-
151-
model_fields = {}
150+
def _parse_model_fields(args, injected_state):
151+
model_fields = {}
152+
153+
def _parse_field(props):
154+
if 'anyOf' in props.keys():
155+
types = tuple(_parse_field(p) for p in props['anyOf'])
156+
return Union[types]
157+
if 'enum' in props.keys():
158+
return Literal[tuple(props['enum'])]
159+
if props['type'] == 'array':
160+
items = props['items']
161+
return List[tuple(_parse_field(items))]
162+
else:
163+
return type_map.get(props['type'], dict)
164+
165+
for field, props in args['properties'].items():
166+
if field == injected_state:
167+
field_type = Annotated[dict, InjectedState]
168+
else:
169+
field_type = _parse_field(props)
170+
if 'default' in props.keys():
171+
default = props['default']
172+
model_fields[field] = (field_type, default)
173+
else:
174+
model_fields[field] = (field_type, ...)
175+
return model_fields
176+
177+
args = tool.inputSchema
152178
injected_state = tool.annotations.model_extra.get('injected_state')
153179
if injected_state:
154180
from langgraph.prebuilt import InjectedState
155-
for field, props in args['properties'].items():
156-
field_type = type_map.get(props['type'], dict)
157-
if field == injected_state:
158-
field_type = Annotated[dict, InjectedState]
159-
model_fields[field] = field_type
160-
161-
args_schema = create_model(tool.name, **{k: (v, ...) for k, v in model_fields.items()})
181+
model_fields = _parse_model_fields(args, injected_state)
182+
183+
args_schema = create_model(tool.name, **{k: v for k, v in model_fields.items()})
162184

163185
return StructuredTool(
164186
name=tool.name,

0 commit comments

Comments
 (0)