File tree Expand file tree Collapse file tree 3 files changed +48
-2
lines changed
Expand file tree Collapse file tree 3 files changed +48
-2
lines changed Original file line number Diff line number Diff line change @@ -557,7 +557,7 @@ def check_weather(location: str) -> str:
557557 if not is_dynamic_model :
558558 if isinstance (model , str ):
559559 try :
560- from langchain .chat_models import ( # type: ignore[import-not-found]
560+ from langchain .chat_models import (
561561 init_chat_model ,
562562 )
563563 except ImportError :
@@ -566,7 +566,7 @@ def check_weather(location: str) -> str:
566566 "use '<provider>:<model>' string syntax for `model` parameter."
567567 )
568568
569- model = cast ( BaseChatModel , init_chat_model (model ) )
569+ model = init_chat_model (model )
570570
571571 if (
572572 _should_bind_tools (model , tool_classes , num_builtin = len (llm_builtin_tools )) # type: ignore[arg-type]
Original file line number Diff line number Diff line change @@ -1748,6 +1748,11 @@ def _is_injection(
17481748 origin_ = get_origin (type_arg )
17491749 if origin_ is Union or origin_ is Annotated :
17501750 return any (_is_injection (ta , injection_type ) for ta in get_args (type_arg ))
1751+
1752+ if origin_ is not None and (
1753+ origin_ is injection_type or (isinstance (origin_ , type ) and issubclass (origin_ , injection_type ))
1754+ ):
1755+ return True
17511756 return False
17521757
17531758
Original file line number Diff line number Diff line change @@ -1860,3 +1860,44 @@ async def comprehensive_async_tool(
18601860 "foo_from_runtime=foo_value, "
18611861 "tool_call_id=test_call_789"
18621862 )
1863+
1864+ async def test_tool_node_tool_runtime_generic () -> None :
1865+ """Test that ToolRuntime with generic type arguments is correctly injected."""
1866+
1867+ @dataclasses .dataclass
1868+ class MyContext :
1869+ some_info : str
1870+
1871+ @dec_tool
1872+ def get_info (rt : ToolRuntime [MyContext ]):
1873+ """This tool returns info from context."""
1874+ return rt .context .some_info
1875+
1876+ # Create a mock runtime with context
1877+ mock_runtime = _create_mock_runtime ()
1878+ mock_runtime .context = MyContext (some_info = "test_info" )
1879+
1880+ config = {"configurable" : {"__pregel_runtime" : mock_runtime }}
1881+
1882+ result = await ToolNode ([get_info ]).ainvoke (
1883+ {
1884+ "messages" : [
1885+ AIMessage (
1886+ "call tool" ,
1887+ tool_calls = [
1888+ {
1889+ "name" : "get_info" ,
1890+ "args" : {},
1891+ "id" : "call_1" ,
1892+ }
1893+ ],
1894+ )
1895+ ]
1896+ },
1897+ config = config ,
1898+ )
1899+
1900+ tool_message = result ["messages" ][- 1 ]
1901+ assert tool_message .type == "tool"
1902+ assert tool_message .content == "test_info"
1903+ assert tool_message .tool_call_id == "call_1"
You can’t perform that action at this time.
0 commit comments