diff --git a/langchain_mcp_adapters/tools.py b/langchain_mcp_adapters/tools.py index 2cdc6d4..af904a5 100644 --- a/langchain_mcp_adapters/tools.py +++ b/langchain_mcp_adapters/tools.py @@ -4,7 +4,7 @@ tools, handle tool execution, and manage tool conversion between the two formats. """ -from typing import Any, cast, get_args +from typing import Annotated, Any, Literal, Union, cast, get_args from langchain_core.tools import ( BaseTool, @@ -146,6 +146,65 @@ async def call_tool( return _convert_call_tool_result(call_tool_result) + # base types being mapped from JSON + type_map = { + "null": None, + "integer": int, + "float": float, + "string": str, + "bool": bool, + "object": dict, + "bytes": bytes, + } + + def _parse_model_fields(args: dict, injected_state: str | None = None) -> dict: + """Parse a JSON field into a Pydantic Field, taking into account injected state. + + :param args: the function parameter schema + :type args: dict + :param injected_state: the name of the key used for the InjectedState + :type injected_state: str + :return: returns a dict of fields with their pydantic type + and default value if any + :rtype: dict + """ + model_fields = {} + + def _parse_field(props: dict) -> tuple[type, Any]: + if "anyOf" in props: + return Union[tuple(_parse_field(p) for p in props["anyOf"])] # noqa: UP007 + if "enum" in props: + return Literal[tuple(props["enum"])] + if props["type"] == "array": + items = props["items"] + return list[_parse_field(items)] + return type_map.get(props["type"], dict) + + for field, props in args["properties"].items(): + if field == injected_state: + field_type = Annotated[dict, InjectedState] + else: + field_type = _parse_field(props) + if "default" in props: + default = props["default"] + model_fields[field] = (field_type, default) + else: + model_fields[field] = (field_type, ...) + return model_fields + + args = tool.inputSchema + # check for the `injected_state`` annotation on the MCP tool. + # The injected_state value is the name of the function parameter used + # as the injected state + injected_state = tool.annotations.model_extra.get("injected_state") + if injected_state: + # import langgraph InjectedState only if we need it + from langgraph.prebuilt import InjectedState + model_fields = _parse_model_fields(args, injected_state) + + # recreate a dynamic model based on the parsed JSON schema, + # which will be properly parsed with annotations for the InjectedState + args_schema = create_model(tool.name, **model_fields) meta = tool.meta if hasattr(tool, "meta") else None base = tool.annotations.model_dump() if tool.annotations is not None else {} meta = {"_meta": meta} if meta is not None else {} @@ -154,7 +213,7 @@ async def call_tool( return StructuredTool( name=tool.name, description=tool.description or "", - args_schema=tool.inputSchema, + args_schema=args_schema, coroutine=call_tool, response_format="content_and_artifact", metadata=metadata,