From 06ce94ca06151e87e355208459cd70fb3c77131a Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Sun, 5 Oct 2025 22:34:18 -0400 Subject: [PATCH 01/12] x --- .../langchain/agents/middleware/__init__.py | 6 + .../agents/middleware/tool_error_handling.py | 409 ++++++++++++++++++ .../langchain/agents/middleware/types.py | 41 +- .../langchain/agents/middleware_agent.py | 138 +++++- .../langchain_v1/langchain/tools/tool_node.py | 290 +++++++++++-- .../agents/middleware/test_on_tool_call.py | 402 +++++++++++++++++ .../unit_tests/tools/test_on_tool_call.py | 389 +++++++++++++++++ 7 files changed, 1631 insertions(+), 44 deletions(-) create mode 100644 libs/langchain_v1/langchain/agents/middleware/tool_error_handling.py create mode 100644 libs/langchain_v1/tests/unit_tests/agents/middleware/test_on_tool_call.py create mode 100644 libs/langchain_v1/tests/unit_tests/tools/test_on_tool_call.py diff --git a/libs/langchain_v1/langchain/agents/middleware/__init__.py b/libs/langchain_v1/langchain/agents/middleware/__init__.py index 5915b75164a74..21b0debba87a6 100644 --- a/libs/langchain_v1/langchain/agents/middleware/__init__.py +++ b/libs/langchain_v1/langchain/agents/middleware/__init__.py @@ -7,6 +7,10 @@ from .prompt_caching import AnthropicPromptCachingMiddleware from .summarization import SummarizationMiddleware from .tool_call_limit import ToolCallLimitMiddleware +from .tool_error_handling import ( + ErrorToMessageMiddleware, + RetryMiddleware, +) from .tool_selection import LLMToolSelectorMiddleware from .types import ( AgentMiddleware, @@ -24,6 +28,7 @@ "AgentState", # should move to langchain-anthropic if we decide to keep it "AnthropicPromptCachingMiddleware", + "ErrorToMessageMiddleware", "HumanInTheLoopMiddleware", "LLMToolSelectorMiddleware", "ModelFallbackMiddleware", @@ -31,6 +36,7 @@ "PIIDetectionError", "PIIMiddleware", "PlanningMiddleware", + "RetryMiddleware", "SummarizationMiddleware", "ToolCallLimitMiddleware", "after_model", diff --git a/libs/langchain_v1/langchain/agents/middleware/tool_error_handling.py b/libs/langchain_v1/langchain/agents/middleware/tool_error_handling.py new file mode 100644 index 0000000000000..bc55e9194de15 --- /dev/null +++ b/libs/langchain_v1/langchain/agents/middleware/tool_error_handling.py @@ -0,0 +1,409 @@ +"""Middleware for handling tool execution errors in agents. + +This module provides composable middleware for error handling, retries, +and error-to-message conversion in tool execution workflows. +""" + +from __future__ import annotations + +import inspect +import logging +import time +from collections.abc import Callable +from typing import TYPE_CHECKING, Union, get_args, get_origin, get_type_hints + +from langchain_core.messages import ToolMessage + +from langchain.agents.middleware.types import AgentMiddleware + +# Import ToolResponse locally to avoid circular import +from langchain.tools.tool_node import ToolResponse + +if TYPE_CHECKING: + from collections.abc import Generator + from types import UnionType + + from langchain.tools.tool_node import ToolRequest, ToolResponse + +logger = logging.getLogger(__name__) + + +# Default retriable exception types - transient errors that may succeed on retry +DEFAULT_RETRIABLE_EXCEPTIONS = ( + # Network and connection errors + ConnectionError, + TimeoutError, + # HTTP client errors are typically not retriable, but these are exceptions: + # - 429: Rate limit (temporary) + # - 503: Service unavailable (temporary) + # Note: Specific HTTP libraries may define their own exception types +) + + +def _infer_retriable_types( + predicate: Callable[[Exception], bool], +) -> tuple[type[Exception], ...]: + """Infer exception types from a retry predicate function's type annotations. + + Analyzes the type annotations of a predicate function to determine which + exception types it's designed to handle for retry decisions. + + Args: + predicate: A callable that takes an exception and returns whether to retry. + The first parameter should be type-annotated with exception type(s). + + Returns: + Tuple of exception types that the predicate handles. Returns (Exception,) + if no specific type information is available. + + Raises: + ValueError: If the predicate's annotation contains non-Exception types. + """ + sig = inspect.signature(predicate) + params = list(sig.parameters.values()) + if params: + # Skip self/cls if it's a method + if params[0].name in ["self", "cls"] and len(params) == 2: + first_param = params[1] + else: + first_param = params[0] + + type_hints = get_type_hints(predicate) + if first_param.name in type_hints: + origin = get_origin(first_param.annotation) + # Handle Union types + if origin in [Union, UnionType]: # type: ignore[has-type] + args = get_args(first_param.annotation) + if all(isinstance(arg, type) and issubclass(arg, Exception) for arg in args): + return tuple(args) + msg = ( + "All types in retry predicate annotation must be Exception types. " + "For example, `def should_retry(e: Union[TimeoutError, ConnectionError]) -> bool`. " + f"Got '{first_param.annotation}' instead." + ) + raise ValueError(msg) + + # Handle single exception type + exception_type = type_hints[first_param.name] + if isinstance(exception_type, type) and issubclass(exception_type, Exception): + return (exception_type,) + msg = ( + "Retry predicate must be annotated with Exception type(s). " + "For example, `def should_retry(e: TimeoutError) -> bool` or " + "`def should_retry(e: Union[TimeoutError, ConnectionError]) -> bool`. " + f"Got '{exception_type}' instead." + ) + raise ValueError(msg) + + # No type information - return Exception for backward compatibility + return (Exception,) + + +class RetryMiddleware(AgentMiddleware): + """Retry failed tool calls with constant delay. + + This middleware catches tool execution errors and retries them up to a maximum + number of attempts with a constant delay between retries. It operates at the + outermost layer of middleware composition to catch all errors. + + Examples: + Retry only network errors: + + ```python + from langchain.agents.middleware import RetryMiddleware + + middleware = RetryMiddleware( + max_retries=3, + delay=2.0, + retry_on=(TimeoutError, ConnectionError), + ) + + agent = create_agent( + model="openai:gpt-4o", + tools=[my_tool], + middleware=[middleware], + ) + ``` + + Use predicate function for custom retry logic: + + ```python + from langchain.tools.tool_node import ToolInvocationError + + + def should_retry(e: Exception) -> bool: + # Don't retry validation errors from LLM + if isinstance(e, ToolInvocationError): + return False + # Retry network errors + if isinstance(e, (TimeoutError, ConnectionError)): + return True + return False + + + middleware = RetryMiddleware( + max_retries=3, + retry_on=should_retry, + ) + ``` + + Compose with error conversion: + + ```python + from langchain.agents.middleware import ( + RetryMiddleware, + ErrorToMessageMiddleware, + ) + + agent = create_agent( + model="openai:gpt-4o", + tools=[my_tool], + middleware=[ + # Outer: retry network errors + RetryMiddleware( + max_retries=3, + delay=2.0, + retry_on=(TimeoutError, ConnectionError), + ), + # Inner: convert validation errors to messages + ErrorToMessageMiddleware( + exception_types=(ValidationError,), + ), + ], + ) + ``` + """ + + def __init__( + self, + *, + max_retries: int = 3, + delay: float = 1.0, + retry_on: type[Exception] + | tuple[type[Exception], ...] + | Callable[[Exception], bool] = DEFAULT_RETRIABLE_EXCEPTIONS, + ) -> None: + """Initialize retry middleware. + + Args: + max_retries: Maximum number of retry attempts. Total attempts will be + max_retries + 1 (initial attempt plus retries). + delay: Constant delay in seconds between retry attempts. + retry_on: Specifies which exceptions should be retried. Can be: + - **type[Exception]**: Retry only this exception type + - **tuple[type[Exception], ...]**: Retry these exception types + - **Callable[[Exception], bool]**: Predicate function that returns + True if the exception should be retried. Type annotations on the + callable are used to filter which exceptions are passed to it. + Defaults to ``DEFAULT_RETRIABLE_EXCEPTIONS`` (ConnectionError, TimeoutError). + """ + super().__init__() + if max_retries < 0: + msg = "max_retries must be non-negative" + raise ValueError(msg) + if delay < 0: + msg = "delay must be non-negative" + raise ValueError(msg) + + self.max_retries = max_retries + self.delay = delay + self._retry_on = retry_on + + # Determine which exception types to check + if isinstance(retry_on, type) and issubclass(retry_on, Exception): + self._retriable_types = (retry_on,) + self._retry_predicate = None + elif isinstance(retry_on, tuple): + if not retry_on: + msg = "retry_on tuple must not be empty" + raise ValueError(msg) + if not all(isinstance(t, type) and issubclass(t, Exception) for t in retry_on): + msg = "All elements in retry_on tuple must be Exception types" + raise ValueError(msg) + self._retriable_types = retry_on + self._retry_predicate = None + elif callable(retry_on): + self._retriable_types = _infer_retriable_types(retry_on) + self._retry_predicate = retry_on + else: + msg = ( + "retry_on must be an Exception type, tuple of Exception types, " + f"or callable. Got {type(retry_on)}" + ) + raise ValueError(msg) + + def on_tool_call( + self, request: ToolRequest + ) -> Generator[ToolRequest, ToolResponse, ToolResponse]: + """Retry tool execution on failures.""" + for attempt in range(1, self.max_retries + 2): # +1 for initial, +1 for inclusive + response = yield request + + # Success - return immediately + if response.action == "return": + return response + + # Error - check if we should retry + if response.action == "raise": + exception = response.exception + if exception is None: + msg = "ToolResponse with action='raise' must have an exception" + raise ValueError(msg) + + # Check if this exception type is retriable + if not isinstance(exception, self._retriable_types): + logger.debug( + "Exception %s is not retriable for tool %s", + type(exception).__name__, + request.tool_call["name"], + ) + return response + + # If predicate is provided, check if we should retry + if self._retry_predicate is not None: + if not self._retry_predicate(exception): + logger.debug( + "Retry predicate returned False for %s in tool %s", + type(exception).__name__, + request.tool_call["name"], + ) + return response + + # Last attempt - return error + if attempt > self.max_retries: + logger.debug( + "Max retries (%d) reached for tool %s", + self.max_retries, + request.tool_call["name"], + ) + return response + + # Retry - log and delay + logger.debug( + "Retrying tool %s (attempt %d/%d) after error: %s", + request.tool_call["name"], + attempt, + self.max_retries + 1, + type(exception).__name__, + ) + time.sleep(self.delay) + continue + + # Should never reach here + msg = f"Unexpected control flow in RetryMiddleware for tool {request.tool_call['name']}" + raise RuntimeError(msg) + + +class ErrorToMessageMiddleware(AgentMiddleware): + """Convert specific exception types to ToolMessages. + + This middleware intercepts errors and converts them into ToolMessages that + can be sent back to the model as feedback. This is useful for errors caused + by invalid model inputs where the model needs feedback to correct its behavior. + + Examples: + Convert validation errors to messages: + + ```python + from langchain.agents.middleware import ErrorToMessageMiddleware + from langchain.tools.tool_node import ToolInvocationError + + middleware = ErrorToMessageMiddleware( + exception_types=(ToolInvocationError,), + message_template="Invalid arguments: {error}. Please fix and try again.", + ) + + agent = create_agent( + model="openai:gpt-4o", + tools=[my_tool], + middleware=[middleware], + ) + ``` + + Compose with retry for network errors: + + ```python + from langchain.agents.middleware import ( + RetryMiddleware, + ErrorToMessageMiddleware, + ) + + agent = create_agent( + model="openai:gpt-4o", + tools=[my_tool], + middleware=[ + # Outer: retry all errors + RetryMiddleware(max_retries=3), + # Inner: convert validation errors to messages + ErrorToMessageMiddleware( + exception_types=(ValidationError,), + ), + ], + ) + ``` + """ + + def __init__( + self, + *, + exception_types: tuple[type[Exception], ...], + message_template: str = "Error: {error}", + ) -> None: + """Initialize error conversion middleware. + + Args: + exception_types: Tuple of exception types to convert to messages. + message_template: Template string for error messages. Can use ``{error}`` + placeholder for the exception string representation. + """ + super().__init__() + if not exception_types: + msg = "exception_types must not be empty" + raise ValueError(msg) + + self.exception_types = exception_types + self.message_template = message_template + + def on_tool_call( + self, request: ToolRequest + ) -> Generator[ToolRequest, ToolResponse, ToolResponse]: + """Convert matching errors to ToolMessages.""" + response = yield request + + # Success - pass through + if response.action == "return": + return response + + # Error - check if we should convert + if response.action == "raise": + exception = response.exception + if exception is None: + msg = "ToolResponse with action='raise' must have an exception" + raise ValueError(msg) + + # Check if exception type matches + if not isinstance(exception, self.exception_types): + return response + + # Convert to ToolMessage + logger.debug( + "Converting %s to ToolMessage for tool %s", + type(exception).__name__, + request.tool_call["name"], + ) + + error_message = self.message_template.format(error=str(exception)) + tool_message = ToolMessage( + content=error_message, + name=request.tool_call["name"], + tool_call_id=request.tool_call["id"], + status="error", + ) + + return ToolResponse( + action="return", + result=tool_message, + exception=exception, # Preserve for logging/debugging + ) + + return response diff --git a/libs/langchain_v1/langchain/agents/middleware/types.py b/libs/langchain_v1/langchain/agents/middleware/types.py index 6aea90ab9fc35..6457cc051b1ee 100644 --- a/libs/langchain_v1/langchain/agents/middleware/types.py +++ b/libs/langchain_v1/langchain/agents/middleware/types.py @@ -19,7 +19,7 @@ from langchain_core.runnables import run_in_executor if TYPE_CHECKING: - from collections.abc import Awaitable + from collections.abc import Awaitable, Generator # needed as top level import for pydantic schema generation on AgentState from langchain_core.messages import AnyMessage # noqa: TC002 @@ -36,6 +36,7 @@ from langgraph.types import Command from langchain.agents.structured_output import ResponseFormat + from langchain.tools.tool_node import ToolRequest, ToolResponse __all__ = [ "AgentMiddleware", @@ -215,6 +216,44 @@ async def aretry_model_request( None, self.retry_model_request, error, request, state, runtime, attempt ) + def on_tool_call( + self, + request: ToolRequest, + ) -> Generator[ToolRequest, ToolResponse, ToolResponse]: + """Intercept tool execution to implement retry logic, monitoring, or request modification. + + Provides generator-based control over the complete tool execution lifecycle. + Multiple middleware can define this hook; they compose automatically with + outer middleware wrapping inner middleware (first defined = outermost layer). + + Generator Protocol: + 1. Yield a ToolRequest (potentially modified from the input) + 2. Receive a ToolResponse via .send() + 3. Optionally yield again to retry + 4. Return the final ToolResponse to propagate + + Args: + request: Tool invocation details including tool_call, tool instance, and config. + + Returns: + Generator for request/response interception. + + Example: + Retry on rate limit with exponential backoff: + + ```python + def on_tool_call(self, request): + for attempt in range(3): + response = yield request + if response.action == "return": + return response + if "rate limit" in str(response.exception): + time.sleep(2**attempt) + continue + return response + ``` + """ + class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]): """Callable with AgentState and Runtime as arguments.""" diff --git a/libs/langchain_v1/langchain/agents/middleware_agent.py b/libs/langchain_v1/langchain/agents/middleware_agent.py index b4f6370446b05..2cae5876a83fc 100644 --- a/libs/langchain_v1/langchain/agents/middleware_agent.py +++ b/libs/langchain_v1/langchain/agents/middleware_agent.py @@ -1,7 +1,7 @@ """Middleware agent implementation.""" import itertools -from collections.abc import Callable, Sequence +from collections.abc import Callable, Generator, Sequence from typing import Annotated, Any, cast, get_args, get_origin, get_type_hints from langchain_core.language_models.chat_models import BaseChatModel @@ -35,12 +35,122 @@ ) from langchain.chat_models import init_chat_model from langchain.tools import ToolNode +from langchain.tools.tool_node import ToolCallHandler, ToolRequest, ToolResponse STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes." ResponseT = TypeVar("ResponseT") +def _chain_tool_call_handlers( + handlers: list[ToolCallHandler], +) -> ToolCallHandler | None: + """Compose multiple tool call handlers into a single middleware stack. + + Args: + handlers: Handlers in middleware order (first = outermost layer). + + Returns: + Single composed handler, or None if handlers is empty. + + Example: + ```python + # Auth middleware (outer) + rate limit (inner) + def auth(req): + resp = yield req + if "unauthorized" in str(resp.exception): + refresh_token() + resp = yield req # Retry + return resp + + + def rate_limit(req): + for attempt in range(3): + resp = yield req + if "rate limit" not in str(resp.exception): + return resp + time.sleep(2**attempt) + return resp + + + handler = _chain_tool_call_handlers([auth, rate_limit]) + # Request: auth -> rate_limit -> tool + # Response: tool -> rate_limit -> auth + ``` + """ + if not handlers: + return None + + if len(handlers) == 1: + return handlers[0] + + def _extract_return_value(stop_iteration: StopIteration) -> ToolResponse: + """Extract ToolResponse from StopIteration, validating protocol compliance.""" + if stop_iteration.value is None: + msg = "on_tool_call handler must explicitly return a ToolResponse" + raise ValueError(msg) + return stop_iteration.value + + def compose_two(outer: ToolCallHandler, inner: ToolCallHandler) -> ToolCallHandler: + """Compose two handlers where outer wraps inner.""" + + def composed( + request: ToolRequest, + ) -> Generator[ToolRequest, ToolResponse, ToolResponse]: + outer_gen = outer(request) + + # Initialize outer generator + try: + outer_request = next(outer_gen) + except StopIteration as e: + return _extract_return_value(e) + + # Outer retry loop + while True: + inner_gen = inner(outer_request) + + # Initialize inner generator + try: + inner_request = next(inner_gen) + except StopIteration as e: + # Inner returned immediately - send to outer + inner_response = _extract_return_value(e) + try: + outer_request = outer_gen.send(inner_response) + continue # Outer retrying + except StopIteration as e: + return _extract_return_value(e) + + # Inner retry loop - yield to next layer (or tool) + while True: + tool_response = yield inner_request + + try: + inner_request = inner_gen.send(tool_response) + # Inner retrying - continue inner loop + except StopIteration as e: + # Inner done - send response to outer + inner_response = _extract_return_value(e) + break + + # Send inner's final response to outer + try: + outer_request = outer_gen.send(inner_response) + # Outer retrying - continue outer loop + except StopIteration as e: + # Outer done - return final response + return _extract_return_value(e) + + return composed + + # Compose right-to-left: handlers[0](handlers[1](...(handlers[-1](tool)))) + result = handlers[-1] + for handler in reversed(handlers[:-1]): + result = compose_two(handler, result) + + return result + + def _resolve_schema(schemas: set[type], schema_name: str, omit_flag: str | None = None) -> type: """Resolve schema by merging schemas and optionally respecting OmitFromSchema annotations. @@ -226,6 +336,20 @@ def create_agent( # noqa: PLR0915 structured_output_tools[structured_tool_info.tool.name] = structured_tool_info middleware_tools = [t for m in middleware for t in getattr(m, "tools", [])] + # Validate middleware and collect handlers + assert len({m.name for m in middleware}) == len(middleware), ( # noqa: S101 + "Please remove duplicate middleware instances." + ) + middleware_w_on_tool_call = [ + m for m in middleware if m.__class__.on_tool_call is not AgentMiddleware.on_tool_call + ] + + # Chain all on_tool_call handlers into a single composed handler + on_tool_call_handler = None + if middleware_w_on_tool_call: + handlers = [m.on_tool_call for m in middleware_w_on_tool_call] + on_tool_call_handler = _chain_tool_call_handlers(handlers) + # Setup tools tool_node: ToolNode | None = None if isinstance(tools, list): @@ -237,7 +361,11 @@ def create_agent( # noqa: PLR0915 available_tools = middleware_tools + regular_tools # Only create ToolNode if we have client-side tools - tool_node = ToolNode(tools=available_tools) if available_tools else None + tool_node = ( + ToolNode(tools=available_tools, on_tool_call=on_tool_call_handler) + if available_tools + else None + ) # Default tools for ModelRequest initialization # Include built-ins and regular tools (can be changed dynamically by middleware) @@ -248,7 +376,7 @@ def create_agent( # noqa: PLR0915 if tool_node: # Add middleware tools to existing ToolNode available_tools = list(tool_node.tools_by_name.values()) + middleware_tools - tool_node = ToolNode(available_tools) + tool_node = ToolNode(available_tools, on_tool_call=on_tool_call_handler) # default_tools includes all client-side tools (no built-ins or structured tools) default_tools = available_tools @@ -256,10 +384,6 @@ def create_agent( # noqa: PLR0915 # No tools provided, only middleware_tools available default_tools = middleware_tools - # validate middleware - assert len({m.name for m in middleware}) == len(middleware), ( # noqa: S101 - "Please remove duplicate middleware instances." - ) middleware_w_before = [ m for m in middleware diff --git a/libs/langchain_v1/langchain/tools/tool_node.py b/libs/langchain_v1/langchain/tools/tool_node.py index 2c1df81cf2962..1a392413b4095 100644 --- a/libs/langchain_v1/langchain/tools/tool_node.py +++ b/libs/langchain_v1/langchain/tools/tool_node.py @@ -38,15 +38,15 @@ def my_tool(x: int) -> str: import asyncio import inspect import json +from collections.abc import Callable, Generator, Sequence from copy import copy, deepcopy -from dataclasses import replace +from dataclasses import dataclass, replace from types import UnionType from typing import ( TYPE_CHECKING, Annotated, Any, Literal, - Optional, Union, cast, get_args, @@ -79,8 +79,6 @@ def my_tool(x: int) -> str: from pydantic import BaseModel, ValidationError if TYPE_CHECKING: - from collections.abc import Callable, Sequence - from langchain_core.runnables import RunnableConfig from langgraph.store.base import BaseStore @@ -100,6 +98,61 @@ def my_tool(x: int) -> str: ) +@dataclass() +class ToolRequest: + """Request passed to on_tool_call handler before tool execution. + + Attributes: + tool_call: The tool call dict containing name, args, and id. + tool: The BaseTool instance that will be invoked. + config: Runnable configuration for this execution. + + Note: + tool_call["args"] can be mutated directly to modify arguments. + """ + + tool_call: ToolCall + tool: BaseTool + config: RunnableConfig + + +@dataclass() +class ToolResponse: + """Response returned from on_tool_call handler after tool execution. + + The action field determines control flow: + - "return": Handler completed successfully, use result + - "raise": Handler wants to propagate the exception + + Attributes: + action: Control flow directive ("return" or "raise"). + result: ToolMessage or Command when action="return". + exception: The exception when action="raise", or for logging when + action="return" with an error ToolMessage. + """ + + action: Literal["return", "raise"] + result: ToolMessage | Command | None = None + exception: Exception | None = None + + def __post_init__(self) -> None: + """Validate that required fields are present based on action.""" + if self.action == "return" and self.result is None: + msg = "action='return' requires a result" + raise ValueError(msg) + if self.action == "raise" and self.exception is None: + msg = "action='raise' requires an exception" + raise ValueError(msg) + + +ToolCallHandler = Callable[[ToolRequest], Generator[ToolRequest, ToolResponse, ToolResponse]] +"""Generator-based handler that intercepts tool execution. + +Receives a ToolRequest, yields modified ToolRequests, receives ToolResponses, +and returns a final ToolResponse. Supports multiple yields for retry logic. +""" + + def msg_content_output(output: Any) -> str | list[dict]: """Convert tool output to valid message content format. @@ -300,8 +353,8 @@ class ToolNode(RunnableCallable): Output format depends on input type and tool behavior: **For Regular tools**: - - Dict input → ``{"messages": [ToolMessage(...)]}`` - - List input → ``[ToolMessage(...)]`` + - Dict input -> ``{"messages": [ToolMessage(...)]}`` + - List input -> ``[ToolMessage(...)]`` **For Command tools**: - Returns ``[Command(...)]`` or mixed list with regular tool outputs @@ -335,6 +388,12 @@ class ToolNode(RunnableCallable): - catches tool invocation errors (due to invalid arguments provided by the model) and returns a descriptive error message - ignores tool execution errors (they will be re-raised) + on_tool_call: Optional handler to intercept tool execution. Receives + ``ToolRequest``, yields potentially modified requests, receives + ``ToolResponse`` via ``.send()``, and returns final ``ToolResponse``. + Enables retries, argument modification, and custom error handling. + Defaults to ``None``. + messages_key: The key in the state dictionary that contains the message list. This same key will be used for the output ToolMessages. Defaults to "messages". @@ -378,6 +437,23 @@ def handle_errors(e: ValueError) -> str: tool_node = ToolNode([my_tool], handle_tool_errors=handle_errors) ``` + + Intercepting tool calls: + + ```python + from langchain.tools.tool_node import ToolRequest, ToolResponse + + def retry_handler(request): + \"\"\"Retry failed tool calls up to 3 times.\"\"\" + for attempt in range(3): + response = yield request + if response.action == "return": + return response + # Retry on error + return response # Final attempt + + tool_node = ToolNode([my_tool], on_tool_call=retry_handler) + ``` """ # noqa: E501 name: str = "tools" @@ -393,6 +469,7 @@ def __init__( | Callable[..., str] | type[Exception] | tuple[type[Exception], ...] = _default_handle_tool_errors, + on_tool_call: ToolCallHandler | None = None, messages_key: str = "messages", ) -> None: """Initialize the ToolNode with the provided tools and configuration. @@ -402,6 +479,7 @@ def __init__( name: Node name for graph identification. tags: Optional metadata tags. handle_tool_errors: Error handling configuration. + on_tool_call: Optional handler to intercept tool execution. messages_key: State key containing messages. """ super().__init__(self._func, self._afunc, name=name, tags=tags, trace=False) @@ -409,6 +487,7 @@ def __init__( self._tool_to_state_args: dict[str, dict[str, str | None]] = {} self._tool_to_store_arg: dict[str, str | None] = {} self._handle_tool_errors = handle_tool_errors + self._on_tool_call = on_tool_call self._messages_key = messages_key for tool in tools: if not isinstance(tool, BaseTool): @@ -429,7 +508,7 @@ def _func( input: list[AnyMessage] | dict[str, Any] | BaseModel, config: RunnableConfig, *, - store: Optional[BaseStore], # noqa: UP045 + store: Optional[BaseStore], ) -> Any: tool_calls, input_type = self._parse_input(input, store) config_list = get_config_list(config, len(tool_calls)) @@ -444,7 +523,7 @@ async def _afunc( input: list[AnyMessage] | dict[str, Any] | BaseModel, config: RunnableConfig, *, - store: Optional[BaseStore], # noqa: UP045 + store: Optional[BaseStore], ) -> Any: tool_calls, input_type = self._parse_input(input, store) outputs = await asyncio.gather( @@ -495,20 +574,20 @@ def _combine_tool_outputs( combined_outputs.append(parent_command) return combined_outputs - def _run_one( - self, - call: ToolCall, - input_type: Literal["list", "dict", "tool_calls"], - config: RunnableConfig, - ) -> ToolMessage | Command: - """Run a single tool call synchronously.""" - if invalid_tool_message := self._validate_tool_call(call): - return invalid_tool_message + def _execute_tool_sync( + self, request: ToolRequest, input_type: Literal["list", "dict", "tool_calls"] + ) -> ToolResponse: + """Execute tool and return response. - try: - call_args = {**call, "type": "tool_call"} - tool = self.tools_by_name[call["name"]] + Applies handle_tool_errors configuration. When on_tool_call is configured, + unhandled errors return action="raise" instead of raising immediately. + """ + call = request.tool_call + tool = request.tool + config = request.config + call_args = {**call, "type": "tool_call"} + try: try: response = tool.invoke(call_args, config) except ValidationError as exc: @@ -541,40 +620,110 @@ def _run_one( # default behavior is catching all exceptions handled_types = (Exception,) - # Unhandled + # Check if error is handled if not self._handle_tool_errors or not isinstance(e, handled_types): + # Error is not handled + if self._on_tool_call is not None: + # If handler exists, return action="raise" so handler can decide + return ToolResponse(action="raise", exception=e) + # No handler - maintain backward compatibility by raising immediately raise - # Handled + + # Error is handled - create error ToolMessage content = _handle_tool_error(e, flag=self._handle_tool_errors) - return ToolMessage( + error_message = ToolMessage( content=content, name=call["name"], tool_call_id=call["id"], status="error", ) + return ToolResponse(action="return", result=error_message, exception=e) + # Process successful response if isinstance(response, Command): - return self._validate_tool_command(response, call, input_type) + # Validate Command before returning to handler + validated_command = self._validate_tool_command(response, request.tool_call, input_type) + return ToolResponse(action="return", result=validated_command) if isinstance(response, ToolMessage): response.content = cast("str | list", msg_content_output(response.content)) - return response + return ToolResponse(action="return", result=response) + msg = f"Tool {call['name']} returned unexpected type: {type(response)}" raise TypeError(msg) - async def _arun_one( + def _run_one( self, call: ToolCall, input_type: Literal["list", "dict", "tool_calls"], config: RunnableConfig, ) -> ToolMessage | Command: - """Run a single tool call asynchronously.""" + """Run a single tool call synchronously.""" if invalid_tool_message := self._validate_tool_call(call): return invalid_tool_message - try: - call_args = {**call, "type": "tool_call"} - tool = self.tools_by_name[call["name"]] + tool = self.tools_by_name[call["name"]] + + # Create the tool request + tool_request = ToolRequest( + tool_call=call, + tool=tool, + config=config, + ) + + if self._on_tool_call is None: + tool_response = self._execute_tool_sync(tool_request, input_type) + else: + # Generator protocol: start generator, send responses, receive requests + gen = self._on_tool_call(tool_request) + try: + request = next(gen) + except StopIteration: + msg = "on_tool_call handler must yield at least once before returning" + raise ValueError(msg) + + while True: + tool_response = self._execute_tool_sync(request, input_type) + try: + request = gen.send(tool_response) + except StopIteration as e: + if e.value is None: + msg = ( + "on_tool_call handler must explicitly return a ToolResponse. " + "Ensure your handler ends with 'return response'." + ) + raise ValueError(msg) + tool_response = e.value + break + + # Apply action directive + if tool_response.action == "raise": + if tool_response.exception is None: + msg = "ToolResponse with action='raise' must have an exception" + raise ValueError(msg) + raise tool_response.exception + + result = tool_response.result + if result is None: + msg = "ToolResponse with action='return' must have a result" + raise ValueError(msg) + + return result + + async def _execute_tool_async( + self, request: ToolRequest, input_type: Literal["list", "dict", "tool_calls"] + ) -> ToolResponse: + """Execute tool asynchronously and return response. + + Applies handle_tool_errors configuration. When on_tool_call is configured, + unhandled errors return action="raise" instead of raising immediately. + """ + call = request.tool_call + tool = request.tool + config = request.config + call_args = {**call, "type": "tool_call"} + + try: try: response = await tool.ainvoke(call_args, config) except ValidationError as exc: @@ -607,27 +756,96 @@ async def _arun_one( # default behavior is catching all exceptions handled_types = (Exception,) - # Unhandled + # Check if error is handled if not self._handle_tool_errors or not isinstance(e, handled_types): + # Error is not handled + if self._on_tool_call is not None: + # If handler exists, return action="raise" so handler can decide + return ToolResponse(action="raise", exception=e) + # No handler - maintain backward compatibility by raising immediately raise - # Handled - content = _handle_tool_error(e, flag=self._handle_tool_errors) - return ToolMessage( + # Error is handled - create error ToolMessage + content = _handle_tool_error(e, flag=self._handle_tool_errors) + error_message = ToolMessage( content=content, name=call["name"], tool_call_id=call["id"], status="error", ) + return ToolResponse(action="return", result=error_message, exception=e) + # Process successful response if isinstance(response, Command): - return self._validate_tool_command(response, call, input_type) + # Validate Command before returning to handler + validated_command = self._validate_tool_command(response, request.tool_call, input_type) + return ToolResponse(action="return", result=validated_command) if isinstance(response, ToolMessage): response.content = cast("str | list", msg_content_output(response.content)) - return response + return ToolResponse(action="return", result=response) + msg = f"Tool {call['name']} returned unexpected type: {type(response)}" raise TypeError(msg) + async def _arun_one( + self, + call: ToolCall, + input_type: Literal["list", "dict", "tool_calls"], + config: RunnableConfig, + ) -> ToolMessage | Command: + """Run a single tool call asynchronously.""" + if invalid_tool_message := self._validate_tool_call(call): + return invalid_tool_message + + tool = self.tools_by_name[call["name"]] + + # Create the tool request + tool_request = ToolRequest( + tool_call=call, + tool=tool, + config=config, + ) + + if self._on_tool_call is None: + tool_response = await self._execute_tool_async(tool_request, input_type) + else: + # Generator protocol: handler is sync generator, tool execution is async + gen = self._on_tool_call(tool_request) + + try: + request = next(gen) + except StopIteration: + msg = "on_tool_call handler must yield at least once before returning" + raise ValueError(msg) + + while True: + tool_response = await self._execute_tool_async(request, input_type) + try: + request = gen.send(tool_response) + except StopIteration as e: + if e.value is None: + msg = ( + "on_tool_call handler must explicitly return a ToolResponse. " + "Ensure your handler ends with 'return response'." + ) + raise ValueError(msg) + tool_response = e.value + break + + # Apply action directive + if tool_response.action == "raise": + if tool_response.exception is None: + msg = "ToolResponse with action='raise' must have an exception" + raise ValueError(msg) + raise tool_response.exception + + result = tool_response.result + if result is None: + msg = "ToolResponse with action='return' must have a result" + raise ValueError(msg) + + return result + def _parse_input( self, input: list[AnyMessage] | dict[str, Any] | BaseModel, diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/test_on_tool_call.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_on_tool_call.py new file mode 100644 index 0000000000000..2f15a58cec66a --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_on_tool_call.py @@ -0,0 +1,402 @@ +"""Unit tests for on_tool_call middleware hook.""" + +from collections.abc import Generator +from typing import Any, Literal, Union +import typing + +from pydantic import BaseModel +import pytest +from langchain_core.language_models import LanguageModelInput +from langchain_core.language_models.fake_chat_models import GenericFakeChatModel +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage +from langchain_core.runnables import Runnable +from langchain_core.tools import BaseTool, tool + +from langchain.agents.middleware.types import AgentMiddleware +from langchain.agents.middleware_agent import create_agent +from langchain.tools.tool_node import ToolRequest, ToolResponse + + +class FakeModel(GenericFakeChatModel): + """Fake chat model for testing.""" + + tool_style: Literal["openai", "anthropic"] = "openai" + + def bind_tools( + self, + tools: typing.Sequence[Union[dict[str, Any], type[BaseModel], typing.Callable, BaseTool]], + **kwargs: Any, + ) -> Runnable[LanguageModelInput, BaseMessage]: + if len(tools) == 0: + msg = "Must provide at least one tool" + raise ValueError(msg) + + tool_dicts = [] + for tool in tools: + if isinstance(tool, dict): + tool_dicts.append(tool) + continue + if not isinstance(tool, BaseTool): + msg = "Only BaseTool and dict is supported by FakeModel.bind_tools" + raise TypeError(msg) + + # NOTE: this is a simplified tool spec for testing purposes only + if self.tool_style == "openai": + tool_dicts.append( + { + "type": "function", + "function": { + "name": tool.name, + }, + } + ) + elif self.tool_style == "anthropic": + tool_dicts.append( + { + "name": tool.name, + } + ) + + return self.bind(tools=tool_dicts) + + +@tool +def add_tool(x: int, y: int) -> int: + """Add two numbers.""" + return x + y + + +@tool +def failing_tool(x: int) -> int: + """Tool that raises an error.""" + msg = "Intentional failure" + raise ValueError(msg) + + +def test_single_middleware_on_tool_call(): + """Test that a single middleware can intercept tool calls.""" + call_log = [] + + class LoggingMiddleware(AgentMiddleware): + """Middleware that logs tool calls.""" + + def on_tool_call( + self, request: ToolRequest + ) -> Generator[ToolRequest, ToolResponse, ToolResponse]: + call_log.append(f"before_{request.tool.name}") + response = yield request + call_log.append(f"after_{request.tool.name}") + return response + + model = FakeModel( + messages=iter( + [ + AIMessage( + content="", + tool_calls=[{"name": "add_tool", "args": {"x": 2, "y": 3}, "id": "1"}], + ), + AIMessage(content="Done"), + ] + ) + ) + + agent = create_agent( + model=model, + tools=[add_tool], + middleware=[LoggingMiddleware()], + ) + + result = agent.compile().invoke({"messages": [HumanMessage("Add 2 and 3")]}) + + assert "before_add_tool" in call_log + assert "after_add_tool" in call_log + assert call_log.index("before_add_tool") < call_log.index("after_add_tool") + + # Check that tool executed successfully + tool_messages = [m for m in result["messages"] if isinstance(m, ToolMessage)] + assert len(tool_messages) == 1 + assert tool_messages[0].content == "5" + + +def test_multiple_middleware_chaining(): + """Test that multiple middleware chain correctly (outer wraps inner).""" + call_order = [] + + class OuterMiddleware(AgentMiddleware): + """Outer middleware in the chain.""" + + def on_tool_call( + self, request: ToolRequest + ) -> Generator[ToolRequest, ToolResponse, ToolResponse]: + call_order.append("outer_start") + response = yield request + call_order.append("outer_end") + return response + + class InnerMiddleware(AgentMiddleware): + """Inner middleware in the chain.""" + + def on_tool_call( + self, request: ToolRequest + ) -> Generator[ToolRequest, ToolResponse, ToolResponse]: + call_order.append("inner_start") + response = yield request + call_order.append("inner_end") + return response + + model = FakeModel( + messages=iter( + [ + AIMessage( + content="", + tool_calls=[{"name": "add_tool", "args": {"x": 1, "y": 1}, "id": "1"}], + ), + AIMessage(content="Done"), + ] + ) + ) + + agent = create_agent( + model=model, + tools=[add_tool], + middleware=[OuterMiddleware(), InnerMiddleware()], + ) + + agent.compile().invoke({"messages": [HumanMessage("Add 1 and 1")]}) + + # Verify order: outer_start -> inner_start -> tool -> inner_end -> outer_end + assert call_order == ["outer_start", "inner_start", "inner_end", "outer_end"] + + +def test_middleware_retry_logic(): + """Test that middleware can retry tool calls.""" + attempt_count = 0 + + class RetryMiddleware(AgentMiddleware): + """Middleware that retries on failure.""" + + def on_tool_call( + self, request: ToolRequest + ) -> Generator[ToolRequest, ToolResponse, ToolResponse]: + nonlocal attempt_count + max_retries = 2 + + for attempt in range(max_retries): + attempt_count += 1 + response = yield request + + if response.action == "return": + return response + + if response.action == "raise" and attempt < max_retries - 1: + # Retry + continue + + # Convert error to success message + return ToolResponse( + action="return", + result=ToolMessage( + content=f"Failed after {max_retries} attempts", + name=request.tool_call["name"], + tool_call_id=request.tool_call["id"], + status="error", + ), + ) + + # Should never reach here + return response + + model = FakeModel( + messages=iter( + [ + AIMessage( + content="", + tool_calls=[{"name": "failing_tool", "args": {"x": 1}, "id": "1"}], + ), + AIMessage(content="Done"), + ] + ) + ) + + agent = create_agent( + model=model, + tools=[failing_tool], + middleware=[RetryMiddleware()], + ) + + result = agent.compile().invoke({"messages": [HumanMessage("Test retry")]}) + + # Should have attempted twice + assert attempt_count == 2 + + # Check that we got an error message + tool_messages = [m for m in result["messages"] if isinstance(m, ToolMessage)] + assert len(tool_messages) == 1 + assert "Failed after 2 attempts" in tool_messages[0].content + + +def test_middleware_request_modification(): + """Test that middleware can modify tool requests.""" + + class RequestModifierMiddleware(AgentMiddleware): + """Middleware that doubles the input.""" + + def on_tool_call( + self, request: ToolRequest + ) -> Generator[ToolRequest, ToolResponse, ToolResponse]: + # Modify the arguments + modified_tool_call = { + **request.tool_call, + "args": { + "x": request.tool_call["args"]["x"] * 2, + "y": request.tool_call["args"]["y"] * 2, + }, + } + modified_request = ToolRequest( + tool_call=modified_tool_call, + tool=request.tool, + config=request.config, + ) + response = yield modified_request + return response + + model = FakeModel( + messages=iter( + [ + AIMessage( + content="", + tool_calls=[{"name": "add_tool", "args": {"x": 1, "y": 2}, "id": "1"}], + ), + AIMessage(content="Done"), + ] + ) + ) + + agent = create_agent( + model=model, + tools=[add_tool], + middleware=[RequestModifierMiddleware()], + ) + + result = agent.compile().invoke({"messages": [HumanMessage("Add 1 and 2")]}) + + # Original: 1 + 2 = 3, Modified: 2 + 4 = 6 + tool_messages = [m for m in result["messages"] if isinstance(m, ToolMessage)] + assert len(tool_messages) == 1 + assert tool_messages[0].content == "6" + + +def test_multiple_middleware_with_retry(): + """Test complex scenario with multiple middleware and retry logic.""" + call_log = [] + + class MonitoringMiddleware(AgentMiddleware): + """Outer middleware for monitoring.""" + + def on_tool_call( + self, request: ToolRequest + ) -> Generator[ToolRequest, ToolResponse, ToolResponse]: + call_log.append("monitoring_start") + response = yield request + call_log.append("monitoring_end") + return response + + class RetryMiddleware(AgentMiddleware): + """Inner middleware for retries.""" + + def on_tool_call( + self, request: ToolRequest + ) -> Generator[ToolRequest, ToolResponse, ToolResponse]: + call_log.append("retry_start") + for attempt in range(2): + call_log.append(f"retry_attempt_{attempt + 1}") + response = yield request + + if response.action == "return": + call_log.append("retry_success") + return response + + if attempt == 0: # Retry once + call_log.append("retry_retry") + continue + + call_log.append("retry_failed") + return response + + model = FakeModel( + messages=iter( + [ + AIMessage( + content="", + tool_calls=[{"name": "add_tool", "args": {"x": 5, "y": 7}, "id": "1"}], + ), + AIMessage(content="Done"), + ] + ) + ) + + agent = create_agent( + model=model, + tools=[add_tool], + middleware=[MonitoringMiddleware(), RetryMiddleware()], + ) + + agent.compile().invoke({"messages": [HumanMessage("Add 5 and 7")]}) + + # Verify the call sequence + assert call_log[0] == "monitoring_start" + assert call_log[1] == "retry_start" + assert "retry_attempt_1" in call_log + assert "retry_success" in call_log + assert call_log[-1] == "monitoring_end" + + +def test_mixed_middleware(): + """Test middleware with both before_model and on_tool_call hooks.""" + call_log = [] + + class MixedMiddleware(AgentMiddleware): + """Middleware with multiple hooks.""" + + def before_model(self, state, runtime): + call_log.append("before_model") + return None + + def on_tool_call( + self, request: ToolRequest + ) -> Generator[ToolRequest, ToolResponse, ToolResponse]: + call_log.append("on_tool_call_start") + response = yield request + call_log.append("on_tool_call_end") + return response + + model = FakeModel( + messages=iter( + [ + AIMessage( + content="", + tool_calls=[{"name": "add_tool", "args": {"x": 10, "y": 20}, "id": "1"}], + ), + AIMessage(content="Done"), + ] + ) + ) + + agent = create_agent( + model=model, + tools=[add_tool], + middleware=[MixedMiddleware()], + ) + + agent.compile().invoke({"messages": [HumanMessage("Add 10 and 20")]}) + + # Both hooks should have been called + assert "before_model" in call_log + assert "on_tool_call_start" in call_log + assert "on_tool_call_end" in call_log + # before_model runs before on_tool_call + assert call_log.index("before_model") < call_log.index("on_tool_call_start") + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/libs/langchain_v1/tests/unit_tests/tools/test_on_tool_call.py b/libs/langchain_v1/tests/unit_tests/tools/test_on_tool_call.py new file mode 100644 index 0000000000000..d33d3c41490ce --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/tools/test_on_tool_call.py @@ -0,0 +1,389 @@ +"""Tests for on_tool_call handler functionality.""" + +from collections.abc import Generator + +import pytest +from langchain_core.messages import AIMessage, ToolMessage +from langchain_core.tools import tool + +from langchain.tools import ToolNode +from langchain.tools.tool_node import ToolRequest, ToolResponse + + +# Test tools +@tool +def success_tool(x: int) -> int: + """A tool that always succeeds.""" + return x * 2 + + +@tool +def error_tool(x: int) -> int: + """A tool that always raises ValueError.""" + msg = f"Error with value: {x}" + raise ValueError(msg) + + +@tool +def rate_limit_tool(x: int) -> int: + """A tool that simulates rate limit errors.""" + if not hasattr(rate_limit_tool, "_call_count"): + rate_limit_tool._call_count = 0 + rate_limit_tool._call_count += 1 + + if rate_limit_tool._call_count < 3: # Fail first 2 times + msg = "Rate limit exceeded" + raise ValueError(msg) + return x * 2 + + +def test_on_tool_call_passthrough() -> None: + """Test that a simple passthrough handler works.""" + + def passthrough_handler( + request: ToolRequest, + ) -> Generator[ToolRequest, ToolResponse, ToolResponse]: + """Simply pass through without modification.""" + response = yield request + return response + + tool_node = ToolNode([success_tool], on_tool_call=passthrough_handler) + result = tool_node.invoke( + { + "messages": [ + AIMessage( + "", + tool_calls=[{"name": "success_tool", "args": {"x": 5}, "id": "1"}], + ) + ] + } + ) + + assert len(result["messages"]) == 1 + tool_message: ToolMessage = result["messages"][0] + assert tool_message.content == "10" + assert tool_message.status != "error" + + +def test_on_tool_call_retry_success(): + """Test that retry handler can recover from transient errors.""" + # Reset counter + if hasattr(rate_limit_tool, "_call_count"): + rate_limit_tool._call_count = 0 + + def retry_handler( + request: ToolRequest, + ) -> Generator[ToolRequest, ToolResponse, ToolResponse]: + """Retry up to 3 times.""" + max_retries = 3 + + for attempt in range(max_retries): + response = yield request + + if response.action == "return": + return response + + # Retry on error + if attempt < max_retries - 1: + continue + + # Final attempt failed - convert to error message + return ToolResponse( + action="return", + result=ToolMessage( + content=f"Failed after {max_retries} attempts", + name=request.tool_call["name"], + tool_call_id=request.tool_call["id"], + status="error", + ), + ) + + return response # Should never reach here + + tool_node = ToolNode([rate_limit_tool], on_tool_call=retry_handler, handle_tool_errors=False) + result = tool_node.invoke( + { + "messages": [ + AIMessage( + "", + tool_calls=[{"name": "rate_limit_tool", "args": {"x": 5}, "id": "1"}], + ) + ] + } + ) + + assert len(result["messages"]) == 1 + tool_message: ToolMessage = result["messages"][0] + assert tool_message.content == "10" # Should succeed on 3rd attempt + assert tool_message.status != "error" + + +def test_on_tool_call_convert_error_to_message(): + """Test that handler can convert raised errors to error messages.""" + + def error_to_message_handler( + request: ToolRequest, + ) -> Generator[ToolRequest, ToolResponse, ToolResponse]: + """Convert any error to a user-friendly message.""" + response = yield request + + if response.action == "raise": + return ToolResponse( + action="return", + result=ToolMessage( + content=f"Tool failed: {response.exception}", + name=request.tool_call["name"], + tool_call_id=request.tool_call["id"], + status="error", + ), + exception=response.exception, + ) + + return response + + tool_node = ToolNode( + [error_tool], on_tool_call=error_to_message_handler, handle_tool_errors=False + ) + result = tool_node.invoke( + { + "messages": [ + AIMessage( + "", + tool_calls=[{"name": "error_tool", "args": {"x": 5}, "id": "1"}], + ) + ] + } + ) + + assert len(result["messages"]) == 1 + tool_message: ToolMessage = result["messages"][0] + assert "Tool failed" in tool_message.content + assert "Error with value: 5" in tool_message.content + assert tool_message.status == "error" + + +def test_on_tool_call_let_error_raise(): + """Test that handler can let errors propagate.""" + + def let_raise_handler( + request: ToolRequest, + ) -> Generator[ToolRequest, ToolResponse, ToolResponse]: + """Just return the response as-is, letting errors raise.""" + response = yield request + return response + + tool_node = ToolNode([error_tool], on_tool_call=let_raise_handler, handle_tool_errors=False) + + with pytest.raises(ValueError) as exc_info: + tool_node.invoke( + { + "messages": [ + AIMessage( + "", + tool_calls=[{"name": "error_tool", "args": {"x": 5}, "id": "1"}], + ) + ] + } + ) + + assert "Error with value: 5" in str(exc_info.value) + + +def test_on_tool_call_with_handled_errors(): + """Test interaction between on_tool_call and handle_tool_errors.""" + call_count = {"count": 0} + + def counting_handler( + request: ToolRequest, + ) -> Generator[ToolRequest, ToolResponse, ToolResponse]: + """Count how many times we're called.""" + call_count["count"] += 1 + response = yield request + return response + + # When handle_tool_errors=True, errors are converted to ToolMessages + # so handler sees action="return" + tool_node = ToolNode([error_tool], on_tool_call=counting_handler, handle_tool_errors=True) + result = tool_node.invoke( + { + "messages": [ + AIMessage( + "", + tool_calls=[{"name": "error_tool", "args": {"x": 5}, "id": "1"}], + ) + ] + } + ) + + assert call_count["count"] == 1 + assert len(result["messages"]) == 1 + tool_message: ToolMessage = result["messages"][0] + assert tool_message.status == "error" + assert "Please fix your mistakes" in tool_message.content + + +def test_on_tool_call_must_return_value(): + """Test that handler must return a ToolResponse.""" + + def no_return_handler( + request: ToolRequest, + ) -> Generator[ToolRequest, ToolResponse, ToolResponse]: + """Handler that doesn't return anything.""" + response = yield request + # Implicit return None + + tool_node = ToolNode([success_tool], on_tool_call=no_return_handler) + + with pytest.raises(ValueError) as exc_info: + tool_node.invoke( + { + "messages": [ + AIMessage( + "", + tool_calls=[{"name": "success_tool", "args": {"x": 5}, "id": "1"}], + ) + ] + } + ) + + assert "must explicitly return a ToolResponse" in str(exc_info.value) + + +def test_on_tool_call_request_modification(): + """Test that handler can modify the request before execution.""" + + def double_input_handler( + request: ToolRequest, + ) -> Generator[ToolRequest, ToolResponse, ToolResponse]: + """Double the input value.""" + # Modify the tool call args + modified_tool_call = { + **request.tool_call, + "args": {**request.tool_call["args"], "x": request.tool_call["args"]["x"] * 2}, + } + modified_request = ToolRequest( + tool_call=modified_tool_call, + tool=request.tool, + config=request.config, + ) + response = yield modified_request + return response + + tool_node = ToolNode([success_tool], on_tool_call=double_input_handler) + result = tool_node.invoke( + { + "messages": [ + AIMessage( + "", + tool_calls=[{"name": "success_tool", "args": {"x": 5}, "id": "1"}], + ) + ] + } + ) + + assert len(result["messages"]) == 1 + tool_message: ToolMessage = result["messages"][0] + # Input was 5, doubled to 10, then tool multiplies by 2 = 20 + assert tool_message.content == "20" + + +def test_on_tool_call_response_validation(): + """Test that ToolResponse validates action and required fields.""" + # Test action="return" requires result + with pytest.raises(ValueError) as exc_info: + ToolResponse(action="return") + assert "action='return' requires a result" in str(exc_info.value) + + # Test action="raise" requires exception + with pytest.raises(ValueError) as exc_info: + ToolResponse(action="raise") + assert "action='raise' requires an exception" in str(exc_info.value) + + # Valid responses should work + ToolResponse( + action="return", + result=ToolMessage(content="test", tool_call_id="1", name="test"), + ) + ToolResponse(action="raise", exception=ValueError("test")) + + +def test_on_tool_call_without_handler_backward_compat(): + """Test that tools work without on_tool_call handler (backward compatibility).""" + # Success case + tool_node = ToolNode([success_tool]) + result = tool_node.invoke( + { + "messages": [ + AIMessage( + "", + tool_calls=[{"name": "success_tool", "args": {"x": 5}, "id": "1"}], + ) + ] + } + ) + assert result["messages"][0].content == "10" + + # Error case with handle_tool_errors=False + tool_node_error = ToolNode([error_tool], handle_tool_errors=False) + with pytest.raises(ValueError): + tool_node_error.invoke( + { + "messages": [ + AIMessage( + "", + tool_calls=[{"name": "error_tool", "args": {"x": 5}, "id": "1"}], + ) + ] + } + ) + + # Error case with handle_tool_errors=True + tool_node_handled = ToolNode([error_tool], handle_tool_errors=True) + result = tool_node_handled.invoke( + { + "messages": [ + AIMessage( + "", + tool_calls=[{"name": "error_tool", "args": {"x": 5}, "id": "1"}], + ) + ] + } + ) + assert result["messages"][0].status == "error" + + +def test_on_tool_call_multiple_yields(): + """Test that handler can yield multiple times for retries.""" + attempts = {"count": 0} + + def multi_yield_handler( + request: ToolRequest, + ) -> Generator[ToolRequest, ToolResponse, ToolResponse]: + """Yield multiple times to track attempts.""" + max_attempts = 3 + + for _ in range(max_attempts): + attempts["count"] += 1 + response = yield request + + if response.action == "return": + return response + + # All attempts failed + return response + + tool_node = ToolNode([error_tool], on_tool_call=multi_yield_handler, handle_tool_errors=False) + + with pytest.raises(ValueError): + tool_node.invoke( + { + "messages": [ + AIMessage( + "", + tool_calls=[{"name": "error_tool", "args": {"x": 5}, "id": "1"}], + ) + ] + } + ) + + assert attempts["count"] == 3 From 0927ae4be11ec466819b2d864709971bdc64d0a1 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Mon, 6 Oct 2025 15:26:43 -0400 Subject: [PATCH 02/12] x --- .../agents/middleware/tool_error_handling.py | 20 +++---- .../langchain/agents/middleware/types.py | 12 ++-- .../langchain/agents/middleware_agent.py | 12 ++-- .../langchain_v1/langchain/tools/tool_node.py | 56 +++++++++--------- ...all.py => test_on_tool_call_middleware.py} | 42 +++++++------- .../unit_tests/tools/test_on_tool_call.py | 57 +++++++++---------- 6 files changed, 98 insertions(+), 101 deletions(-) rename libs/langchain_v1/tests/unit_tests/agents/middleware/{test_on_tool_call.py => test_on_tool_call_middleware.py} (90%) diff --git a/libs/langchain_v1/langchain/agents/middleware/tool_error_handling.py b/libs/langchain_v1/langchain/agents/middleware/tool_error_handling.py index bc55e9194de15..37441d9be593f 100644 --- a/libs/langchain_v1/langchain/agents/middleware/tool_error_handling.py +++ b/libs/langchain_v1/langchain/agents/middleware/tool_error_handling.py @@ -16,14 +16,14 @@ from langchain.agents.middleware.types import AgentMiddleware -# Import ToolResponse locally to avoid circular import -from langchain.tools.tool_node import ToolResponse +# Import ToolCallResponse locally to avoid circular import +from langchain.tools.tool_node import ToolCallResponse if TYPE_CHECKING: from collections.abc import Generator from types import UnionType - from langchain.tools.tool_node import ToolRequest, ToolResponse + from langchain.tools.tool_node import ToolCallRequest, ToolCallResponse logger = logging.getLogger(__name__) @@ -233,8 +233,8 @@ def __init__( raise ValueError(msg) def on_tool_call( - self, request: ToolRequest - ) -> Generator[ToolRequest, ToolResponse, ToolResponse]: + self, request: ToolCallRequest + ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: """Retry tool execution on failures.""" for attempt in range(1, self.max_retries + 2): # +1 for initial, +1 for inclusive response = yield request @@ -247,7 +247,7 @@ def on_tool_call( if response.action == "raise": exception = response.exception if exception is None: - msg = "ToolResponse with action='raise' must have an exception" + msg = "ToolCallResponse with action='raise' must have an exception" raise ValueError(msg) # Check if this exception type is retriable @@ -365,8 +365,8 @@ def __init__( self.message_template = message_template def on_tool_call( - self, request: ToolRequest - ) -> Generator[ToolRequest, ToolResponse, ToolResponse]: + self, request: ToolCallRequest + ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: """Convert matching errors to ToolMessages.""" response = yield request @@ -378,7 +378,7 @@ def on_tool_call( if response.action == "raise": exception = response.exception if exception is None: - msg = "ToolResponse with action='raise' must have an exception" + msg = "ToolCallResponse with action='raise' must have an exception" raise ValueError(msg) # Check if exception type matches @@ -400,7 +400,7 @@ def on_tool_call( status="error", ) - return ToolResponse( + return ToolCallResponse( action="return", result=tool_message, exception=exception, # Preserve for logging/debugging diff --git a/libs/langchain_v1/langchain/agents/middleware/types.py b/libs/langchain_v1/langchain/agents/middleware/types.py index 6457cc051b1ee..323f2d2b74763 100644 --- a/libs/langchain_v1/langchain/agents/middleware/types.py +++ b/libs/langchain_v1/langchain/agents/middleware/types.py @@ -36,7 +36,7 @@ from langgraph.types import Command from langchain.agents.structured_output import ResponseFormat - from langchain.tools.tool_node import ToolRequest, ToolResponse + from langchain.tools.tool_node import ToolCallRequest, ToolCallResponse __all__ = [ "AgentMiddleware", @@ -218,8 +218,8 @@ async def aretry_model_request( def on_tool_call( self, - request: ToolRequest, - ) -> Generator[ToolRequest, ToolResponse, ToolResponse]: + request: ToolCallRequest, + ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: """Intercept tool execution to implement retry logic, monitoring, or request modification. Provides generator-based control over the complete tool execution lifecycle. @@ -227,10 +227,10 @@ def on_tool_call( outer middleware wrapping inner middleware (first defined = outermost layer). Generator Protocol: - 1. Yield a ToolRequest (potentially modified from the input) - 2. Receive a ToolResponse via .send() + 1. Yield a ToolCallRequest (potentially modified from the input) + 2. Receive a ToolCallResponse via .send() 3. Optionally yield again to retry - 4. Return the final ToolResponse to propagate + 4. Return the final ToolCallResponse to propagate Args: request: Tool invocation details including tool_call, tool instance, and config. diff --git a/libs/langchain_v1/langchain/agents/middleware_agent.py b/libs/langchain_v1/langchain/agents/middleware_agent.py index 2cae5876a83fc..592e30df6cc50 100644 --- a/libs/langchain_v1/langchain/agents/middleware_agent.py +++ b/libs/langchain_v1/langchain/agents/middleware_agent.py @@ -35,7 +35,7 @@ ) from langchain.chat_models import init_chat_model from langchain.tools import ToolNode -from langchain.tools.tool_node import ToolCallHandler, ToolRequest, ToolResponse +from langchain.tools.tool_node import ToolCallHandler, ToolCallRequest, ToolCallResponse STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes." @@ -84,10 +84,10 @@ def rate_limit(req): if len(handlers) == 1: return handlers[0] - def _extract_return_value(stop_iteration: StopIteration) -> ToolResponse: - """Extract ToolResponse from StopIteration, validating protocol compliance.""" + def _extract_return_value(stop_iteration: StopIteration) -> ToolCallResponse: + """Extract ToolCallResponse from StopIteration, validating protocol compliance.""" if stop_iteration.value is None: - msg = "on_tool_call handler must explicitly return a ToolResponse" + msg = "on_tool_call handler must explicitly return a ToolCallResponse" raise ValueError(msg) return stop_iteration.value @@ -95,8 +95,8 @@ def compose_two(outer: ToolCallHandler, inner: ToolCallHandler) -> ToolCallHandl """Compose two handlers where outer wraps inner.""" def composed( - request: ToolRequest, - ) -> Generator[ToolRequest, ToolResponse, ToolResponse]: + request: ToolCallRequest, + ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: outer_gen = outer(request) # Initialize outer generator diff --git a/libs/langchain_v1/langchain/tools/tool_node.py b/libs/langchain_v1/langchain/tools/tool_node.py index 1a392413b4095..d047fe3d9bc32 100644 --- a/libs/langchain_v1/langchain/tools/tool_node.py +++ b/libs/langchain_v1/langchain/tools/tool_node.py @@ -99,7 +99,7 @@ def my_tool(x: int) -> str: @dataclass() -class ToolRequest: +class ToolCallRequest: """Request passed to on_tool_call handler before tool execution. Attributes: @@ -117,7 +117,7 @@ class ToolRequest: @dataclass() -class ToolResponse: +class ToolCallResponse: """Response returned from on_tool_call handler after tool execution. The action field determines control flow: @@ -145,11 +145,11 @@ def __post_init__(self) -> None: raise ValueError(msg) -ToolCallHandler = Callable[[ToolRequest], Generator[ToolRequest, ToolResponse, ToolResponse]] +ToolCallHandler = Callable[[ToolCallRequest], Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]] """Generator-based handler that intercepts tool execution. -Receives a ToolRequest, yields modified ToolRequests, receives ToolResponses, -and returns a final ToolResponse. Supports multiple yields for retry logic. +Receives a ToolCallRequest, yields modified ToolCallRequests, receives ToolCallResponses, +and returns a final ToolCallResponse. Supports multiple yields for retry logic. """ @@ -389,8 +389,8 @@ class ToolNode(RunnableCallable): - ignores tool execution errors (they will be re-raised) on_tool_call: Optional handler to intercept tool execution. Receives - ``ToolRequest``, yields potentially modified requests, receives - ``ToolResponse`` via ``.send()``, and returns final ``ToolResponse``. + ``ToolCallRequest``, yields potentially modified requests, receives + ``ToolCallResponse`` via ``.send()``, and returns final ``ToolCallResponse``. Enables retries, argument modification, and custom error handling. Defaults to ``None``. @@ -441,7 +441,7 @@ def handle_errors(e: ValueError) -> str: Intercepting tool calls: ```python - from langchain.tools.tool_node import ToolRequest, ToolResponse + from langchain.tools.tool_node import ToolCallRequest, ToolCallResponse def retry_handler(request): \"\"\"Retry failed tool calls up to 3 times.\"\"\" @@ -575,8 +575,8 @@ def _combine_tool_outputs( return combined_outputs def _execute_tool_sync( - self, request: ToolRequest, input_type: Literal["list", "dict", "tool_calls"] - ) -> ToolResponse: + self, request: ToolCallRequest, input_type: Literal["list", "dict", "tool_calls"] + ) -> ToolCallResponse: """Execute tool and return response. Applies handle_tool_errors configuration. When on_tool_call is configured, @@ -625,7 +625,7 @@ def _execute_tool_sync( # Error is not handled if self._on_tool_call is not None: # If handler exists, return action="raise" so handler can decide - return ToolResponse(action="raise", exception=e) + return ToolCallResponse(action="raise", exception=e) # No handler - maintain backward compatibility by raising immediately raise @@ -637,16 +637,16 @@ def _execute_tool_sync( tool_call_id=call["id"], status="error", ) - return ToolResponse(action="return", result=error_message, exception=e) + return ToolCallResponse(action="return", result=error_message, exception=e) # Process successful response if isinstance(response, Command): # Validate Command before returning to handler validated_command = self._validate_tool_command(response, request.tool_call, input_type) - return ToolResponse(action="return", result=validated_command) + return ToolCallResponse(action="return", result=validated_command) if isinstance(response, ToolMessage): response.content = cast("str | list", msg_content_output(response.content)) - return ToolResponse(action="return", result=response) + return ToolCallResponse(action="return", result=response) msg = f"Tool {call['name']} returned unexpected type: {type(response)}" raise TypeError(msg) @@ -664,7 +664,7 @@ def _run_one( tool = self.tools_by_name[call["name"]] # Create the tool request - tool_request = ToolRequest( + tool_request = ToolCallRequest( tool_call=call, tool=tool, config=config, @@ -689,7 +689,7 @@ def _run_one( except StopIteration as e: if e.value is None: msg = ( - "on_tool_call handler must explicitly return a ToolResponse. " + "on_tool_call handler must explicitly return a ToolCallResponse. " "Ensure your handler ends with 'return response'." ) raise ValueError(msg) @@ -699,20 +699,20 @@ def _run_one( # Apply action directive if tool_response.action == "raise": if tool_response.exception is None: - msg = "ToolResponse with action='raise' must have an exception" + msg = "ToolCallResponse with action='raise' must have an exception" raise ValueError(msg) raise tool_response.exception result = tool_response.result if result is None: - msg = "ToolResponse with action='return' must have a result" + msg = "ToolCallResponse with action='return' must have a result" raise ValueError(msg) return result async def _execute_tool_async( - self, request: ToolRequest, input_type: Literal["list", "dict", "tool_calls"] - ) -> ToolResponse: + self, request: ToolCallRequest, input_type: Literal["list", "dict", "tool_calls"] + ) -> ToolCallResponse: """Execute tool asynchronously and return response. Applies handle_tool_errors configuration. When on_tool_call is configured, @@ -761,7 +761,7 @@ async def _execute_tool_async( # Error is not handled if self._on_tool_call is not None: # If handler exists, return action="raise" so handler can decide - return ToolResponse(action="raise", exception=e) + return ToolCallResponse(action="raise", exception=e) # No handler - maintain backward compatibility by raising immediately raise @@ -773,16 +773,16 @@ async def _execute_tool_async( tool_call_id=call["id"], status="error", ) - return ToolResponse(action="return", result=error_message, exception=e) + return ToolCallResponse(action="return", result=error_message, exception=e) # Process successful response if isinstance(response, Command): # Validate Command before returning to handler validated_command = self._validate_tool_command(response, request.tool_call, input_type) - return ToolResponse(action="return", result=validated_command) + return ToolCallResponse(action="return", result=validated_command) if isinstance(response, ToolMessage): response.content = cast("str | list", msg_content_output(response.content)) - return ToolResponse(action="return", result=response) + return ToolCallResponse(action="return", result=response) msg = f"Tool {call['name']} returned unexpected type: {type(response)}" raise TypeError(msg) @@ -800,7 +800,7 @@ async def _arun_one( tool = self.tools_by_name[call["name"]] # Create the tool request - tool_request = ToolRequest( + tool_request = ToolCallRequest( tool_call=call, tool=tool, config=config, @@ -825,7 +825,7 @@ async def _arun_one( except StopIteration as e: if e.value is None: msg = ( - "on_tool_call handler must explicitly return a ToolResponse. " + "on_tool_call handler must explicitly return a ToolCallResponse. " "Ensure your handler ends with 'return response'." ) raise ValueError(msg) @@ -835,13 +835,13 @@ async def _arun_one( # Apply action directive if tool_response.action == "raise": if tool_response.exception is None: - msg = "ToolResponse with action='raise' must have an exception" + msg = "ToolCallResponse with action='raise' must have an exception" raise ValueError(msg) raise tool_response.exception result = tool_response.result if result is None: - msg = "ToolResponse with action='return' must have a result" + msg = "ToolCallResponse with action='return' must have a result" raise ValueError(msg) return result diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/test_on_tool_call.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_on_tool_call_middleware.py similarity index 90% rename from libs/langchain_v1/tests/unit_tests/agents/middleware/test_on_tool_call.py rename to libs/langchain_v1/tests/unit_tests/agents/middleware/test_on_tool_call_middleware.py index 2f15a58cec66a..ba822d885adc5 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/test_on_tool_call.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_on_tool_call_middleware.py @@ -14,7 +14,7 @@ from langchain.agents.middleware.types import AgentMiddleware from langchain.agents.middleware_agent import create_agent -from langchain.tools.tool_node import ToolRequest, ToolResponse +from langchain.tools.tool_node import ToolCallRequest, ToolCallResponse class FakeModel(GenericFakeChatModel): @@ -81,8 +81,8 @@ class LoggingMiddleware(AgentMiddleware): """Middleware that logs tool calls.""" def on_tool_call( - self, request: ToolRequest - ) -> Generator[ToolRequest, ToolResponse, ToolResponse]: + self, request: ToolCallRequest + ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: call_log.append(f"before_{request.tool.name}") response = yield request call_log.append(f"after_{request.tool.name}") @@ -126,8 +126,8 @@ class OuterMiddleware(AgentMiddleware): """Outer middleware in the chain.""" def on_tool_call( - self, request: ToolRequest - ) -> Generator[ToolRequest, ToolResponse, ToolResponse]: + self, request: ToolCallRequest + ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: call_order.append("outer_start") response = yield request call_order.append("outer_end") @@ -137,8 +137,8 @@ class InnerMiddleware(AgentMiddleware): """Inner middleware in the chain.""" def on_tool_call( - self, request: ToolRequest - ) -> Generator[ToolRequest, ToolResponse, ToolResponse]: + self, request: ToolCallRequest + ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: call_order.append("inner_start") response = yield request call_order.append("inner_end") @@ -174,10 +174,9 @@ def test_middleware_retry_logic(): class RetryMiddleware(AgentMiddleware): """Middleware that retries on failure.""" - def on_tool_call( - self, request: ToolRequest - ) -> Generator[ToolRequest, ToolResponse, ToolResponse]: + self, request: ToolCallRequest + ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: nonlocal attempt_count max_retries = 2 @@ -193,7 +192,7 @@ def on_tool_call( continue # Convert error to success message - return ToolResponse( + return ToolCallResponse( action="return", result=ToolMessage( content=f"Failed after {max_retries} attempts", @@ -203,8 +202,7 @@ def on_tool_call( ), ) - # Should never reach here - return response + raise AssertionError("Unreachable") model = FakeModel( messages=iter( @@ -242,8 +240,8 @@ class RequestModifierMiddleware(AgentMiddleware): """Middleware that doubles the input.""" def on_tool_call( - self, request: ToolRequest - ) -> Generator[ToolRequest, ToolResponse, ToolResponse]: + self, request: ToolCallRequest + ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: # Modify the arguments modified_tool_call = { **request.tool_call, @@ -252,7 +250,7 @@ def on_tool_call( "y": request.tool_call["args"]["y"] * 2, }, } - modified_request = ToolRequest( + modified_request = ToolCallRequest( tool_call=modified_tool_call, tool=request.tool, config=request.config, @@ -294,8 +292,8 @@ class MonitoringMiddleware(AgentMiddleware): """Outer middleware for monitoring.""" def on_tool_call( - self, request: ToolRequest - ) -> Generator[ToolRequest, ToolResponse, ToolResponse]: + self, request: ToolCallRequest + ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: call_log.append("monitoring_start") response = yield request call_log.append("monitoring_end") @@ -305,8 +303,8 @@ class RetryMiddleware(AgentMiddleware): """Inner middleware for retries.""" def on_tool_call( - self, request: ToolRequest - ) -> Generator[ToolRequest, ToolResponse, ToolResponse]: + self, request: ToolCallRequest + ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: call_log.append("retry_start") for attempt in range(2): call_log.append(f"retry_attempt_{attempt + 1}") @@ -363,8 +361,8 @@ def before_model(self, state, runtime): return None def on_tool_call( - self, request: ToolRequest - ) -> Generator[ToolRequest, ToolResponse, ToolResponse]: + self, request: ToolCallRequest + ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: call_log.append("on_tool_call_start") response = yield request call_log.append("on_tool_call_end") diff --git a/libs/langchain_v1/tests/unit_tests/tools/test_on_tool_call.py b/libs/langchain_v1/tests/unit_tests/tools/test_on_tool_call.py index d33d3c41490ce..fc80b36094431 100644 --- a/libs/langchain_v1/tests/unit_tests/tools/test_on_tool_call.py +++ b/libs/langchain_v1/tests/unit_tests/tools/test_on_tool_call.py @@ -7,7 +7,7 @@ from langchain_core.tools import tool from langchain.tools import ToolNode -from langchain.tools.tool_node import ToolRequest, ToolResponse +from langchain.tools.tool_node import ToolCallRequest, ToolCallResponse # Test tools @@ -41,8 +41,8 @@ def test_on_tool_call_passthrough() -> None: """Test that a simple passthrough handler works.""" def passthrough_handler( - request: ToolRequest, - ) -> Generator[ToolRequest, ToolResponse, ToolResponse]: + request: ToolCallRequest, + ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: """Simply pass through without modification.""" response = yield request return response @@ -72,8 +72,8 @@ def test_on_tool_call_retry_success(): rate_limit_tool._call_count = 0 def retry_handler( - request: ToolRequest, - ) -> Generator[ToolRequest, ToolResponse, ToolResponse]: + request: ToolCallRequest, + ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: """Retry up to 3 times.""" max_retries = 3 @@ -88,7 +88,7 @@ def retry_handler( continue # Final attempt failed - convert to error message - return ToolResponse( + return ToolCallResponse( action="return", result=ToolMessage( content=f"Failed after {max_retries} attempts", @@ -97,8 +97,7 @@ def retry_handler( status="error", ), ) - - return response # Should never reach here + raise AssertionError("Unreachable code") tool_node = ToolNode([rate_limit_tool], on_tool_call=retry_handler, handle_tool_errors=False) result = tool_node.invoke( @@ -122,13 +121,13 @@ def test_on_tool_call_convert_error_to_message(): """Test that handler can convert raised errors to error messages.""" def error_to_message_handler( - request: ToolRequest, - ) -> Generator[ToolRequest, ToolResponse, ToolResponse]: + request: ToolCallRequest, + ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: """Convert any error to a user-friendly message.""" response = yield request if response.action == "raise": - return ToolResponse( + return ToolCallResponse( action="return", result=ToolMessage( content=f"Tool failed: {response.exception}", @@ -166,8 +165,8 @@ def test_on_tool_call_let_error_raise(): """Test that handler can let errors propagate.""" def let_raise_handler( - request: ToolRequest, - ) -> Generator[ToolRequest, ToolResponse, ToolResponse]: + request: ToolCallRequest, + ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: """Just return the response as-is, letting errors raise.""" response = yield request return response @@ -194,8 +193,8 @@ def test_on_tool_call_with_handled_errors(): call_count = {"count": 0} def counting_handler( - request: ToolRequest, - ) -> Generator[ToolRequest, ToolResponse, ToolResponse]: + request: ToolCallRequest, + ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: """Count how many times we're called.""" call_count["count"] += 1 response = yield request @@ -223,11 +222,11 @@ def counting_handler( def test_on_tool_call_must_return_value(): - """Test that handler must return a ToolResponse.""" + """Test that handler must return a ToolCallResponse.""" def no_return_handler( - request: ToolRequest, - ) -> Generator[ToolRequest, ToolResponse, ToolResponse]: + request: ToolCallRequest, + ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: """Handler that doesn't return anything.""" response = yield request # Implicit return None @@ -246,22 +245,22 @@ def no_return_handler( } ) - assert "must explicitly return a ToolResponse" in str(exc_info.value) + assert "must explicitly return a ToolCallResponse" in str(exc_info.value) def test_on_tool_call_request_modification(): """Test that handler can modify the request before execution.""" def double_input_handler( - request: ToolRequest, - ) -> Generator[ToolRequest, ToolResponse, ToolResponse]: + request: ToolCallRequest, + ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: """Double the input value.""" # Modify the tool call args modified_tool_call = { **request.tool_call, "args": {**request.tool_call["args"], "x": request.tool_call["args"]["x"] * 2}, } - modified_request = ToolRequest( + modified_request = ToolCallRequest( tool_call=modified_tool_call, tool=request.tool, config=request.config, @@ -288,23 +287,23 @@ def double_input_handler( def test_on_tool_call_response_validation(): - """Test that ToolResponse validates action and required fields.""" + """Test that ToolCallResponse validates action and required fields.""" # Test action="return" requires result with pytest.raises(ValueError) as exc_info: - ToolResponse(action="return") + ToolCallResponse(action="return") assert "action='return' requires a result" in str(exc_info.value) # Test action="raise" requires exception with pytest.raises(ValueError) as exc_info: - ToolResponse(action="raise") + ToolCallResponse(action="raise") assert "action='raise' requires an exception" in str(exc_info.value) # Valid responses should work - ToolResponse( + ToolCallResponse( action="return", result=ToolMessage(content="test", tool_call_id="1", name="test"), ) - ToolResponse(action="raise", exception=ValueError("test")) + ToolCallResponse(action="raise", exception=ValueError("test")) def test_on_tool_call_without_handler_backward_compat(): @@ -357,8 +356,8 @@ def test_on_tool_call_multiple_yields(): attempts = {"count": 0} def multi_yield_handler( - request: ToolRequest, - ) -> Generator[ToolRequest, ToolResponse, ToolResponse]: + request: ToolCallRequest, + ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: """Yield multiple times to track attempts.""" max_attempts = 3 From a9ff8e0b67c6d78d0f1527681dc9e6b0b11a5834 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Mon, 6 Oct 2025 15:31:16 -0400 Subject: [PATCH 03/12] x --- libs/langchain_v1/langchain/tools/tool_node.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langchain_v1/langchain/tools/tool_node.py b/libs/langchain_v1/langchain/tools/tool_node.py index d047fe3d9bc32..2398ca3b30d43 100644 --- a/libs/langchain_v1/langchain/tools/tool_node.py +++ b/libs/langchain_v1/langchain/tools/tool_node.py @@ -209,7 +209,7 @@ def __init__( super().__init__(self.message) -def _default_handle_tool_errors(e: Exception) -> str: +def _default_handle_tool_errors(e: ToolInvocationError) -> str: """Default error handler for tool errors. If the tool is a tool invocation error, return its message. From 65e073e85c8d37b586e30144cf88cfd579609128 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Mon, 6 Oct 2025 16:08:09 -0400 Subject: [PATCH 04/12] x --- .../langchain/agents/middleware/types.py | 6 +++- .../langchain/agents/middleware_agent.py | 31 +++------------- .../langchain_v1/langchain/tools/tool_node.py | 35 +++++++++++++++---- .../test_on_tool_call_middleware.py | 17 ++++----- .../unit_tests/tools/test_on_tool_call.py | 16 ++++----- 5 files changed, 54 insertions(+), 51 deletions(-) diff --git a/libs/langchain_v1/langchain/agents/middleware/types.py b/libs/langchain_v1/langchain/agents/middleware/types.py index 323f2d2b74763..2e005c2478a83 100644 --- a/libs/langchain_v1/langchain/agents/middleware/types.py +++ b/libs/langchain_v1/langchain/agents/middleware/types.py @@ -219,6 +219,8 @@ async def aretry_model_request( def on_tool_call( self, request: ToolCallRequest, + state: StateT, + runtime: Runtime[ContextT], ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: """Intercept tool execution to implement retry logic, monitoring, or request modification. @@ -234,6 +236,8 @@ def on_tool_call( Args: request: Tool invocation details including tool_call, tool instance, and config. + state: Current agent state (readonly context). + runtime: LangGraph runtime for accessing user context (readonly context). Returns: Generator for request/response interception. @@ -242,7 +246,7 @@ def on_tool_call( Retry on rate limit with exponential backoff: ```python - def on_tool_call(self, request): + def on_tool_call(self, request, state, runtime): for attempt in range(3): response = yield request if response.action == "return": diff --git a/libs/langchain_v1/langchain/agents/middleware_agent.py b/libs/langchain_v1/langchain/agents/middleware_agent.py index 592e30df6cc50..ac892f14a3bc5 100644 --- a/libs/langchain_v1/langchain/agents/middleware_agent.py +++ b/libs/langchain_v1/langchain/agents/middleware_agent.py @@ -52,31 +52,6 @@ def _chain_tool_call_handlers( Returns: Single composed handler, or None if handlers is empty. - - Example: - ```python - # Auth middleware (outer) + rate limit (inner) - def auth(req): - resp = yield req - if "unauthorized" in str(resp.exception): - refresh_token() - resp = yield req # Retry - return resp - - - def rate_limit(req): - for attempt in range(3): - resp = yield req - if "rate limit" not in str(resp.exception): - return resp - time.sleep(2**attempt) - return resp - - - handler = _chain_tool_call_handlers([auth, rate_limit]) - # Request: auth -> rate_limit -> tool - # Response: tool -> rate_limit -> auth - ``` """ if not handlers: return None @@ -96,8 +71,10 @@ def compose_two(outer: ToolCallHandler, inner: ToolCallHandler) -> ToolCallHandl def composed( request: ToolCallRequest, + state: Any, + runtime: Any, ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: - outer_gen = outer(request) + outer_gen = outer(request, state, runtime) # Initialize outer generator try: @@ -107,7 +84,7 @@ def composed( # Outer retry loop while True: - inner_gen = inner(outer_request) + inner_gen = inner(outer_request, state, runtime) # Initialize inner generator try: diff --git a/libs/langchain_v1/langchain/tools/tool_node.py b/libs/langchain_v1/langchain/tools/tool_node.py index 2398ca3b30d43..56d29fc95ce23 100644 --- a/libs/langchain_v1/langchain/tools/tool_node.py +++ b/libs/langchain_v1/langchain/tools/tool_node.py @@ -75,6 +75,7 @@ def my_tool(x: int) -> str: from langgraph._internal._runnable import RunnableCallable from langgraph.errors import GraphBubbleUp from langgraph.graph.message import REMOVE_ALL_MESSAGES +from langgraph.runtime import get_runtime from langgraph.types import Command, Send from pydantic import BaseModel, ValidationError @@ -145,11 +146,13 @@ def __post_init__(self) -> None: raise ValueError(msg) -ToolCallHandler = Callable[[ToolCallRequest], Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]] +ToolCallHandler = Callable[ + [ToolCallRequest, Any, Any], Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse] +] """Generator-based handler that intercepts tool execution. -Receives a ToolCallRequest, yields modified ToolCallRequests, receives ToolCallResponses, -and returns a final ToolCallResponse. Supports multiple yields for retry logic. +Receives a ToolCallRequest, state, and runtime; yields modified ToolCallRequests; +receives ToolCallResponses; and returns a final ToolCallResponse. Supports multiple yields for retry logic. """ @@ -510,11 +513,20 @@ def _func( *, store: Optional[BaseStore], ) -> Any: + try: + runtime = get_runtime() + except RuntimeError: + # Running outside of the LangGrah runtime context (e.g., unit-tests) + runtime = None tool_calls, input_type = self._parse_input(input, store) config_list = get_config_list(config, len(tool_calls)) input_types = [input_type] * len(tool_calls) + inputs = [input] * len(tool_calls) + runtimes = [runtime] * len(tool_calls) with get_executor_for_config(config) as executor: - outputs = [*executor.map(self._run_one, tool_calls, input_types, config_list)] + outputs = [ + *executor.map(self._run_one, tool_calls, input_types, config_list, inputs, runtimes) + ] return self._combine_tool_outputs(outputs, input_type) @@ -525,9 +537,14 @@ async def _afunc( *, store: Optional[BaseStore], ) -> Any: + try: + runtime = get_runtime() + except RuntimeError: + # Running outside of the LangGrah runtime context (e.g., unit-tests) + runtime = None tool_calls, input_type = self._parse_input(input, store) outputs = await asyncio.gather( - *(self._arun_one(call, input_type, config) for call in tool_calls) + *(self._arun_one(call, input_type, config, input, runtime) for call in tool_calls) ) return self._combine_tool_outputs(outputs, input_type) @@ -656,6 +673,8 @@ def _run_one( call: ToolCall, input_type: Literal["list", "dict", "tool_calls"], config: RunnableConfig, + input: list[AnyMessage] | dict[str, Any] | BaseModel, + runtime: Any, ) -> ToolMessage | Command: """Run a single tool call synchronously.""" if invalid_tool_message := self._validate_tool_call(call): @@ -674,7 +693,7 @@ def _run_one( tool_response = self._execute_tool_sync(tool_request, input_type) else: # Generator protocol: start generator, send responses, receive requests - gen = self._on_tool_call(tool_request) + gen = self._on_tool_call(tool_request, input, runtime) try: request = next(gen) @@ -792,6 +811,8 @@ async def _arun_one( call: ToolCall, input_type: Literal["list", "dict", "tool_calls"], config: RunnableConfig, + input: list[AnyMessage] | dict[str, Any] | BaseModel, + runtime: Any, ) -> ToolMessage | Command: """Run a single tool call asynchronously.""" if invalid_tool_message := self._validate_tool_call(call): @@ -810,7 +831,7 @@ async def _arun_one( tool_response = await self._execute_tool_async(tool_request, input_type) else: # Generator protocol: handler is sync generator, tool execution is async - gen = self._on_tool_call(tool_request) + gen = self._on_tool_call(tool_request, input, runtime) try: request = next(gen) diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/test_on_tool_call_middleware.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_on_tool_call_middleware.py index ba822d885adc5..ad982f4ec7314 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/test_on_tool_call_middleware.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_on_tool_call_middleware.py @@ -81,7 +81,7 @@ class LoggingMiddleware(AgentMiddleware): """Middleware that logs tool calls.""" def on_tool_call( - self, request: ToolCallRequest + self, request: ToolCallRequest, state, runtime ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: call_log.append(f"before_{request.tool.name}") response = yield request @@ -126,7 +126,7 @@ class OuterMiddleware(AgentMiddleware): """Outer middleware in the chain.""" def on_tool_call( - self, request: ToolCallRequest + self, request: ToolCallRequest, state, runtime ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: call_order.append("outer_start") response = yield request @@ -137,7 +137,7 @@ class InnerMiddleware(AgentMiddleware): """Inner middleware in the chain.""" def on_tool_call( - self, request: ToolCallRequest + self, request: ToolCallRequest, state, runtime ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: call_order.append("inner_start") response = yield request @@ -174,8 +174,9 @@ def test_middleware_retry_logic(): class RetryMiddleware(AgentMiddleware): """Middleware that retries on failure.""" + def on_tool_call( - self, request: ToolCallRequest + self, request: ToolCallRequest, state, runtime ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: nonlocal attempt_count max_retries = 2 @@ -240,7 +241,7 @@ class RequestModifierMiddleware(AgentMiddleware): """Middleware that doubles the input.""" def on_tool_call( - self, request: ToolCallRequest + self, request: ToolCallRequest, state, runtime ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: # Modify the arguments modified_tool_call = { @@ -292,7 +293,7 @@ class MonitoringMiddleware(AgentMiddleware): """Outer middleware for monitoring.""" def on_tool_call( - self, request: ToolCallRequest + self, request: ToolCallRequest, state, runtime ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: call_log.append("monitoring_start") response = yield request @@ -303,7 +304,7 @@ class RetryMiddleware(AgentMiddleware): """Inner middleware for retries.""" def on_tool_call( - self, request: ToolCallRequest + self, request: ToolCallRequest, state, runtime ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: call_log.append("retry_start") for attempt in range(2): @@ -361,7 +362,7 @@ def before_model(self, state, runtime): return None def on_tool_call( - self, request: ToolCallRequest + self, request: ToolCallRequest, state, runtime ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: call_log.append("on_tool_call_start") response = yield request diff --git a/libs/langchain_v1/tests/unit_tests/tools/test_on_tool_call.py b/libs/langchain_v1/tests/unit_tests/tools/test_on_tool_call.py index fc80b36094431..f844aff7352e3 100644 --- a/libs/langchain_v1/tests/unit_tests/tools/test_on_tool_call.py +++ b/libs/langchain_v1/tests/unit_tests/tools/test_on_tool_call.py @@ -41,7 +41,7 @@ def test_on_tool_call_passthrough() -> None: """Test that a simple passthrough handler works.""" def passthrough_handler( - request: ToolCallRequest, + request: ToolCallRequest, state, runtime ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: """Simply pass through without modification.""" response = yield request @@ -72,7 +72,7 @@ def test_on_tool_call_retry_success(): rate_limit_tool._call_count = 0 def retry_handler( - request: ToolCallRequest, + request: ToolCallRequest, state, runtime ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: """Retry up to 3 times.""" max_retries = 3 @@ -121,7 +121,7 @@ def test_on_tool_call_convert_error_to_message(): """Test that handler can convert raised errors to error messages.""" def error_to_message_handler( - request: ToolCallRequest, + request: ToolCallRequest, state, runtime ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: """Convert any error to a user-friendly message.""" response = yield request @@ -165,7 +165,7 @@ def test_on_tool_call_let_error_raise(): """Test that handler can let errors propagate.""" def let_raise_handler( - request: ToolCallRequest, + request: ToolCallRequest, state, runtime ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: """Just return the response as-is, letting errors raise.""" response = yield request @@ -193,7 +193,7 @@ def test_on_tool_call_with_handled_errors(): call_count = {"count": 0} def counting_handler( - request: ToolCallRequest, + request: ToolCallRequest, state, runtime ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: """Count how many times we're called.""" call_count["count"] += 1 @@ -225,7 +225,7 @@ def test_on_tool_call_must_return_value(): """Test that handler must return a ToolCallResponse.""" def no_return_handler( - request: ToolCallRequest, + request: ToolCallRequest, state, runtime ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: """Handler that doesn't return anything.""" response = yield request @@ -252,7 +252,7 @@ def test_on_tool_call_request_modification(): """Test that handler can modify the request before execution.""" def double_input_handler( - request: ToolCallRequest, + request: ToolCallRequest, state, runtime ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: """Double the input value.""" # Modify the tool call args @@ -356,7 +356,7 @@ def test_on_tool_call_multiple_yields(): attempts = {"count": 0} def multi_yield_handler( - request: ToolCallRequest, + request: ToolCallRequest, state, runtime ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: """Yield multiple times to track attempts.""" max_attempts = 3 From def2f147aed4782066e9eec2af0fc377da112b18 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Mon, 6 Oct 2025 16:16:19 -0400 Subject: [PATCH 05/12] x --- .../agents/middleware/tool_error_handling.py | 29 +++++----- .../langchain_v1/langchain/tools/tool_node.py | 7 ++- .../unit_tests/tools/test_on_tool_call.py | 58 +++++++++---------- 3 files changed, 44 insertions(+), 50 deletions(-) diff --git a/libs/langchain_v1/langchain/agents/middleware/tool_error_handling.py b/libs/langchain_v1/langchain/agents/middleware/tool_error_handling.py index 37441d9be593f..f165d49938545 100644 --- a/libs/langchain_v1/langchain/agents/middleware/tool_error_handling.py +++ b/libs/langchain_v1/langchain/agents/middleware/tool_error_handling.py @@ -9,7 +9,7 @@ import inspect import logging import time -from collections.abc import Callable +from types import UnionType from typing import TYPE_CHECKING, Union, get_args, get_origin, get_type_hints from langchain_core.messages import ToolMessage @@ -17,13 +17,10 @@ from langchain.agents.middleware.types import AgentMiddleware # Import ToolCallResponse locally to avoid circular import -from langchain.tools.tool_node import ToolCallResponse +from langchain.tools.tool_node import ToolCallRequest, ToolCallResponse if TYPE_CHECKING: - from collections.abc import Generator - from types import UnionType - - from langchain.tools.tool_node import ToolCallRequest, ToolCallResponse + from collections.abc import Callable, Generator logger = logging.getLogger(__name__) @@ -72,13 +69,14 @@ def _infer_retriable_types( if first_param.name in type_hints: origin = get_origin(first_param.annotation) # Handle Union types - if origin in [Union, UnionType]: # type: ignore[has-type] + if origin in [Union, UnionType]: args = get_args(first_param.annotation) if all(isinstance(arg, type) and issubclass(arg, Exception) for arg in args): return tuple(args) msg = ( "All types in retry predicate annotation must be Exception types. " - "For example, `def should_retry(e: Union[TimeoutError, ConnectionError]) -> bool`. " + "For example, `def should_retry(e: Union[TimeoutError, " + "ConnectionError]) -> bool`. " f"Got '{first_param.annotation}' instead." ) raise ValueError(msg) @@ -260,14 +258,13 @@ def on_tool_call( return response # If predicate is provided, check if we should retry - if self._retry_predicate is not None: - if not self._retry_predicate(exception): - logger.debug( - "Retry predicate returned False for %s in tool %s", - type(exception).__name__, - request.tool_call["name"], - ) - return response + if self._retry_predicate is not None and not self._retry_predicate(exception): + logger.debug( + "Retry predicate returned False for %s in tool %s", + type(exception).__name__, + request.tool_call["name"], + ) + return response # Last attempt - return error if attempt > self.max_retries: diff --git a/libs/langchain_v1/langchain/tools/tool_node.py b/libs/langchain_v1/langchain/tools/tool_node.py index 56d29fc95ce23..6be59457ecc06 100644 --- a/libs/langchain_v1/langchain/tools/tool_node.py +++ b/libs/langchain_v1/langchain/tools/tool_node.py @@ -152,7 +152,8 @@ def __post_init__(self) -> None: """Generator-based handler that intercepts tool execution. Receives a ToolCallRequest, state, and runtime; yields modified ToolCallRequests; -receives ToolCallResponses; and returns a final ToolCallResponse. Supports multiple yields for retry logic. +receives ToolCallResponses; and returns a final ToolCallResponse. Supports multiple +yields for retry logic. """ @@ -511,7 +512,7 @@ def _func( input: list[AnyMessage] | dict[str, Any] | BaseModel, config: RunnableConfig, *, - store: Optional[BaseStore], + store: BaseStore | None, ) -> Any: try: runtime = get_runtime() @@ -535,7 +536,7 @@ async def _afunc( input: list[AnyMessage] | dict[str, Any] | BaseModel, config: RunnableConfig, *, - store: Optional[BaseStore], + store: BaseStore | None, ) -> Any: try: runtime = get_runtime() diff --git a/libs/langchain_v1/tests/unit_tests/tools/test_on_tool_call.py b/libs/langchain_v1/tests/unit_tests/tools/test_on_tool_call.py index f844aff7352e3..5a21df58b404f 100644 --- a/libs/langchain_v1/tests/unit_tests/tools/test_on_tool_call.py +++ b/libs/langchain_v1/tests/unit_tests/tools/test_on_tool_call.py @@ -1,6 +1,7 @@ """Tests for on_tool_call handler functionality.""" from collections.abc import Generator +from typing import Any import pytest from langchain_core.messages import AIMessage, ToolMessage @@ -41,7 +42,7 @@ def test_on_tool_call_passthrough() -> None: """Test that a simple passthrough handler works.""" def passthrough_handler( - request: ToolCallRequest, state, runtime + request: ToolCallRequest, _state: Any, _runtime: Any ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: """Simply pass through without modification.""" response = yield request @@ -65,14 +66,14 @@ def passthrough_handler( assert tool_message.status != "error" -def test_on_tool_call_retry_success(): +def test_on_tool_call_retry_success() -> None: """Test that retry handler can recover from transient errors.""" # Reset counter if hasattr(rate_limit_tool, "_call_count"): rate_limit_tool._call_count = 0 def retry_handler( - request: ToolCallRequest, state, runtime + request: ToolCallRequest, _state: Any, _runtime: Any ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: """Retry up to 3 times.""" max_retries = 3 @@ -97,7 +98,8 @@ def retry_handler( status="error", ), ) - raise AssertionError("Unreachable code") + msg = "Unreachable code" + raise AssertionError(msg) tool_node = ToolNode([rate_limit_tool], on_tool_call=retry_handler, handle_tool_errors=False) result = tool_node.invoke( @@ -117,11 +119,11 @@ def retry_handler( assert tool_message.status != "error" -def test_on_tool_call_convert_error_to_message(): +def test_on_tool_call_convert_error_to_message() -> None: """Test that handler can convert raised errors to error messages.""" def error_to_message_handler( - request: ToolCallRequest, state, runtime + request: ToolCallRequest, _state: Any, _runtime: Any ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: """Convert any error to a user-friendly message.""" response = yield request @@ -161,11 +163,11 @@ def error_to_message_handler( assert tool_message.status == "error" -def test_on_tool_call_let_error_raise(): +def test_on_tool_call_let_error_raise() -> None: """Test that handler can let errors propagate.""" def let_raise_handler( - request: ToolCallRequest, state, runtime + request: ToolCallRequest, _state: Any, _runtime: Any ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: """Just return the response as-is, letting errors raise.""" response = yield request @@ -173,7 +175,7 @@ def let_raise_handler( tool_node = ToolNode([error_tool], on_tool_call=let_raise_handler, handle_tool_errors=False) - with pytest.raises(ValueError) as exc_info: + with pytest.raises(ValueError, match=r"Error with value: 5"): tool_node.invoke( { "messages": [ @@ -185,15 +187,13 @@ def let_raise_handler( } ) - assert "Error with value: 5" in str(exc_info.value) - -def test_on_tool_call_with_handled_errors(): +def test_on_tool_call_with_handled_errors() -> None: """Test interaction between on_tool_call and handle_tool_errors.""" call_count = {"count": 0} def counting_handler( - request: ToolCallRequest, state, runtime + request: ToolCallRequest, _state: Any, _runtime: Any ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: """Count how many times we're called.""" call_count["count"] += 1 @@ -221,19 +221,19 @@ def counting_handler( assert "Please fix your mistakes" in tool_message.content -def test_on_tool_call_must_return_value(): +def test_on_tool_call_must_return_value() -> None: """Test that handler must return a ToolCallResponse.""" def no_return_handler( - request: ToolCallRequest, state, runtime + request: ToolCallRequest, _state: Any, _runtime: Any ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: """Handler that doesn't return anything.""" - response = yield request + _ = yield request # Implicit return None tool_node = ToolNode([success_tool], on_tool_call=no_return_handler) - with pytest.raises(ValueError) as exc_info: + with pytest.raises(ValueError, match=r"must explicitly return a ToolCallResponse"): tool_node.invoke( { "messages": [ @@ -245,14 +245,12 @@ def no_return_handler( } ) - assert "must explicitly return a ToolCallResponse" in str(exc_info.value) - -def test_on_tool_call_request_modification(): +def test_on_tool_call_request_modification() -> None: """Test that handler can modify the request before execution.""" def double_input_handler( - request: ToolCallRequest, state, runtime + request: ToolCallRequest, _state: Any, _runtime: Any ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: """Double the input value.""" # Modify the tool call args @@ -286,17 +284,15 @@ def double_input_handler( assert tool_message.content == "20" -def test_on_tool_call_response_validation(): +def test_on_tool_call_response_validation() -> None: """Test that ToolCallResponse validates action and required fields.""" # Test action="return" requires result - with pytest.raises(ValueError) as exc_info: + with pytest.raises(ValueError, match=r"action='return' requires a result"): ToolCallResponse(action="return") - assert "action='return' requires a result" in str(exc_info.value) # Test action="raise" requires exception - with pytest.raises(ValueError) as exc_info: + with pytest.raises(ValueError, match=r"action='raise' requires an exception"): ToolCallResponse(action="raise") - assert "action='raise' requires an exception" in str(exc_info.value) # Valid responses should work ToolCallResponse( @@ -306,7 +302,7 @@ def test_on_tool_call_response_validation(): ToolCallResponse(action="raise", exception=ValueError("test")) -def test_on_tool_call_without_handler_backward_compat(): +def test_on_tool_call_without_handler_backward_compat() -> None: """Test that tools work without on_tool_call handler (backward compatibility).""" # Success case tool_node = ToolNode([success_tool]) @@ -324,7 +320,7 @@ def test_on_tool_call_without_handler_backward_compat(): # Error case with handle_tool_errors=False tool_node_error = ToolNode([error_tool], handle_tool_errors=False) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=r"Error with value: 5"): tool_node_error.invoke( { "messages": [ @@ -351,12 +347,12 @@ def test_on_tool_call_without_handler_backward_compat(): assert result["messages"][0].status == "error" -def test_on_tool_call_multiple_yields(): +def test_on_tool_call_multiple_yields() -> None: """Test that handler can yield multiple times for retries.""" attempts = {"count": 0} def multi_yield_handler( - request: ToolCallRequest, state, runtime + request: ToolCallRequest, _state: Any, _runtime: Any ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: """Yield multiple times to track attempts.""" max_attempts = 3 @@ -373,7 +369,7 @@ def multi_yield_handler( tool_node = ToolNode([error_tool], on_tool_call=multi_yield_handler, handle_tool_errors=False) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=r"Error with value: 5"): tool_node.invoke( { "messages": [ From 4f53ed3e9a4b2d645e3d83e09c2045a3b3a8a529 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Mon, 6 Oct 2025 16:51:24 -0400 Subject: [PATCH 06/12] x --- libs/langchain_v1/langchain/tools/tool_node.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langchain_v1/langchain/tools/tool_node.py b/libs/langchain_v1/langchain/tools/tool_node.py index 6be59457ecc06..1dab50527c65c 100644 --- a/libs/langchain_v1/langchain/tools/tool_node.py +++ b/libs/langchain_v1/langchain/tools/tool_node.py @@ -512,7 +512,7 @@ def _func( input: list[AnyMessage] | dict[str, Any] | BaseModel, config: RunnableConfig, *, - store: BaseStore | None, + store: Optional[BaseStore], ) -> Any: try: runtime = get_runtime() From fa533c44b756431d3dda284a1114db57ef25c52e Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Mon, 6 Oct 2025 16:54:00 -0400 Subject: [PATCH 07/12] x --- .../agents/middleware/test_on_tool_call_middleware.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/test_on_tool_call_middleware.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_on_tool_call_middleware.py index ad982f4ec7314..35c582845359c 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/test_on_tool_call_middleware.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_on_tool_call_middleware.py @@ -395,7 +395,3 @@ def on_tool_call( assert "on_tool_call_end" in call_log # before_model runs before on_tool_call assert call_log.index("before_model") < call_log.index("on_tool_call_start") - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) From ba9ec6d8951a41522a5d3deac3acf7cee2349b39 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Mon, 6 Oct 2025 16:56:16 -0400 Subject: [PATCH 08/12] x --- libs/langchain_v1/langchain/tools/tool_node.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/libs/langchain_v1/langchain/tools/tool_node.py b/libs/langchain_v1/langchain/tools/tool_node.py index 1dab50527c65c..d44a956d3effd 100644 --- a/libs/langchain_v1/langchain/tools/tool_node.py +++ b/libs/langchain_v1/langchain/tools/tool_node.py @@ -53,6 +53,7 @@ def my_tool(x: int) -> str: get_origin, get_type_hints, ) +from typing import Optional as Optional from langchain_core.messages import ( AIMessage, @@ -512,7 +513,9 @@ def _func( input: list[AnyMessage] | dict[str, Any] | BaseModel, config: RunnableConfig, *, - store: Optional[BaseStore], + # Optional[BaseStore] should not change to BaseStore | None + # until we support injection of store using `BaseStore | None` annotation + store: Optional[BaseStore], # noqa: UP045 ) -> Any: try: runtime = get_runtime() @@ -536,7 +539,9 @@ async def _afunc( input: list[AnyMessage] | dict[str, Any] | BaseModel, config: RunnableConfig, *, - store: BaseStore | None, + # Optional[BaseStore] should not change to BaseStore | None + # until we support injection of store using `BaseStore | None` annotation + store: Optional[BaseStore], # noqa: UP045 ) -> Any: try: runtime = get_runtime() From db58bfa5436803ea50fbbe918c48adb1f88ebd20 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Mon, 6 Oct 2025 17:32:04 -0400 Subject: [PATCH 09/12] rename to continue --- .../agents/middleware/tool_error_handling.py | 6 ++-- .../langchain/agents/middleware/types.py | 2 +- .../langchain_v1/langchain/tools/tool_node.py | 34 +++++++++---------- .../test_on_tool_call_middleware.py | 16 ++++----- .../unit_tests/tools/test_on_tool_call.py | 18 +++++----- 5 files changed, 37 insertions(+), 39 deletions(-) diff --git a/libs/langchain_v1/langchain/agents/middleware/tool_error_handling.py b/libs/langchain_v1/langchain/agents/middleware/tool_error_handling.py index f165d49938545..447a2c27e824f 100644 --- a/libs/langchain_v1/langchain/agents/middleware/tool_error_handling.py +++ b/libs/langchain_v1/langchain/agents/middleware/tool_error_handling.py @@ -238,7 +238,7 @@ def on_tool_call( response = yield request # Success - return immediately - if response.action == "return": + if response.action == "continue": return response # Error - check if we should retry @@ -368,7 +368,7 @@ def on_tool_call( response = yield request # Success - pass through - if response.action == "return": + if response.action == "continue": return response # Error - check if we should convert @@ -398,7 +398,7 @@ def on_tool_call( ) return ToolCallResponse( - action="return", + action="continue", result=tool_message, exception=exception, # Preserve for logging/debugging ) diff --git a/libs/langchain_v1/langchain/agents/middleware/types.py b/libs/langchain_v1/langchain/agents/middleware/types.py index 2e005c2478a83..d70ba37141c35 100644 --- a/libs/langchain_v1/langchain/agents/middleware/types.py +++ b/libs/langchain_v1/langchain/agents/middleware/types.py @@ -249,7 +249,7 @@ def on_tool_call( def on_tool_call(self, request, state, runtime): for attempt in range(3): response = yield request - if response.action == "return": + if response.action == "continue": return response if "rate limit" in str(response.exception): time.sleep(2**attempt) diff --git a/libs/langchain_v1/langchain/tools/tool_node.py b/libs/langchain_v1/langchain/tools/tool_node.py index d44a956d3effd..79f1e60ce6b4d 100644 --- a/libs/langchain_v1/langchain/tools/tool_node.py +++ b/libs/langchain_v1/langchain/tools/tool_node.py @@ -107,7 +107,6 @@ class ToolCallRequest: Attributes: tool_call: The tool call dict containing name, args, and id. tool: The BaseTool instance that will be invoked. - config: Runnable configuration for this execution. Note: tool_call["args"] can be mutated directly to modify arguments. @@ -115,7 +114,6 @@ class ToolCallRequest: tool_call: ToolCall tool: BaseTool - config: RunnableConfig @dataclass() @@ -123,24 +121,24 @@ class ToolCallResponse: """Response returned from on_tool_call handler after tool execution. The action field determines control flow: - - "return": Handler completed successfully, use result + - "continue": Handler completed successfully, use result - "raise": Handler wants to propagate the exception Attributes: - action: Control flow directive ("return" or "raise"). - result: ToolMessage or Command when action="return". + action: Control flow directive ("continue" or "raise"). + result: ToolMessage or Command when action="continue". exception: The exception when action="raise", or for logging when - action="return" with an error ToolMessage. + action="continue" with an error ToolMessage. """ - action: Literal["return", "raise"] + action: Literal["continue", "raise"] result: ToolMessage | Command | None = None exception: Exception | None = None def __post_init__(self) -> None: """Validate that required fields are present based on action.""" - if self.action == "return" and self.result is None: - msg = "action='return' requires a result" + if self.action == "continue" and self.result is None: + msg = "action='continue' requires a result" raise ValueError(msg) if self.action == "raise" and self.exception is None: msg = "action='raise' requires an exception" @@ -452,7 +450,7 @@ def retry_handler(request): \"\"\"Retry failed tool calls up to 3 times.\"\"\" for attempt in range(3): response = yield request - if response.action == "return": + if response.action == "continue": return response # Retry on error return response # Final attempt @@ -660,16 +658,16 @@ def _execute_tool_sync( tool_call_id=call["id"], status="error", ) - return ToolCallResponse(action="return", result=error_message, exception=e) + return ToolCallResponse(action="continue", result=error_message, exception=e) # Process successful response if isinstance(response, Command): # Validate Command before returning to handler validated_command = self._validate_tool_command(response, request.tool_call, input_type) - return ToolCallResponse(action="return", result=validated_command) + return ToolCallResponse(action="continue", result=validated_command) if isinstance(response, ToolMessage): response.content = cast("str | list", msg_content_output(response.content)) - return ToolCallResponse(action="return", result=response) + return ToolCallResponse(action="continue", result=response) msg = f"Tool {call['name']} returned unexpected type: {type(response)}" raise TypeError(msg) @@ -730,7 +728,7 @@ def _run_one( result = tool_response.result if result is None: - msg = "ToolCallResponse with action='return' must have a result" + msg = "ToolCallResponse with action='continue' must have a result" raise ValueError(msg) return result @@ -798,16 +796,16 @@ async def _execute_tool_async( tool_call_id=call["id"], status="error", ) - return ToolCallResponse(action="return", result=error_message, exception=e) + return ToolCallResponse(action="continue", result=error_message, exception=e) # Process successful response if isinstance(response, Command): # Validate Command before returning to handler validated_command = self._validate_tool_command(response, request.tool_call, input_type) - return ToolCallResponse(action="return", result=validated_command) + return ToolCallResponse(action="continue", result=validated_command) if isinstance(response, ToolMessage): response.content = cast("str | list", msg_content_output(response.content)) - return ToolCallResponse(action="return", result=response) + return ToolCallResponse(action="continue", result=response) msg = f"Tool {call['name']} returned unexpected type: {type(response)}" raise TypeError(msg) @@ -868,7 +866,7 @@ async def _arun_one( result = tool_response.result if result is None: - msg = "ToolCallResponse with action='return' must have a result" + msg = "ToolCallResponse with action='continue' must have a result" raise ValueError(msg) return result diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/test_on_tool_call_middleware.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_on_tool_call_middleware.py index 35c582845359c..5c63b9b0df3ba 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/test_on_tool_call_middleware.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_on_tool_call_middleware.py @@ -185,7 +185,7 @@ def on_tool_call( attempt_count += 1 response = yield request - if response.action == "return": + if response.action == "continue": return response if response.action == "raise" and attempt < max_retries - 1: @@ -194,7 +194,7 @@ def on_tool_call( # Convert error to success message return ToolCallResponse( - action="return", + action="continue", result=ToolMessage( content=f"Failed after {max_retries} attempts", name=request.tool_call["name"], @@ -311,7 +311,7 @@ def on_tool_call( call_log.append(f"retry_attempt_{attempt + 1}") response = yield request - if response.action == "return": + if response.action == "continue": call_log.append("retry_success") return response @@ -357,15 +357,15 @@ def test_mixed_middleware(): class MixedMiddleware(AgentMiddleware): """Middleware with multiple hooks.""" - def before_model(self, state, runtime): - call_log.append("before_model") - return None - def on_tool_call( self, request: ToolCallRequest, state, runtime ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: call_log.append("on_tool_call_start") - response = yield request + for _ in range(3): + response = yield request + if response.action == "continue": + break + # response = yield request call_log.append("on_tool_call_end") return response diff --git a/libs/langchain_v1/tests/unit_tests/tools/test_on_tool_call.py b/libs/langchain_v1/tests/unit_tests/tools/test_on_tool_call.py index 5a21df58b404f..103f67e734e34 100644 --- a/libs/langchain_v1/tests/unit_tests/tools/test_on_tool_call.py +++ b/libs/langchain_v1/tests/unit_tests/tools/test_on_tool_call.py @@ -81,7 +81,7 @@ def retry_handler( for attempt in range(max_retries): response = yield request - if response.action == "return": + if response.action == "continue": return response # Retry on error @@ -90,7 +90,7 @@ def retry_handler( # Final attempt failed - convert to error message return ToolCallResponse( - action="return", + action="continue", result=ToolMessage( content=f"Failed after {max_retries} attempts", name=request.tool_call["name"], @@ -130,7 +130,7 @@ def error_to_message_handler( if response.action == "raise": return ToolCallResponse( - action="return", + action="continue", result=ToolMessage( content=f"Tool failed: {response.exception}", name=request.tool_call["name"], @@ -201,7 +201,7 @@ def counting_handler( return response # When handle_tool_errors=True, errors are converted to ToolMessages - # so handler sees action="return" + # so handler sees action="continue" tool_node = ToolNode([error_tool], on_tool_call=counting_handler, handle_tool_errors=True) result = tool_node.invoke( { @@ -286,9 +286,9 @@ def double_input_handler( def test_on_tool_call_response_validation() -> None: """Test that ToolCallResponse validates action and required fields.""" - # Test action="return" requires result - with pytest.raises(ValueError, match=r"action='return' requires a result"): - ToolCallResponse(action="return") + # Test action="continue" requires result + with pytest.raises(ValueError, match=r"action='continue' requires a result"): + ToolCallResponse(action="continue") # Test action="raise" requires exception with pytest.raises(ValueError, match=r"action='raise' requires an exception"): @@ -296,7 +296,7 @@ def test_on_tool_call_response_validation() -> None: # Valid responses should work ToolCallResponse( - action="return", + action="continue", result=ToolMessage(content="test", tool_call_id="1", name="test"), ) ToolCallResponse(action="raise", exception=ValueError("test")) @@ -361,7 +361,7 @@ def multi_yield_handler( attempts["count"] += 1 response = yield request - if response.action == "return": + if response.action == "continue": return response # All attempts failed From 076c6f6b417611b32146942b9c0a74b1664f093c Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Mon, 6 Oct 2025 17:35:27 -0400 Subject: [PATCH 10/12] x --- libs/langchain_v1/langchain/tools/tool_node.py | 16 ++++++---------- .../middleware/test_on_tool_call_middleware.py | 1 - .../tests/unit_tests/tools/test_on_tool_call.py | 1 - 3 files changed, 6 insertions(+), 12 deletions(-) diff --git a/libs/langchain_v1/langchain/tools/tool_node.py b/libs/langchain_v1/langchain/tools/tool_node.py index 79f1e60ce6b4d..c1610f51f713a 100644 --- a/libs/langchain_v1/langchain/tools/tool_node.py +++ b/libs/langchain_v1/langchain/tools/tool_node.py @@ -596,7 +596,7 @@ def _combine_tool_outputs( return combined_outputs def _execute_tool_sync( - self, request: ToolCallRequest, input_type: Literal["list", "dict", "tool_calls"] + self, request: ToolCallRequest, input_type: Literal["list", "dict", "tool_calls"], config: RunnableConfig ) -> ToolCallResponse: """Execute tool and return response. @@ -605,7 +605,6 @@ def _execute_tool_sync( """ call = request.tool_call tool = request.tool - config = request.config call_args = {**call, "type": "tool_call"} try: @@ -690,11 +689,10 @@ def _run_one( tool_request = ToolCallRequest( tool_call=call, tool=tool, - config=config, ) if self._on_tool_call is None: - tool_response = self._execute_tool_sync(tool_request, input_type) + tool_response = self._execute_tool_sync(tool_request, input_type, config) else: # Generator protocol: start generator, send responses, receive requests gen = self._on_tool_call(tool_request, input, runtime) @@ -706,7 +704,7 @@ def _run_one( raise ValueError(msg) while True: - tool_response = self._execute_tool_sync(request, input_type) + tool_response = self._execute_tool_sync(request, input_type, config) try: request = gen.send(tool_response) except StopIteration as e: @@ -734,7 +732,7 @@ def _run_one( return result async def _execute_tool_async( - self, request: ToolCallRequest, input_type: Literal["list", "dict", "tool_calls"] + self, request: ToolCallRequest, input_type: Literal["list", "dict", "tool_calls"], config: RunnableConfig ) -> ToolCallResponse: """Execute tool asynchronously and return response. @@ -743,7 +741,6 @@ async def _execute_tool_async( """ call = request.tool_call tool = request.tool - config = request.config call_args = {**call, "type": "tool_call"} try: @@ -828,11 +825,10 @@ async def _arun_one( tool_request = ToolCallRequest( tool_call=call, tool=tool, - config=config, ) if self._on_tool_call is None: - tool_response = await self._execute_tool_async(tool_request, input_type) + tool_response = await self._execute_tool_async(tool_request, input_type, config) else: # Generator protocol: handler is sync generator, tool execution is async gen = self._on_tool_call(tool_request, input, runtime) @@ -844,7 +840,7 @@ async def _arun_one( raise ValueError(msg) while True: - tool_response = await self._execute_tool_async(request, input_type) + tool_response = await self._execute_tool_async(request, input_type, config) try: request = gen.send(tool_response) except StopIteration as e: diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/test_on_tool_call_middleware.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_on_tool_call_middleware.py index 5c63b9b0df3ba..c8ddaf7c3d746 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/test_on_tool_call_middleware.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_on_tool_call_middleware.py @@ -254,7 +254,6 @@ def on_tool_call( modified_request = ToolCallRequest( tool_call=modified_tool_call, tool=request.tool, - config=request.config, ) response = yield modified_request return response diff --git a/libs/langchain_v1/tests/unit_tests/tools/test_on_tool_call.py b/libs/langchain_v1/tests/unit_tests/tools/test_on_tool_call.py index 103f67e734e34..4111f9e9483bc 100644 --- a/libs/langchain_v1/tests/unit_tests/tools/test_on_tool_call.py +++ b/libs/langchain_v1/tests/unit_tests/tools/test_on_tool_call.py @@ -261,7 +261,6 @@ def double_input_handler( modified_request = ToolCallRequest( tool_call=modified_tool_call, tool=request.tool, - config=request.config, ) response = yield modified_request return response From 40b4c69a5a9aecdcb8ec58539a80f964ea2b1a0a Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Mon, 6 Oct 2025 17:38:01 -0400 Subject: [PATCH 11/12] x --- .../agents/middleware/test_on_tool_call_middleware.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/test_on_tool_call_middleware.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_on_tool_call_middleware.py index c8ddaf7c3d746..0698f5506048a 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/test_on_tool_call_middleware.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_on_tool_call_middleware.py @@ -356,15 +356,15 @@ def test_mixed_middleware(): class MixedMiddleware(AgentMiddleware): """Middleware with multiple hooks.""" + def before_model(self, state, runtime): + call_log.append("before_model") + return None + def on_tool_call( self, request: ToolCallRequest, state, runtime ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: call_log.append("on_tool_call_start") - for _ in range(3): - response = yield request - if response.action == "continue": - break - # response = yield request + response = yield request call_log.append("on_tool_call_end") return response From 3c92e986f6cae2027b7ff662bb2107b847116daf Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Mon, 6 Oct 2025 21:52:18 -0400 Subject: [PATCH 12/12] x --- .../langchain/agents/middleware/__init__.py | 4 ++-- .../agents/middleware/tool_error_handling.py | 22 +++++++++---------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/libs/langchain_v1/langchain/agents/middleware/__init__.py b/libs/langchain_v1/langchain/agents/middleware/__init__.py index 21b0debba87a6..e4329f5894104 100644 --- a/libs/langchain_v1/langchain/agents/middleware/__init__.py +++ b/libs/langchain_v1/langchain/agents/middleware/__init__.py @@ -9,7 +9,7 @@ from .tool_call_limit import ToolCallLimitMiddleware from .tool_error_handling import ( ErrorToMessageMiddleware, - RetryMiddleware, + ToolRetryMiddleware, ) from .tool_selection import LLMToolSelectorMiddleware from .types import ( @@ -36,7 +36,7 @@ "PIIDetectionError", "PIIMiddleware", "PlanningMiddleware", - "RetryMiddleware", + "ToolRetryMiddleware", "SummarizationMiddleware", "ToolCallLimitMiddleware", "after_model", diff --git a/libs/langchain_v1/langchain/agents/middleware/tool_error_handling.py b/libs/langchain_v1/langchain/agents/middleware/tool_error_handling.py index 447a2c27e824f..e0e825f94e201 100644 --- a/libs/langchain_v1/langchain/agents/middleware/tool_error_handling.py +++ b/libs/langchain_v1/langchain/agents/middleware/tool_error_handling.py @@ -97,7 +97,7 @@ def _infer_retriable_types( return (Exception,) -class RetryMiddleware(AgentMiddleware): +class ToolRetryMiddleware(AgentMiddleware): """Retry failed tool calls with constant delay. This middleware catches tool execution errors and retries them up to a maximum @@ -108,9 +108,9 @@ class RetryMiddleware(AgentMiddleware): Retry only network errors: ```python - from langchain.agents.middleware import RetryMiddleware + from langchain.agents.middleware import ToolRetryMiddleware - middleware = RetryMiddleware( + middleware = ToolRetryMiddleware( max_retries=3, delay=2.0, retry_on=(TimeoutError, ConnectionError), @@ -139,7 +139,7 @@ def should_retry(e: Exception) -> bool: return False - middleware = RetryMiddleware( + middleware = ToolRetryMiddleware( max_retries=3, retry_on=should_retry, ) @@ -149,7 +149,7 @@ def should_retry(e: Exception) -> bool: ```python from langchain.agents.middleware import ( - RetryMiddleware, + ToolRetryMiddleware, ErrorToMessageMiddleware, ) @@ -158,7 +158,7 @@ def should_retry(e: Exception) -> bool: tools=[my_tool], middleware=[ # Outer: retry network errors - RetryMiddleware( + ToolRetryMiddleware( max_retries=3, delay=2.0, retry_on=(TimeoutError, ConnectionError), @@ -231,7 +231,7 @@ def __init__( raise ValueError(msg) def on_tool_call( - self, request: ToolCallRequest + self, request: ToolCallRequest, state, runtime ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: """Retry tool execution on failures.""" for attempt in range(1, self.max_retries + 2): # +1 for initial, +1 for inclusive @@ -287,7 +287,7 @@ def on_tool_call( continue # Should never reach here - msg = f"Unexpected control flow in RetryMiddleware for tool {request.tool_call['name']}" + msg = f"Unexpected control flow in ToolRetryMiddleware for tool {request.tool_call['name']}" raise RuntimeError(msg) @@ -321,7 +321,7 @@ class ErrorToMessageMiddleware(AgentMiddleware): ```python from langchain.agents.middleware import ( - RetryMiddleware, + ToolRetryMiddleware, ErrorToMessageMiddleware, ) @@ -330,7 +330,7 @@ class ErrorToMessageMiddleware(AgentMiddleware): tools=[my_tool], middleware=[ # Outer: retry all errors - RetryMiddleware(max_retries=3), + ToolRetryMiddleware(max_retries=3), # Inner: convert validation errors to messages ErrorToMessageMiddleware( exception_types=(ValidationError,), @@ -362,7 +362,7 @@ def __init__( self.message_template = message_template def on_tool_call( - self, request: ToolCallRequest + self, request: ToolCallRequest, state, runtime ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: """Convert matching errors to ToolMessages.""" response = yield request