Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions libs/prebuilt/langgraph/prebuilt/chat_agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ def check_weather(location: str) -> str:
if not is_dynamic_model:
if isinstance(model, str):
try:
from langchain.chat_models import ( # type: ignore[import-not-found]
from langchain.chat_models import (
init_chat_model,
)
except ImportError:
Expand All @@ -566,7 +566,7 @@ def check_weather(location: str) -> str:
"use '<provider>:<model>' string syntax for `model` parameter."
)

model = cast(BaseChatModel, init_chat_model(model))
model = init_chat_model(model)

if (
_should_bind_tools(model, tool_classes, num_builtin=len(llm_builtin_tools)) # type: ignore[arg-type]
Expand Down
6 changes: 6 additions & 0 deletions libs/prebuilt/langgraph/prebuilt/tool_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -1748,6 +1748,12 @@ def _is_injection(
origin_ = get_origin(type_arg)
if origin_ is Union or origin_ is Annotated:
return any(_is_injection(ta, injection_type) for ta in get_args(type_arg))

if origin_ is not None and (
origin_ is injection_type
or (isinstance(origin_, type) and issubclass(origin_, injection_type))
):
return True
return False


Expand Down
42 changes: 42 additions & 0 deletions libs/prebuilt/tests/test_tool_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -1860,3 +1860,45 @@ async def comprehensive_async_tool(
"foo_from_runtime=foo_value, "
"tool_call_id=test_call_789"
)


async def test_tool_node_tool_runtime_generic() -> None:
"""Test that ToolRuntime with generic type arguments is correctly injected."""

@dataclasses.dataclass
class MyContext:
some_info: str

@dec_tool
def get_info(rt: ToolRuntime[MyContext]):
"""This tool returns info from context."""
return rt.context.some_info

# Create a mock runtime with context
mock_runtime = _create_mock_runtime()
mock_runtime.context = MyContext(some_info="test_info")

config = {"configurable": {"__pregel_runtime": mock_runtime}}

result = await ToolNode([get_info]).ainvoke(
{
"messages": [
AIMessage(
"call tool",
tool_calls=[
{
"name": "get_info",
"args": {},
"id": "call_1",
}
],
)
]
},
config=config,
)

tool_message = result["messages"][-1]
assert tool_message.type == "tool"
assert tool_message.content == "test_info"
assert tool_message.tool_call_id == "call_1"