Skip to content

Commit 17f0716

Browse files
authored
fix(langchain_v1): remove non llm controllable params from tool message on invocation failure (#33625)
The LLM shouldn't be seeing parameters it cannot control in the ToolMessage error it gets when it invokes a tool with incorrect args. This fixes the behavior within langchain to address immediate issue. We may want to change the behavior in langchain_core as well to prevent validation of injected arguments. But this would be done in a separate change
1 parent 5acd34a commit 17f0716

File tree

3 files changed

+951
-23
lines changed

3 files changed

+951
-23
lines changed

libs/langchain_v1/langchain/tools/tool_node.py

Lines changed: 103 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def my_tool(x: int) -> str:
8989
from collections.abc import Sequence
9090

9191
from langgraph.runtime import Runtime
92+
from pydantic_core import ErrorDetails
9293

9394
# right now we use a dict as the default, can change this to AgentState, but depends
9495
# on if this lives in LangChain or LangGraph... ideally would have some typed
@@ -303,21 +304,40 @@ class ToolInvocationError(ToolException):
303304
"""
304305

305306
def __init__(
306-
self, tool_name: str, source: ValidationError, tool_kwargs: dict[str, Any]
307+
self,
308+
tool_name: str,
309+
source: ValidationError,
310+
tool_kwargs: dict[str, Any],
311+
filtered_errors: list[ErrorDetails] | None = None,
307312
) -> None:
308313
"""Initialize the ToolInvocationError.
309314
310315
Args:
311316
tool_name: The name of the tool that failed.
312317
source: The exception that occurred.
313318
tool_kwargs: The keyword arguments that were passed to the tool.
319+
filtered_errors: Optional list of filtered validation errors excluding
320+
injected arguments.
314321
"""
322+
# Format error display based on filtered errors if provided
323+
if filtered_errors is not None:
324+
# Manually format the filtered errors without URLs or fancy formatting
325+
error_str_parts = []
326+
for error in filtered_errors:
327+
loc_str = ".".join(str(loc) for loc in error.get("loc", ()))
328+
msg = error.get("msg", "Unknown error")
329+
error_str_parts.append(f"{loc_str}: {msg}")
330+
error_display_str = "\n".join(error_str_parts)
331+
else:
332+
error_display_str = str(source)
333+
315334
self.message = TOOL_INVOCATION_ERROR_TEMPLATE.format(
316-
tool_name=tool_name, tool_kwargs=tool_kwargs, error=source
335+
tool_name=tool_name, tool_kwargs=tool_kwargs, error=error_display_str
317336
)
318337
self.tool_name = tool_name
319338
self.tool_kwargs = tool_kwargs
320339
self.source = source
340+
self.filtered_errors = filtered_errors
321341
super().__init__(self.message)
322342

323343

@@ -442,6 +462,59 @@ def _infer_handled_types(handler: Callable[..., str]) -> tuple[type[Exception],
442462
return (Exception,)
443463

444464

465+
def _filter_validation_errors(
466+
validation_error: ValidationError,
467+
tool_to_state_args: dict[str, str | None],
468+
tool_to_store_arg: str | None,
469+
tool_to_runtime_arg: str | None,
470+
) -> list[ErrorDetails]:
471+
"""Filter validation errors to only include LLM-controlled arguments.
472+
473+
When a tool invocation fails validation, only errors for arguments that the LLM
474+
controls should be included in error messages. This ensures the LLM receives
475+
focused, actionable feedback about parameters it can actually fix. System-injected
476+
arguments (state, store, runtime) are filtered out since the LLM has no control
477+
over them.
478+
479+
This function also removes injected argument values from the `input` field in error
480+
details, ensuring that only LLM-provided arguments appear in error messages.
481+
482+
Args:
483+
validation_error: The Pydantic ValidationError raised during tool invocation.
484+
tool_to_state_args: Mapping of state argument names to state field names.
485+
tool_to_store_arg: Name of the store argument, if any.
486+
tool_to_runtime_arg: Name of the runtime argument, if any.
487+
488+
Returns:
489+
List of ErrorDetails containing only errors for LLM-controlled arguments,
490+
with system-injected argument values removed from the input field.
491+
"""
492+
injected_args = set(tool_to_state_args.keys())
493+
if tool_to_store_arg:
494+
injected_args.add(tool_to_store_arg)
495+
if tool_to_runtime_arg:
496+
injected_args.add(tool_to_runtime_arg)
497+
498+
filtered_errors: list[ErrorDetails] = []
499+
for error in validation_error.errors():
500+
# Check if error location contains any injected argument
501+
# error['loc'] is a tuple like ('field_name',) or ('field_name', 'nested_field')
502+
if error["loc"] and error["loc"][0] not in injected_args:
503+
# Create a copy of the error dict to avoid mutating the original
504+
error_copy: dict[str, Any] = {**error}
505+
506+
# Remove injected arguments from input_value if it's a dict
507+
if isinstance(error_copy.get("input"), dict):
508+
input_dict = error_copy["input"]
509+
input_copy = {k: v for k, v in input_dict.items() if k not in injected_args}
510+
error_copy["input"] = input_copy
511+
512+
# Cast is safe because ErrorDetails is a TypedDict compatible with this structure
513+
filtered_errors.append(error_copy) # type: ignore[arg-type]
514+
515+
return filtered_errors
516+
517+
445518
class _ToolNode(RunnableCallable):
446519
"""A node for executing tools in LangGraph workflows.
447520
@@ -623,17 +696,10 @@ def _func(
623696
)
624697
tool_runtimes.append(tool_runtime)
625698

626-
# Inject tool arguments (including runtime)
627-
628-
injected_tool_calls = []
699+
# Pass original tool calls without injection
629700
input_types = [input_type] * len(tool_calls)
630-
for call, tool_runtime in zip(tool_calls, tool_runtimes, strict=False):
631-
injected_call = self._inject_tool_args(call, tool_runtime) # type: ignore[arg-type]
632-
injected_tool_calls.append(injected_call)
633701
with get_executor_for_config(config) as executor:
634-
outputs = list(
635-
executor.map(self._run_one, injected_tool_calls, input_types, tool_runtimes)
636-
)
702+
outputs = list(executor.map(self._run_one, tool_calls, input_types, tool_runtimes))
637703

638704
return self._combine_tool_outputs(outputs, input_type)
639705

@@ -660,12 +726,10 @@ async def _afunc(
660726
)
661727
tool_runtimes.append(tool_runtime)
662728

663-
injected_tool_calls = []
729+
# Pass original tool calls without injection
664730
coros = []
665731
for call, tool_runtime in zip(tool_calls, tool_runtimes, strict=False):
666-
injected_call = self._inject_tool_args(call, tool_runtime) # type: ignore[arg-type]
667-
injected_tool_calls.append(injected_call)
668-
coros.append(self._arun_one(injected_call, input_type, tool_runtime)) # type: ignore[arg-type]
732+
coros.append(self._arun_one(call, input_type, tool_runtime)) # type: ignore[arg-type]
669733
outputs = await asyncio.gather(*coros)
670734

671735
return self._combine_tool_outputs(outputs, input_type)
@@ -742,13 +806,23 @@ def _execute_tool_sync(
742806
msg = f"Tool {call['name']} is not registered with ToolNode"
743807
raise TypeError(msg)
744808

745-
call_args = {**call, "type": "tool_call"}
809+
# Inject state, store, and runtime right before invocation
810+
injected_call = self._inject_tool_args(call, request.runtime)
811+
call_args = {**injected_call, "type": "tool_call"}
746812

747813
try:
748814
try:
749815
response = tool.invoke(call_args, config)
750816
except ValidationError as exc:
751-
raise ToolInvocationError(call["name"], exc, call["args"]) from exc
817+
# Filter out errors for injected arguments
818+
filtered_errors = _filter_validation_errors(
819+
exc,
820+
self._tool_to_state_args.get(call["name"], {}),
821+
self._tool_to_store_arg.get(call["name"]),
822+
self._tool_to_runtime_arg.get(call["name"]),
823+
)
824+
# Use original call["args"] without injected values for error reporting
825+
raise ToolInvocationError(call["name"], exc, call["args"], filtered_errors) from exc
752826

753827
# GraphInterrupt is a special exception that will always be raised.
754828
# It can be triggered in the following scenarios,
@@ -887,13 +961,23 @@ async def _execute_tool_async(
887961
msg = f"Tool {call['name']} is not registered with ToolNode"
888962
raise TypeError(msg)
889963

890-
call_args = {**call, "type": "tool_call"}
964+
# Inject state, store, and runtime right before invocation
965+
injected_call = self._inject_tool_args(call, request.runtime)
966+
call_args = {**injected_call, "type": "tool_call"}
891967

892968
try:
893969
try:
894970
response = await tool.ainvoke(call_args, config)
895971
except ValidationError as exc:
896-
raise ToolInvocationError(call["name"], exc, call["args"]) from exc
972+
# Filter out errors for injected arguments
973+
filtered_errors = _filter_validation_errors(
974+
exc,
975+
self._tool_to_state_args.get(call["name"], {}),
976+
self._tool_to_store_arg.get(call["name"]),
977+
self._tool_to_runtime_arg.get(call["name"]),
978+
)
979+
# Use original call["args"] without injected values for error reporting
980+
raise ToolInvocationError(call["name"], exc, call["args"], filtered_errors) from exc
897981

898982
# GraphInterrupt is a special exception that will always be raised.
899983
# It can be triggered in the following scenarios,

libs/langchain_v1/tests/unit_tests/agents/test_tool_node.py

Lines changed: 170 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
TypeVar,
1111
)
1212
from unittest.mock import Mock
13+
from langchain.agents import create_agent
14+
from langchain.agents.middleware.types import AgentState
1315

1416
import pytest
1517
from langchain_core.messages import (
@@ -302,6 +304,172 @@ def test_tool_node_error_handling_default_exception() -> None:
302304
)
303305

304306

307+
@pytest.mark.skipif(
308+
sys.version_info >= (3, 14), reason="Pydantic model rebuild issue in Python 3.14"
309+
)
310+
def test_tool_invocation_error_excludes_injected_state() -> None:
311+
"""Test that tool invocation errors only include LLM-controllable arguments.
312+
313+
When a tool has InjectedState parameters and the LLM makes an incorrect
314+
invocation (e.g., missing required arguments), the error message should only
315+
contain the arguments from the tool call that the LLM controls. This ensures
316+
the LLM receives relevant context to correct its mistakes, without being
317+
distracted by system-injected parameters it has no control over.
318+
319+
This test uses create_agent to ensure the behavior works in a full agent context.
320+
"""
321+
322+
# Define a custom state schema with injected data
323+
class TestState(AgentState):
324+
secret_data: str # Example of state data not controlled by LLM
325+
326+
@dec_tool
327+
def tool_with_injected_state(
328+
some_val: int,
329+
state: Annotated[TestState, InjectedState],
330+
) -> str:
331+
"""Tool that uses injected state."""
332+
return f"some_val: {some_val}"
333+
334+
# Create a fake model that makes an incorrect tool call (missing 'some_val')
335+
# Then returns no tool calls on the second iteration to end the loop
336+
model = FakeToolCallingModel(
337+
tool_calls=[
338+
[
339+
{
340+
"name": "tool_with_injected_state",
341+
"args": {"wrong_arg": "value"}, # Missing required 'some_val'
342+
"id": "call_1",
343+
}
344+
],
345+
[], # No tool calls on second iteration to end the loop
346+
]
347+
)
348+
349+
# Create an agent with the tool and custom state schema
350+
agent = create_agent(
351+
model=model,
352+
tools=[tool_with_injected_state],
353+
state_schema=TestState,
354+
)
355+
356+
# Invoke the agent with injected state data
357+
result = agent.invoke(
358+
{
359+
"messages": [HumanMessage("Test message")],
360+
"secret_data": "sensitive_secret_123",
361+
}
362+
)
363+
364+
# Find the tool error message
365+
tool_messages = [m for m in result["messages"] if m.type == "tool"]
366+
assert len(tool_messages) == 1
367+
tool_message = tool_messages[0]
368+
assert tool_message.status == "error"
369+
370+
# The error message should contain only the LLM-provided args (wrong_arg)
371+
# and NOT the system-injected state (secret_data)
372+
assert "{'wrong_arg': 'value'}" in tool_message.content
373+
assert "secret_data" not in tool_message.content
374+
assert "sensitive_secret_123" not in tool_message.content
375+
376+
377+
@pytest.mark.skipif(
378+
sys.version_info >= (3, 14), reason="Pydantic model rebuild issue in Python 3.14"
379+
)
380+
async def test_tool_invocation_error_excludes_injected_state_async() -> None:
381+
"""Test that async tool invocation errors only include LLM-controllable arguments.
382+
383+
This test verifies that the async execution path (_execute_tool_async and _arun_one)
384+
properly filters validation errors to exclude system-injected arguments, ensuring
385+
the LLM receives only relevant context for correction.
386+
"""
387+
388+
# Define a custom state schema
389+
class TestState(AgentState):
390+
internal_data: str
391+
392+
@dec_tool
393+
async def async_tool_with_injected_state(
394+
query: str,
395+
max_results: int,
396+
state: Annotated[TestState, InjectedState],
397+
) -> str:
398+
"""Async tool that uses injected state."""
399+
return f"query: {query}, max_results: {max_results}"
400+
401+
# Create a fake model that makes an incorrect tool call
402+
# - query has wrong type (int instead of str)
403+
# - max_results is missing
404+
model = FakeToolCallingModel(
405+
tool_calls=[
406+
[
407+
{
408+
"name": "async_tool_with_injected_state",
409+
"args": {"query": 999}, # Wrong type, missing max_results
410+
"id": "call_async_1",
411+
}
412+
],
413+
[], # End the loop
414+
]
415+
)
416+
417+
# Create an agent with the async tool
418+
agent = create_agent(
419+
model=model,
420+
tools=[async_tool_with_injected_state],
421+
state_schema=TestState,
422+
)
423+
424+
# Invoke with state data
425+
result = await agent.ainvoke(
426+
{
427+
"messages": [HumanMessage("Test async")],
428+
"internal_data": "secret_internal_value_xyz",
429+
}
430+
)
431+
432+
# Find the tool error message
433+
tool_messages = [m for m in result["messages"] if m.type == "tool"]
434+
assert len(tool_messages) == 1
435+
tool_message = tool_messages[0]
436+
assert tool_message.status == "error"
437+
438+
# Verify error mentions LLM-controlled parameters only
439+
content = tool_message.content
440+
assert "query" in content.lower(), "Error should mention 'query' (LLM-controlled)"
441+
assert "max_results" in content.lower(), "Error should mention 'max_results' (LLM-controlled)"
442+
443+
# Verify system-injected state does not appear in the validation errors
444+
# This keeps the error focused on what the LLM can actually fix
445+
assert "internal_data" not in content, (
446+
"Error should NOT mention 'internal_data' (system-injected field)"
447+
)
448+
assert "secret_internal_value" not in content, (
449+
"Error should NOT contain system-injected state values"
450+
)
451+
452+
# Verify only LLM-controlled parameters are in the error list
453+
# Should see "query" and "max_results" errors, but not "state"
454+
lines = content.split("\n")
455+
error_lines = [line.strip() for line in lines if line.strip()]
456+
# Find lines that look like field names (single words at start of line)
457+
field_errors = [
458+
line
459+
for line in error_lines
460+
if line
461+
and not line.startswith("input")
462+
and not line.startswith("field")
463+
and not line.startswith("error")
464+
and not line.startswith("please")
465+
and len(line.split()) <= 2
466+
]
467+
# Verify system-injected 'state' is not in the field error list
468+
assert not any("state" == field.lower() for field in field_errors), (
469+
"The field 'state' (system-injected) should not appear in validation errors"
470+
)
471+
472+
305473
async def test_tool_node_error_handling() -> None:
306474
def handle_all(e: ValueError | ToolException | ToolInvocationError):
307475
return TOOL_CALL_ERROR_TEMPLATE.format(error=repr(e))
@@ -355,10 +523,8 @@ def handle_all(e: ValueError | ToolException | ToolInvocationError):
355523
result_error["messages"][1].content
356524
== f"Error: {ToolException('Test error')!r}\n Please fix your mistakes."
357525
)
358-
assert (
359-
"ValidationError" in result_error["messages"][2].content
360-
or "validation error" in result_error["messages"][2].content
361-
)
526+
# Check that the validation error contains the field name
527+
assert "some_other_val" in result_error["messages"][2].content
362528

363529
assert result_error["messages"][0].tool_call_id == "some id"
364530
assert result_error["messages"][1].tool_call_id == "some other id"

0 commit comments

Comments
 (0)