Skip to content

Commit 29f4d21

Browse files
committed
fix: Support ToolRuntime with generic type arguments
1 parent 630bd9a commit 29f4d21

File tree

3 files changed

+48
-2
lines changed

3 files changed

+48
-2
lines changed

libs/prebuilt/langgraph/prebuilt/chat_agent_executor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,7 @@ def check_weather(location: str) -> str:
557557
if not is_dynamic_model:
558558
if isinstance(model, str):
559559
try:
560-
from langchain.chat_models import ( # type: ignore[import-not-found]
560+
from langchain.chat_models import (
561561
init_chat_model,
562562
)
563563
except ImportError:
@@ -566,7 +566,7 @@ def check_weather(location: str) -> str:
566566
"use '<provider>:<model>' string syntax for `model` parameter."
567567
)
568568

569-
model = cast(BaseChatModel, init_chat_model(model))
569+
model = init_chat_model(model)
570570

571571
if (
572572
_should_bind_tools(model, tool_classes, num_builtin=len(llm_builtin_tools)) # type: ignore[arg-type]

libs/prebuilt/langgraph/prebuilt/tool_node.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1748,6 +1748,11 @@ def _is_injection(
17481748
origin_ = get_origin(type_arg)
17491749
if origin_ is Union or origin_ is Annotated:
17501750
return any(_is_injection(ta, injection_type) for ta in get_args(type_arg))
1751+
1752+
if origin_ is not None and (
1753+
origin_ is injection_type or (isinstance(origin_, type) and issubclass(origin_, injection_type))
1754+
):
1755+
return True
17511756
return False
17521757

17531758

libs/prebuilt/tests/test_tool_node.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1860,3 +1860,44 @@ async def comprehensive_async_tool(
18601860
"foo_from_runtime=foo_value, "
18611861
"tool_call_id=test_call_789"
18621862
)
1863+
1864+
async def test_tool_node_tool_runtime_generic() -> None:
1865+
"""Test that ToolRuntime with generic type arguments is correctly injected."""
1866+
1867+
@dataclasses.dataclass
1868+
class MyContext:
1869+
some_info: str
1870+
1871+
@dec_tool
1872+
def get_info(rt: ToolRuntime[MyContext]):
1873+
"""This tool returns info from context."""
1874+
return rt.context.some_info
1875+
1876+
# Create a mock runtime with context
1877+
mock_runtime = _create_mock_runtime()
1878+
mock_runtime.context = MyContext(some_info="test_info")
1879+
1880+
config = {"configurable": {"__pregel_runtime": mock_runtime}}
1881+
1882+
result = await ToolNode([get_info]).ainvoke(
1883+
{
1884+
"messages": [
1885+
AIMessage(
1886+
"call tool",
1887+
tool_calls=[
1888+
{
1889+
"name": "get_info",
1890+
"args": {},
1891+
"id": "call_1",
1892+
}
1893+
],
1894+
)
1895+
]
1896+
},
1897+
config=config,
1898+
)
1899+
1900+
tool_message = result["messages"][-1]
1901+
assert tool_message.type == "tool"
1902+
assert tool_message.content == "test_info"
1903+
assert tool_message.tool_call_id == "call_1"

0 commit comments

Comments
 (0)