@@ -137,6 +137,7 @@ async def call_tool(
137
137
call_tool_result = await session .call_tool (tool .name , arguments )
138
138
return _convert_call_tool_result (call_tool_result )
139
139
140
+ # base types being mapped from JSON
140
141
type_map = {
141
142
'null' : None ,
142
143
'integer' : int ,
@@ -148,6 +149,15 @@ async def call_tool(
148
149
}
149
150
150
151
def _parse_model_fields (args , injected_state ):
152
+ """Parse a JSON field into a Pydantic Field, taking into account injected state
153
+
154
+ :param args: the function parameter schema
155
+ :type args: dict
156
+ :param injected_state: the name of the key used for the InjectedState
157
+ :type injected_state: str
158
+ :return: returns a dict of fields with their pydantic type and default value if any
159
+ :rtype: dict
160
+ """
151
161
model_fields = {}
152
162
153
163
def _parse_field (props ):
@@ -174,9 +184,12 @@ def _parse_field(props):
174
184
model_fields [field ] = (field_type , ...)
175
185
return model_fields
176
186
177
- args = tool .inputSchema
187
+ args = tool .inputSchema
188
+ # check for the `injected_state`` annotation on the MCP tool.
189
+ # The injected_state value is the name of the function parameter used as the injected state
178
190
injected_state = tool .annotations .model_extra .get ('injected_state' )
179
191
if injected_state :
192
+ # import langgraph InjectedState only if we need it
180
193
from langgraph .prebuilt import InjectedState
181
194
model_fields = _parse_model_fields (args , injected_state )
182
195
0 commit comments