From 17f7eee0600b825dd0d285de4bc6b31a14cd925d Mon Sep 17 00:00:00 2001 From: Emmanuel Leroy Date: Thu, 14 Aug 2025 16:08:00 -0700 Subject: [PATCH 1/6] InjectedState in MCP tools usage: On the MCP server side, specify the key to map to the state with annotations: @mcp.tool(annotations={"injected_state": "state"}) def tool(a, b, state): pass On the client side, it is handled automatically --- langchain_mcp_adapters/tools.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/langchain_mcp_adapters/tools.py b/langchain_mcp_adapters/tools.py index 0c6fa45..667f8c0 100644 --- a/langchain_mcp_adapters/tools.py +++ b/langchain_mcp_adapters/tools.py @@ -4,7 +4,7 @@ tools, handle tool execution, and manage tool conversion between the two formats. """ -from typing import Any, cast, get_args +from typing import Any, cast, get_args, Annotated from langchain_core.tools import ( BaseTool, @@ -20,6 +20,7 @@ from mcp.types import Tool as MCPTool from pydantic import BaseModel, create_model + from langchain_mcp_adapters.sessions import Connection, create_session NonTextContent = ImageContent | EmbeddedResource @@ -135,11 +136,34 @@ async def call_tool( else: call_tool_result = await session.call_tool(tool.name, arguments) return _convert_call_tool_result(call_tool_result) + + type_map = { + 'integer': int, + 'float': float, + 'string': str, + 'bool': bool, + 'object': dict, + 'bytes': bytes + } + + args = tool.inputSchema + + model_fields = {} + injected_state = tool.annotations.model_extra.get('injected_state') + if injected_state: + from langgraph.prebuilt import InjectedState + for field, props in args['properties'].items(): + field_type = type_map.get(props['type'], dict) + if field == injected_state: + field_type = Annotated[dict, InjectedState] + model_fields[field] = field_type + + args_schema = create_model(tool.name, **{k: (v, ...) for k, v in model_fields.items()}) return StructuredTool( name=tool.name, description=tool.description or "", - args_schema=tool.inputSchema, + args_schema=args_schema, coroutine=call_tool, response_format="content_and_artifact", metadata=tool.annotations.model_dump() if tool.annotations else None, From e4c64c8e973fad638497adee2a1c2fbc5d23c56f Mon Sep 17 00:00:00 2001 From: Emmanuel Leroy Date: Fri, 15 Aug 2025 14:03:59 -0700 Subject: [PATCH 2/6] better types parsing --- langchain_mcp_adapters/tools.py | 46 ++++++++++++++++++++++++--------- 1 file changed, 34 insertions(+), 12 deletions(-) diff --git a/langchain_mcp_adapters/tools.py b/langchain_mcp_adapters/tools.py index 667f8c0..66c4564 100644 --- a/langchain_mcp_adapters/tools.py +++ b/langchain_mcp_adapters/tools.py @@ -4,7 +4,7 @@ tools, handle tool execution, and manage tool conversion between the two formats. """ -from typing import Any, cast, get_args, Annotated +from typing import Any, cast, get_args, Annotated, List, Union, Literal from langchain_core.tools import ( BaseTool, @@ -138,27 +138,49 @@ async def call_tool( return _convert_call_tool_result(call_tool_result) type_map = { + 'null': None, 'integer': int, 'float': float, 'string': str, 'bool': bool, 'object': dict, - 'bytes': bytes + 'bytes': bytes, } - args = tool.inputSchema - - model_fields = {} + def _parse_model_fields(args, injected_state): + model_fields = {} + + def _parse_field(props): + if 'anyOf' in props.keys(): + types = tuple(_parse_field(p) for p in props['anyOf']) + return Union[types] + if 'enum' in props.keys(): + return Literal[tuple(props['enum'])] + if props['type'] == 'array': + items = props['items'] + return List[tuple(_parse_field(items))] + else: + return type_map.get(props['type'], dict) + + for field, props in args['properties'].items(): + if field == injected_state: + field_type = Annotated[dict, InjectedState] + else: + field_type = _parse_field(props) + if 'default' in props.keys(): + default = props['default'] + model_fields[field] = (field_type, default) + else: + model_fields[field] = (field_type, ...) + return model_fields + + args = tool.inputSchema injected_state = tool.annotations.model_extra.get('injected_state') if injected_state: from langgraph.prebuilt import InjectedState - for field, props in args['properties'].items(): - field_type = type_map.get(props['type'], dict) - if field == injected_state: - field_type = Annotated[dict, InjectedState] - model_fields[field] = field_type - - args_schema = create_model(tool.name, **{k: (v, ...) for k, v in model_fields.items()}) + model_fields = _parse_model_fields(args, injected_state) + + args_schema = create_model(tool.name, **{k: v for k, v in model_fields.items()}) return StructuredTool( name=tool.name, From c1225c18c8bc7f9f14d9c354a0f8a0145fd395de Mon Sep 17 00:00:00 2001 From: Emmanuel Leroy Date: Fri, 15 Aug 2025 14:10:19 -0700 Subject: [PATCH 3/6] more comments to explain the logic --- langchain_mcp_adapters/tools.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/langchain_mcp_adapters/tools.py b/langchain_mcp_adapters/tools.py index 66c4564..7c4a423 100644 --- a/langchain_mcp_adapters/tools.py +++ b/langchain_mcp_adapters/tools.py @@ -137,6 +137,7 @@ async def call_tool( call_tool_result = await session.call_tool(tool.name, arguments) return _convert_call_tool_result(call_tool_result) + # base types being mapped from JSON type_map = { 'null': None, 'integer': int, @@ -148,6 +149,15 @@ async def call_tool( } def _parse_model_fields(args, injected_state): + """Parse a JSON field into a Pydantic Field, taking into account injected state + + :param args: the function parameter schema + :type args: dict + :param injected_state: the name of the key used for the InjectedState + :type injected_state: str + :return: returns a dict of fields with their pydantic type and default value if any + :rtype: dict + """ model_fields = {} def _parse_field(props): @@ -174,9 +184,12 @@ def _parse_field(props): model_fields[field] = (field_type, ...) return model_fields - args = tool.inputSchema + args = tool.inputSchema + # check for the `injected_state`` annotation on the MCP tool. + # The injected_state value is the name of the function parameter used as the injected state injected_state = tool.annotations.model_extra.get('injected_state') if injected_state: + # import langgraph InjectedState only if we need it from langgraph.prebuilt import InjectedState model_fields = _parse_model_fields(args, injected_state) From 94bc9ebde4a8ea62c2d8c14388b389f806e7ffd3 Mon Sep 17 00:00:00 2001 From: Emmanuel Leroy Date: Fri, 15 Aug 2025 14:11:46 -0700 Subject: [PATCH 4/6] more comments --- langchain_mcp_adapters/tools.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/langchain_mcp_adapters/tools.py b/langchain_mcp_adapters/tools.py index 7c4a423..eca02a6 100644 --- a/langchain_mcp_adapters/tools.py +++ b/langchain_mcp_adapters/tools.py @@ -193,6 +193,8 @@ def _parse_field(props): from langgraph.prebuilt import InjectedState model_fields = _parse_model_fields(args, injected_state) + # recreate a dynamic model based on the parsed JSON schema, which will be properly parsed + # with annotations for the InjectedState args_schema = create_model(tool.name, **{k: v for k, v in model_fields.items()}) return StructuredTool( From f9ec252aec4c54a0041836356596ef51d6f59560 Mon Sep 17 00:00:00 2001 From: Emmanuel Leroy Date: Fri, 15 Aug 2025 14:24:56 -0700 Subject: [PATCH 5/6] fix List type parsing --- langchain_mcp_adapters/tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/langchain_mcp_adapters/tools.py b/langchain_mcp_adapters/tools.py index eca02a6..31c5d01 100644 --- a/langchain_mcp_adapters/tools.py +++ b/langchain_mcp_adapters/tools.py @@ -168,7 +168,7 @@ def _parse_field(props): return Literal[tuple(props['enum'])] if props['type'] == 'array': items = props['items'] - return List[tuple(_parse_field(items))] + return List[_parse_field(items)] else: return type_map.get(props['type'], dict) From faa300aaebbd4911933b0d43e6bc1b769ceb05fb Mon Sep 17 00:00:00 2001 From: Emmanuel Leroy Date: Fri, 15 Aug 2025 14:37:36 -0700 Subject: [PATCH 6/6] lint / format --- langchain_mcp_adapters/tools.py | 69 ++++++++++++++++----------------- 1 file changed, 34 insertions(+), 35 deletions(-) diff --git a/langchain_mcp_adapters/tools.py b/langchain_mcp_adapters/tools.py index 31c5d01..d5f96e8 100644 --- a/langchain_mcp_adapters/tools.py +++ b/langchain_mcp_adapters/tools.py @@ -4,7 +4,7 @@ tools, handle tool execution, and manage tool conversion between the two formats. """ -from typing import Any, cast, get_args, Annotated, List, Union, Literal +from typing import Annotated, Any, Literal, Union, cast, get_args from langchain_core.tools import ( BaseTool, @@ -20,7 +20,6 @@ from mcp.types import Tool as MCPTool from pydantic import BaseModel, create_model - from langchain_mcp_adapters.sessions import Connection, create_session NonTextContent = ImageContent | EmbeddedResource @@ -136,66 +135,66 @@ async def call_tool( else: call_tool_result = await session.call_tool(tool.name, arguments) return _convert_call_tool_result(call_tool_result) - + # base types being mapped from JSON type_map = { - 'null': None, - 'integer': int, - 'float': float, - 'string': str, - 'bool': bool, - 'object': dict, - 'bytes': bytes, + "null": None, + "integer": int, + "float": float, + "string": str, + "bool": bool, + "object": dict, + "bytes": bytes, } - def _parse_model_fields(args, injected_state): - """Parse a JSON field into a Pydantic Field, taking into account injected state + def _parse_model_fields(args: dict, injected_state: str | None = None) -> dict: + """Parse a JSON field into a Pydantic Field, taking into account injected state. :param args: the function parameter schema :type args: dict :param injected_state: the name of the key used for the InjectedState :type injected_state: str - :return: returns a dict of fields with their pydantic type and default value if any + :return: returns a dict of fields with their pydantic type + and default value if any :rtype: dict """ model_fields = {} - def _parse_field(props): - if 'anyOf' in props.keys(): - types = tuple(_parse_field(p) for p in props['anyOf']) - return Union[types] - if 'enum' in props.keys(): - return Literal[tuple(props['enum'])] - if props['type'] == 'array': - items = props['items'] - return List[_parse_field(items)] - else: - return type_map.get(props['type'], dict) - - for field, props in args['properties'].items(): + def _parse_field(props: dict) -> tuple[type, Any]: + if "anyOf" in props: + return Union[tuple(_parse_field(p) for p in props["anyOf"])] # noqa: UP007 + if "enum" in props: + return Literal[tuple(props["enum"])] + if props["type"] == "array": + items = props["items"] + return list[_parse_field(items)] + return type_map.get(props["type"], dict) + + for field, props in args["properties"].items(): if field == injected_state: field_type = Annotated[dict, InjectedState] else: field_type = _parse_field(props) - if 'default' in props.keys(): - default = props['default'] + if "default" in props: + default = props["default"] model_fields[field] = (field_type, default) else: model_fields[field] = (field_type, ...) return model_fields - args = tool.inputSchema - # check for the `injected_state`` annotation on the MCP tool. - # The injected_state value is the name of the function parameter used as the injected state - injected_state = tool.annotations.model_extra.get('injected_state') + args = tool.inputSchema + # check for the `injected_state`` annotation on the MCP tool. + # The injected_state value is the name of the function parameter used + # as the injected state + injected_state = tool.annotations.model_extra.get("injected_state") if injected_state: # import langgraph InjectedState only if we need it from langgraph.prebuilt import InjectedState model_fields = _parse_model_fields(args, injected_state) - # recreate a dynamic model based on the parsed JSON schema, which will be properly parsed - # with annotations for the InjectedState - args_schema = create_model(tool.name, **{k: v for k, v in model_fields.items()}) + # recreate a dynamic model based on the parsed JSON schema, + # which will be properly parsed with annotations for the InjectedState + args_schema = create_model(tool.name, **model_fields) return StructuredTool( name=tool.name,