|
4 | 4 | tools, handle tool execution, and manage tool conversion between the two formats.
|
5 | 5 | """
|
6 | 6 |
|
7 |
| -from typing import Any, cast, get_args |
| 7 | +from typing import Any, cast, get_args, Annotated |
8 | 8 |
|
9 | 9 | from langchain_core.tools import (
|
10 | 10 | BaseTool,
|
|
20 | 20 | from mcp.types import Tool as MCPTool
|
21 | 21 | from pydantic import BaseModel, create_model
|
22 | 22 |
|
| 23 | + |
23 | 24 | from langchain_mcp_adapters.sessions import Connection, create_session
|
24 | 25 |
|
25 | 26 | NonTextContent = ImageContent | EmbeddedResource
|
@@ -135,11 +136,34 @@ async def call_tool(
|
135 | 136 | else:
|
136 | 137 | call_tool_result = await session.call_tool(tool.name, arguments)
|
137 | 138 | return _convert_call_tool_result(call_tool_result)
|
| 139 | + |
| 140 | + type_map = { |
| 141 | + 'integer': int, |
| 142 | + 'float': float, |
| 143 | + 'string': str, |
| 144 | + 'bool': bool, |
| 145 | + 'object': dict, |
| 146 | + 'bytes': bytes |
| 147 | + } |
| 148 | + |
| 149 | + args = tool.inputSchema |
| 150 | + |
| 151 | + model_fields = {} |
| 152 | + injected_state = tool.annotations.model_extra.get('injected_state') |
| 153 | + if injected_state: |
| 154 | + from langgraph.prebuilt import InjectedState |
| 155 | + for field, props in args['properties'].items(): |
| 156 | + field_type = type_map.get(props['type'], dict) |
| 157 | + if field == injected_state: |
| 158 | + field_type = Annotated[dict, InjectedState] |
| 159 | + model_fields[field] = field_type |
| 160 | + |
| 161 | + args_schema = create_model(tool.name, **{k: (v, ...) for k, v in model_fields.items()}) |
138 | 162 |
|
139 | 163 | return StructuredTool(
|
140 | 164 | name=tool.name,
|
141 | 165 | description=tool.description or "",
|
142 |
| - args_schema=tool.inputSchema, |
| 166 | + args_schema=args_schema, |
143 | 167 | coroutine=call_tool,
|
144 | 168 | response_format="content_and_artifact",
|
145 | 169 | metadata=tool.annotations.model_dump() if tool.annotations else None,
|
|
0 commit comments