Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 61 additions & 2 deletions langchain_mcp_adapters/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
tools, handle tool execution, and manage tool conversion between the two formats.
"""

from typing import Any, cast, get_args
from typing import Annotated, Any, Literal, Union, cast, get_args

from langchain_core.tools import (
BaseTool,
Expand Down Expand Up @@ -146,6 +146,65 @@ async def call_tool(

return _convert_call_tool_result(call_tool_result)

# base types being mapped from JSON
type_map = {
"null": None,
"integer": int,
"float": float,
"string": str,
"bool": bool,
"object": dict,
"bytes": bytes,
}

def _parse_model_fields(args: dict, injected_state: str | None = None) -> dict:
"""Parse a JSON field into a Pydantic Field, taking into account injected state.

:param args: the function parameter schema
:type args: dict
:param injected_state: the name of the key used for the InjectedState
:type injected_state: str
:return: returns a dict of fields with their pydantic type
and default value if any
:rtype: dict
"""
model_fields = {}

def _parse_field(props: dict) -> tuple[type, Any]:
if "anyOf" in props:
return Union[tuple(_parse_field(p) for p in props["anyOf"])] # noqa: UP007
if "enum" in props:
return Literal[tuple(props["enum"])]
if props["type"] == "array":
items = props["items"]
return list[_parse_field(items)]
return type_map.get(props["type"], dict)

for field, props in args["properties"].items():
if field == injected_state:
field_type = Annotated[dict, InjectedState]
else:
field_type = _parse_field(props)
if "default" in props:
default = props["default"]
model_fields[field] = (field_type, default)
else:
model_fields[field] = (field_type, ...)
return model_fields

args = tool.inputSchema
# check for the `injected_state`` annotation on the MCP tool.
# The injected_state value is the name of the function parameter used
# as the injected state
injected_state = tool.annotations.model_extra.get("injected_state")
if injected_state:
# import langgraph InjectedState only if we need it
from langgraph.prebuilt import InjectedState
model_fields = _parse_model_fields(args, injected_state)

# recreate a dynamic model based on the parsed JSON schema,
# which will be properly parsed with annotations for the InjectedState
args_schema = create_model(tool.name, **model_fields)
meta = tool.meta if hasattr(tool, "meta") else None
base = tool.annotations.model_dump() if tool.annotations is not None else {}
meta = {"_meta": meta} if meta is not None else {}
Expand All @@ -154,7 +213,7 @@ async def call_tool(
return StructuredTool(
name=tool.name,
description=tool.description or "",
args_schema=tool.inputSchema,
args_schema=args_schema,
coroutine=call_tool,
response_format="content_and_artifact",
metadata=metadata,
Expand Down