|
4 | 4 | tools, handle tool execution, and manage tool conversion between the two formats.
|
5 | 5 | """
|
6 | 6 |
|
7 |
| -from typing import Any, cast, get_args, Annotated |
| 7 | +from typing import Any, cast, get_args, Annotated, List, Union, Literal |
8 | 8 |
|
9 | 9 | from langchain_core.tools import (
|
10 | 10 | BaseTool,
|
@@ -138,27 +138,49 @@ async def call_tool(
|
138 | 138 | return _convert_call_tool_result(call_tool_result)
|
139 | 139 |
|
140 | 140 | type_map = {
|
| 141 | + 'null': None, |
141 | 142 | 'integer': int,
|
142 | 143 | 'float': float,
|
143 | 144 | 'string': str,
|
144 | 145 | 'bool': bool,
|
145 | 146 | 'object': dict,
|
146 |
| - 'bytes': bytes |
| 147 | + 'bytes': bytes, |
147 | 148 | }
|
148 | 149 |
|
149 |
| - args = tool.inputSchema |
150 |
| - |
151 |
| - model_fields = {} |
| 150 | + def _parse_model_fields(args, injected_state): |
| 151 | + model_fields = {} |
| 152 | + |
| 153 | + def _parse_field(props): |
| 154 | + if 'anyOf' in props.keys(): |
| 155 | + types = tuple(_parse_field(p) for p in props['anyOf']) |
| 156 | + return Union[types] |
| 157 | + if 'enum' in props.keys(): |
| 158 | + return Literal[tuple(props['enum'])] |
| 159 | + if props['type'] == 'array': |
| 160 | + items = props['items'] |
| 161 | + return List[tuple(_parse_field(items))] |
| 162 | + else: |
| 163 | + return type_map.get(props['type'], dict) |
| 164 | + |
| 165 | + for field, props in args['properties'].items(): |
| 166 | + if field == injected_state: |
| 167 | + field_type = Annotated[dict, InjectedState] |
| 168 | + else: |
| 169 | + field_type = _parse_field(props) |
| 170 | + if 'default' in props.keys(): |
| 171 | + default = props['default'] |
| 172 | + model_fields[field] = (field_type, default) |
| 173 | + else: |
| 174 | + model_fields[field] = (field_type, ...) |
| 175 | + return model_fields |
| 176 | + |
| 177 | + args = tool.inputSchema |
152 | 178 | injected_state = tool.annotations.model_extra.get('injected_state')
|
153 | 179 | if injected_state:
|
154 | 180 | from langgraph.prebuilt import InjectedState
|
155 |
| - for field, props in args['properties'].items(): |
156 |
| - field_type = type_map.get(props['type'], dict) |
157 |
| - if field == injected_state: |
158 |
| - field_type = Annotated[dict, InjectedState] |
159 |
| - model_fields[field] = field_type |
160 |
| - |
161 |
| - args_schema = create_model(tool.name, **{k: (v, ...) for k, v in model_fields.items()}) |
| 181 | + model_fields = _parse_model_fields(args, injected_state) |
| 182 | + |
| 183 | + args_schema = create_model(tool.name, **{k: v for k, v in model_fields.items()}) |
162 | 184 |
|
163 | 185 | return StructuredTool(
|
164 | 186 | name=tool.name,
|
|
0 commit comments