Skip to content

InjectedState in MCP tools #287

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 62 additions & 2 deletions langchain_mcp_adapters/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
tools, handle tool execution, and manage tool conversion between the two formats.
"""

from typing import Any, cast, get_args
from typing import Annotated, Any, Literal, Union, cast, get_args

from langchain_core.tools import (
BaseTool,
Expand Down Expand Up @@ -136,10 +136,70 @@ async def call_tool(
call_tool_result = await session.call_tool(tool.name, arguments)
return _convert_call_tool_result(call_tool_result)

# base types being mapped from JSON
type_map = {
"null": None,
"integer": int,
"float": float,
"string": str,
"bool": bool,
"object": dict,
"bytes": bytes,
}

def _parse_model_fields(args: dict, injected_state: str | None = None) -> dict:
"""Parse a JSON field into a Pydantic Field, taking into account injected state.

:param args: the function parameter schema
:type args: dict
:param injected_state: the name of the key used for the InjectedState
:type injected_state: str
:return: returns a dict of fields with their pydantic type
and default value if any
:rtype: dict
"""
model_fields = {}

def _parse_field(props: dict) -> tuple[type, Any]:
if "anyOf" in props:
return Union[tuple(_parse_field(p) for p in props["anyOf"])] # noqa: UP007
if "enum" in props:
return Literal[tuple(props["enum"])]
if props["type"] == "array":
items = props["items"]
return list[_parse_field(items)]
return type_map.get(props["type"], dict)

for field, props in args["properties"].items():
if field == injected_state:
field_type = Annotated[dict, InjectedState]
else:
field_type = _parse_field(props)
if "default" in props:
default = props["default"]
model_fields[field] = (field_type, default)
else:
model_fields[field] = (field_type, ...)
return model_fields

args = tool.inputSchema
# check for the `injected_state`` annotation on the MCP tool.
# The injected_state value is the name of the function parameter used
# as the injected state
injected_state = tool.annotations.model_extra.get("injected_state")
if injected_state:
# import langgraph InjectedState only if we need it
from langgraph.prebuilt import InjectedState
model_fields = _parse_model_fields(args, injected_state)

# recreate a dynamic model based on the parsed JSON schema,
# which will be properly parsed with annotations for the InjectedState
args_schema = create_model(tool.name, **model_fields)

return StructuredTool(
name=tool.name,
description=tool.description or "",
args_schema=tool.inputSchema,
args_schema=args_schema,
coroutine=call_tool,
response_format="content_and_artifact",
metadata=tool.annotations.model_dump() if tool.annotations else None,
Expand Down