|
4 | 4 | tools, handle tool execution, and manage tool conversion between the two formats.
|
5 | 5 | """
|
6 | 6 |
|
7 |
| -from typing import Any, cast, get_args, Annotated, List, Union, Literal |
| 7 | +from typing import Annotated, Any, Literal, Union, cast, get_args |
8 | 8 |
|
9 | 9 | from langchain_core.tools import (
|
10 | 10 | BaseTool,
|
|
20 | 20 | from mcp.types import Tool as MCPTool
|
21 | 21 | from pydantic import BaseModel, create_model
|
22 | 22 |
|
23 |
| - |
24 | 23 | from langchain_mcp_adapters.sessions import Connection, create_session
|
25 | 24 |
|
26 | 25 | NonTextContent = ImageContent | EmbeddedResource
|
@@ -136,66 +135,66 @@ async def call_tool(
|
136 | 135 | else:
|
137 | 136 | call_tool_result = await session.call_tool(tool.name, arguments)
|
138 | 137 | return _convert_call_tool_result(call_tool_result)
|
139 |
| - |
| 138 | + |
140 | 139 | # base types being mapped from JSON
|
141 | 140 | type_map = {
|
142 |
| - 'null': None, |
143 |
| - 'integer': int, |
144 |
| - 'float': float, |
145 |
| - 'string': str, |
146 |
| - 'bool': bool, |
147 |
| - 'object': dict, |
148 |
| - 'bytes': bytes, |
| 141 | + "null": None, |
| 142 | + "integer": int, |
| 143 | + "float": float, |
| 144 | + "string": str, |
| 145 | + "bool": bool, |
| 146 | + "object": dict, |
| 147 | + "bytes": bytes, |
149 | 148 | }
|
150 | 149 |
|
151 |
| - def _parse_model_fields(args, injected_state): |
152 |
| - """Parse a JSON field into a Pydantic Field, taking into account injected state |
| 150 | + def _parse_model_fields(args: dict, injected_state: str | None = None) -> dict: |
| 151 | + """Parse a JSON field into a Pydantic Field, taking into account injected state. |
153 | 152 |
|
154 | 153 | :param args: the function parameter schema
|
155 | 154 | :type args: dict
|
156 | 155 | :param injected_state: the name of the key used for the InjectedState
|
157 | 156 | :type injected_state: str
|
158 |
| - :return: returns a dict of fields with their pydantic type and default value if any |
| 157 | + :return: returns a dict of fields with their pydantic type |
| 158 | + and default value if any |
159 | 159 | :rtype: dict
|
160 | 160 | """
|
161 | 161 | model_fields = {}
|
162 | 162 |
|
163 |
| - def _parse_field(props): |
164 |
| - if 'anyOf' in props.keys(): |
165 |
| - types = tuple(_parse_field(p) for p in props['anyOf']) |
166 |
| - return Union[types] |
167 |
| - if 'enum' in props.keys(): |
168 |
| - return Literal[tuple(props['enum'])] |
169 |
| - if props['type'] == 'array': |
170 |
| - items = props['items'] |
171 |
| - return List[_parse_field(items)] |
172 |
| - else: |
173 |
| - return type_map.get(props['type'], dict) |
174 |
| - |
175 |
| - for field, props in args['properties'].items(): |
| 163 | + def _parse_field(props: dict) -> tuple[type, Any]: |
| 164 | + if "anyOf" in props: |
| 165 | + return Union[tuple(_parse_field(p) for p in props["anyOf"])] # noqa: UP007 |
| 166 | + if "enum" in props: |
| 167 | + return Literal[tuple(props["enum"])] |
| 168 | + if props["type"] == "array": |
| 169 | + items = props["items"] |
| 170 | + return list[_parse_field(items)] |
| 171 | + return type_map.get(props["type"], dict) |
| 172 | + |
| 173 | + for field, props in args["properties"].items(): |
176 | 174 | if field == injected_state:
|
177 | 175 | field_type = Annotated[dict, InjectedState]
|
178 | 176 | else:
|
179 | 177 | field_type = _parse_field(props)
|
180 |
| - if 'default' in props.keys(): |
181 |
| - default = props['default'] |
| 178 | + if "default" in props: |
| 179 | + default = props["default"] |
182 | 180 | model_fields[field] = (field_type, default)
|
183 | 181 | else:
|
184 | 182 | model_fields[field] = (field_type, ...)
|
185 | 183 | return model_fields
|
186 | 184 |
|
187 |
| - args = tool.inputSchema |
188 |
| - # check for the `injected_state`` annotation on the MCP tool. |
189 |
| - # The injected_state value is the name of the function parameter used as the injected state |
190 |
| - injected_state = tool.annotations.model_extra.get('injected_state') |
| 185 | + args = tool.inputSchema |
| 186 | + # check for the `injected_state`` annotation on the MCP tool. |
| 187 | + # The injected_state value is the name of the function parameter used |
| 188 | + # as the injected state |
| 189 | + injected_state = tool.annotations.model_extra.get("injected_state") |
191 | 190 | if injected_state:
|
192 | 191 | # import langgraph InjectedState only if we need it
|
193 | 192 | from langgraph.prebuilt import InjectedState
|
194 | 193 | model_fields = _parse_model_fields(args, injected_state)
|
195 | 194 |
|
196 |
| - # recreate a dynamic model based on the parsed JSON schema, which will be properly parsed |
197 |
| - # with annotations for the InjectedState |
198 |
| - args_schema = create_model(tool.name, **{k: v for k, v in model_fields.items()}) |
| 195 | + # recreate a dynamic model based on the parsed JSON schema, |
| 196 | + # which will be properly parsed with annotations for the InjectedState |
| 197 | + args_schema = create_model(tool.name, **model_fields) |
199 | 198 |
|
200 | 199 | return StructuredTool(
|
201 | 200 | name=tool.name,
|
|
0 commit comments