Skip to content

Commit 32bbe99

Browse files
chore: Support tool runtime injection when custom args schema is prov… (#33999)
Support injection of injected args (like `InjectedToolCallId`, `ToolRuntime`) when an `args_schema` is specified that doesn't contain said args. This allows for pydantic validation of other args while retaining the ability to inject langchain specific arguments. fixes #33646 fixes #31688 Taking a deep dive here reminded me that we definitely need to revisit our internal tooling logic, but I don't think we should do that in this PR. --------- Co-authored-by: Sydney Runkle <[email protected]> Co-authored-by: Sydney Runkle <[email protected]>
1 parent 990e346 commit 32bbe99

File tree

3 files changed

+138
-2
lines changed

3 files changed

+138
-2
lines changed

libs/core/langchain_core/tools/base.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,8 @@ class ToolException(Exception): # noqa: N818
386386

387387
ArgsSchema = TypeBaseModel | dict[str, Any]
388388

389+
_EMPTY_SET: frozenset[str] = frozenset()
390+
389391

390392
class BaseTool(RunnableSerializable[str | dict | ToolCall, Any]):
391393
"""Base class for all LangChain tools.
@@ -569,6 +571,11 @@ def tool_call_schema(self) -> ArgsSchema:
569571
self.name, full_schema, fields, fn_description=self.description
570572
)
571573

574+
@functools.cached_property
575+
def _injected_args_keys(self) -> frozenset[str]:
576+
# base implementation doesn't manage injected args
577+
return _EMPTY_SET
578+
572579
# --- Runnable ---
573580

574581
@override
@@ -649,6 +656,7 @@ def _parse_input(
649656
if isinstance(input_args, dict):
650657
return tool_input
651658
if issubclass(input_args, BaseModel):
659+
# Check args_schema for InjectedToolCallId
652660
for k, v in get_all_basemodel_annotations(input_args).items():
653661
if _is_injected_arg_type(v, injected_type=InjectedToolCallId):
654662
if tool_call_id is None:
@@ -664,6 +672,7 @@ def _parse_input(
664672
result = input_args.model_validate(tool_input)
665673
result_dict = result.model_dump()
666674
elif issubclass(input_args, BaseModelV1):
675+
# Check args_schema for InjectedToolCallId
667676
for k, v in get_all_basemodel_annotations(input_args).items():
668677
if _is_injected_arg_type(v, injected_type=InjectedToolCallId):
669678
if tool_call_id is None:
@@ -683,9 +692,25 @@ def _parse_input(
683692
f"args_schema must be a Pydantic BaseModel, got {self.args_schema}"
684693
)
685694
raise NotImplementedError(msg)
686-
return {
687-
k: getattr(result, k) for k, v in result_dict.items() if k in tool_input
695+
validated_input = {
696+
k: getattr(result, k) for k in result_dict if k in tool_input
688697
}
698+
for k in self._injected_args_keys:
699+
if k == "tool_call_id":
700+
if tool_call_id is None:
701+
msg = (
702+
"When tool includes an InjectedToolCallId "
703+
"argument, tool must always be invoked with a full "
704+
"model ToolCall of the form: {'args': {...}, "
705+
"'name': '...', 'type': 'tool_call', "
706+
"'tool_call_id': '...'}"
707+
)
708+
raise ValueError(msg)
709+
validated_input[k] = tool_call_id
710+
if k in tool_input:
711+
injected_val = tool_input[k]
712+
validated_input[k] = injected_val
713+
return validated_input
689714
return tool_input
690715

691716
@abstractmethod

libs/core/langchain_core/tools/structured.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import functools
56
import textwrap
67
from collections.abc import Awaitable, Callable
78
from inspect import signature
@@ -21,10 +22,12 @@
2122
)
2223
from langchain_core.runnables import RunnableConfig, run_in_executor
2324
from langchain_core.tools.base import (
25+
_EMPTY_SET,
2426
FILTERED_ARGS,
2527
ArgsSchema,
2628
BaseTool,
2729
_get_runnable_config_param,
30+
_is_injected_arg_type,
2831
create_schema_from_function,
2932
)
3033
from langchain_core.utils.pydantic import is_basemodel_subclass
@@ -241,6 +244,17 @@ def add(a: int, b: int) -> int:
241244
**kwargs,
242245
)
243246

247+
@functools.cached_property
248+
def _injected_args_keys(self) -> frozenset[str]:
249+
fn = self.func or self.coroutine
250+
if fn is None:
251+
return _EMPTY_SET
252+
return frozenset(
253+
k
254+
for k, v in signature(fn).parameters.items()
255+
if _is_injected_arg_type(v.annotation)
256+
)
257+
244258

245259
def _filter_schema_args(func: Callable) -> list[str]:
246260
filter_args = list(FILTERED_ARGS)

libs/core/tests/unit_tests/test_tools.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import textwrap
77
import threading
88
from collections.abc import Callable
9+
from dataclasses import dataclass
910
from datetime import datetime
1011
from enum import Enum
1112
from functools import partial
@@ -55,6 +56,7 @@
5556
InjectedToolArg,
5657
InjectedToolCallId,
5758
SchemaAnnotationError,
59+
_DirectlyInjectedToolArg,
5860
_is_message_content_block,
5961
_is_message_content_type,
6062
get_all_basemodel_annotations,
@@ -2331,6 +2333,101 @@ def injected_tool(x: int, foo: Annotated[Foo, InjectedToolArg]) -> str:
23312333
assert injected_tool.invoke({"x": 5, "foo": Foo()}) == "bar"
23322334

23332335

2336+
@pytest.mark.parametrize("schema_format", ["model", "json_schema"])
2337+
def test_tool_allows_extra_runtime_args_with_custom_schema(
2338+
schema_format: Literal["model", "json_schema"],
2339+
) -> None:
2340+
"""Ensure runtime args are preserved even if not in the args schema."""
2341+
2342+
class InputSchema(BaseModel):
2343+
query: str
2344+
2345+
captured: dict[str, Any] = {}
2346+
2347+
@dataclass
2348+
class MyRuntime(_DirectlyInjectedToolArg):
2349+
some_obj: object
2350+
2351+
args_schema = (
2352+
InputSchema if schema_format == "model" else InputSchema.model_json_schema()
2353+
)
2354+
2355+
@tool(args_schema=args_schema)
2356+
def runtime_tool(query: str, runtime: MyRuntime) -> str:
2357+
"""Echo the query and capture runtime value."""
2358+
captured["runtime"] = runtime
2359+
return query
2360+
2361+
runtime_obj = object()
2362+
runtime = MyRuntime(some_obj=runtime_obj)
2363+
assert runtime_tool.invoke({"query": "hello", "runtime": runtime}) == "hello"
2364+
assert captured["runtime"] is runtime
2365+
2366+
2367+
def test_tool_injected_tool_call_id_with_custom_schema() -> None:
2368+
"""Ensure InjectedToolCallId works with custom args schema."""
2369+
2370+
class InputSchema(BaseModel):
2371+
x: int
2372+
2373+
@tool(args_schema=InputSchema)
2374+
def injected_tool(
2375+
x: int, tool_call_id: Annotated[str, InjectedToolCallId]
2376+
) -> ToolMessage:
2377+
"""Tool with injected tool_call_id and custom schema."""
2378+
return ToolMessage(str(x), tool_call_id=tool_call_id)
2379+
2380+
# Test that tool_call_id is properly injected even though not in custom schema
2381+
result = injected_tool.invoke(
2382+
{
2383+
"type": "tool_call",
2384+
"args": {"x": 42},
2385+
"name": "injected_tool",
2386+
"id": "test_call_id",
2387+
}
2388+
)
2389+
assert result == ToolMessage("42", tool_call_id="test_call_id")
2390+
2391+
# Test that it still raises error when invoked without a ToolCall
2392+
with pytest.raises(
2393+
ValueError,
2394+
match="When tool includes an InjectedToolCallId argument, "
2395+
"tool must always be invoked with a full model ToolCall",
2396+
):
2397+
injected_tool.invoke({"x": 42})
2398+
2399+
2400+
def test_tool_injected_arg_with_custom_schema() -> None:
2401+
"""Ensure InjectedToolArg works with custom args schema."""
2402+
2403+
class InputSchema(BaseModel):
2404+
query: str
2405+
2406+
class CustomContext:
2407+
"""Custom context object to be injected."""
2408+
2409+
def __init__(self, value: str) -> None:
2410+
self.value = value
2411+
2412+
captured: dict[str, Any] = {}
2413+
2414+
@tool(args_schema=InputSchema)
2415+
def search_tool(
2416+
query: str, context: Annotated[CustomContext, InjectedToolArg]
2417+
) -> str:
2418+
"""Search with custom context."""
2419+
captured["context"] = context
2420+
return f"Results for {query} with context {context.value}"
2421+
2422+
# Test that context is properly injected even though not in custom schema
2423+
ctx = CustomContext("test_context")
2424+
result = search_tool.invoke({"query": "hello", "context": ctx})
2425+
2426+
assert result == "Results for hello with context test_context"
2427+
assert captured["context"] is ctx
2428+
assert captured["context"].value == "test_context"
2429+
2430+
23342431
def test_tool_injected_tool_call_id() -> None:
23352432
@tool
23362433
def foo(x: int, tool_call_id: Annotated[str, InjectedToolCallId]) -> ToolMessage:

0 commit comments

Comments
 (0)