diff --git a/libs/prebuilt/langgraph/prebuilt/chat_agent_executor.py b/libs/prebuilt/langgraph/prebuilt/chat_agent_executor.py index 42352fefbc..89d8408144 100644 --- a/libs/prebuilt/langgraph/prebuilt/chat_agent_executor.py +++ b/libs/prebuilt/langgraph/prebuilt/chat_agent_executor.py @@ -557,7 +557,7 @@ def check_weather(location: str) -> str: if not is_dynamic_model: if isinstance(model, str): try: - from langchain.chat_models import ( # type: ignore[import-not-found] + from langchain.chat_models import ( init_chat_model, ) except ImportError: @@ -566,7 +566,7 @@ def check_weather(location: str) -> str: "use ':' string syntax for `model` parameter." ) - model = cast(BaseChatModel, init_chat_model(model)) + model = init_chat_model(model) if ( _should_bind_tools(model, tool_classes, num_builtin=len(llm_builtin_tools)) # type: ignore[arg-type] diff --git a/libs/prebuilt/langgraph/prebuilt/tool_node.py b/libs/prebuilt/langgraph/prebuilt/tool_node.py index 425ee38b0b..f598194476 100644 --- a/libs/prebuilt/langgraph/prebuilt/tool_node.py +++ b/libs/prebuilt/langgraph/prebuilt/tool_node.py @@ -1748,6 +1748,12 @@ def _is_injection( origin_ = get_origin(type_arg) if origin_ is Union or origin_ is Annotated: return any(_is_injection(ta, injection_type) for ta in get_args(type_arg)) + + if origin_ is not None and ( + origin_ is injection_type + or (isinstance(origin_, type) and issubclass(origin_, injection_type)) + ): + return True return False diff --git a/libs/prebuilt/tests/test_tool_node.py b/libs/prebuilt/tests/test_tool_node.py index 00845d3982..d1d83382a1 100644 --- a/libs/prebuilt/tests/test_tool_node.py +++ b/libs/prebuilt/tests/test_tool_node.py @@ -1860,3 +1860,45 @@ async def comprehensive_async_tool( "foo_from_runtime=foo_value, " "tool_call_id=test_call_789" ) + + +async def test_tool_node_tool_runtime_generic() -> None: + """Test that ToolRuntime with generic type arguments is correctly injected.""" + + @dataclasses.dataclass + class MyContext: + some_info: str + + @dec_tool + def get_info(rt: ToolRuntime[MyContext]): + """This tool returns info from context.""" + return rt.context.some_info + + # Create a mock runtime with context + mock_runtime = _create_mock_runtime() + mock_runtime.context = MyContext(some_info="test_info") + + config = {"configurable": {"__pregel_runtime": mock_runtime}} + + result = await ToolNode([get_info]).ainvoke( + { + "messages": [ + AIMessage( + "call tool", + tool_calls=[ + { + "name": "get_info", + "args": {}, + "id": "call_1", + } + ], + ) + ] + }, + config=config, + ) + + tool_message = result["messages"][-1] + assert tool_message.type == "tool" + assert tool_message.content == "test_info" + assert tool_message.tool_call_id == "call_1"