diff --git a/libs/langchain_v1/langchain/agents/middleware/__init__.py b/libs/langchain_v1/langchain/agents/middleware/__init__.py index 5915b75164a74..e4329f5894104 100644 --- a/libs/langchain_v1/langchain/agents/middleware/__init__.py +++ b/libs/langchain_v1/langchain/agents/middleware/__init__.py @@ -7,6 +7,10 @@ from .prompt_caching import AnthropicPromptCachingMiddleware from .summarization import SummarizationMiddleware from .tool_call_limit import ToolCallLimitMiddleware +from .tool_error_handling import ( + ErrorToMessageMiddleware, + ToolRetryMiddleware, +) from .tool_selection import LLMToolSelectorMiddleware from .types import ( AgentMiddleware, @@ -24,6 +28,7 @@ "AgentState", # should move to langchain-anthropic if we decide to keep it "AnthropicPromptCachingMiddleware", + "ErrorToMessageMiddleware", "HumanInTheLoopMiddleware", "LLMToolSelectorMiddleware", "ModelFallbackMiddleware", @@ -31,6 +36,7 @@ "PIIDetectionError", "PIIMiddleware", "PlanningMiddleware", + "ToolRetryMiddleware", "SummarizationMiddleware", "ToolCallLimitMiddleware", "after_model", diff --git a/libs/langchain_v1/langchain/agents/middleware/tool_error_handling.py b/libs/langchain_v1/langchain/agents/middleware/tool_error_handling.py new file mode 100644 index 0000000000000..e0e825f94e201 --- /dev/null +++ b/libs/langchain_v1/langchain/agents/middleware/tool_error_handling.py @@ -0,0 +1,406 @@ +"""Middleware for handling tool execution errors in agents. + +This module provides composable middleware for error handling, retries, +and error-to-message conversion in tool execution workflows. +""" + +from __future__ import annotations + +import inspect +import logging +import time +from types import UnionType +from typing import TYPE_CHECKING, Union, get_args, get_origin, get_type_hints + +from langchain_core.messages import ToolMessage + +from langchain.agents.middleware.types import AgentMiddleware + +# Import ToolCallResponse locally to avoid circular import +from langchain.tools.tool_node import ToolCallRequest, ToolCallResponse + +if TYPE_CHECKING: + from collections.abc import Callable, Generator + +logger = logging.getLogger(__name__) + + +# Default retriable exception types - transient errors that may succeed on retry +DEFAULT_RETRIABLE_EXCEPTIONS = ( + # Network and connection errors + ConnectionError, + TimeoutError, + # HTTP client errors are typically not retriable, but these are exceptions: + # - 429: Rate limit (temporary) + # - 503: Service unavailable (temporary) + # Note: Specific HTTP libraries may define their own exception types +) + + +def _infer_retriable_types( + predicate: Callable[[Exception], bool], +) -> tuple[type[Exception], ...]: + """Infer exception types from a retry predicate function's type annotations. + + Analyzes the type annotations of a predicate function to determine which + exception types it's designed to handle for retry decisions. + + Args: + predicate: A callable that takes an exception and returns whether to retry. + The first parameter should be type-annotated with exception type(s). + + Returns: + Tuple of exception types that the predicate handles. Returns (Exception,) + if no specific type information is available. + + Raises: + ValueError: If the predicate's annotation contains non-Exception types. + """ + sig = inspect.signature(predicate) + params = list(sig.parameters.values()) + if params: + # Skip self/cls if it's a method + if params[0].name in ["self", "cls"] and len(params) == 2: + first_param = params[1] + else: + first_param = params[0] + + type_hints = get_type_hints(predicate) + if first_param.name in type_hints: + origin = get_origin(first_param.annotation) + # Handle Union types + if origin in [Union, UnionType]: + args = get_args(first_param.annotation) + if all(isinstance(arg, type) and issubclass(arg, Exception) for arg in args): + return tuple(args) + msg = ( + "All types in retry predicate annotation must be Exception types. " + "For example, `def should_retry(e: Union[TimeoutError, " + "ConnectionError]) -> bool`. " + f"Got '{first_param.annotation}' instead." + ) + raise ValueError(msg) + + # Handle single exception type + exception_type = type_hints[first_param.name] + if isinstance(exception_type, type) and issubclass(exception_type, Exception): + return (exception_type,) + msg = ( + "Retry predicate must be annotated with Exception type(s). " + "For example, `def should_retry(e: TimeoutError) -> bool` or " + "`def should_retry(e: Union[TimeoutError, ConnectionError]) -> bool`. " + f"Got '{exception_type}' instead." + ) + raise ValueError(msg) + + # No type information - return Exception for backward compatibility + return (Exception,) + + +class ToolRetryMiddleware(AgentMiddleware): + """Retry failed tool calls with constant delay. + + This middleware catches tool execution errors and retries them up to a maximum + number of attempts with a constant delay between retries. It operates at the + outermost layer of middleware composition to catch all errors. + + Examples: + Retry only network errors: + + ```python + from langchain.agents.middleware import ToolRetryMiddleware + + middleware = ToolRetryMiddleware( + max_retries=3, + delay=2.0, + retry_on=(TimeoutError, ConnectionError), + ) + + agent = create_agent( + model="openai:gpt-4o", + tools=[my_tool], + middleware=[middleware], + ) + ``` + + Use predicate function for custom retry logic: + + ```python + from langchain.tools.tool_node import ToolInvocationError + + + def should_retry(e: Exception) -> bool: + # Don't retry validation errors from LLM + if isinstance(e, ToolInvocationError): + return False + # Retry network errors + if isinstance(e, (TimeoutError, ConnectionError)): + return True + return False + + + middleware = ToolRetryMiddleware( + max_retries=3, + retry_on=should_retry, + ) + ``` + + Compose with error conversion: + + ```python + from langchain.agents.middleware import ( + ToolRetryMiddleware, + ErrorToMessageMiddleware, + ) + + agent = create_agent( + model="openai:gpt-4o", + tools=[my_tool], + middleware=[ + # Outer: retry network errors + ToolRetryMiddleware( + max_retries=3, + delay=2.0, + retry_on=(TimeoutError, ConnectionError), + ), + # Inner: convert validation errors to messages + ErrorToMessageMiddleware( + exception_types=(ValidationError,), + ), + ], + ) + ``` + """ + + def __init__( + self, + *, + max_retries: int = 3, + delay: float = 1.0, + retry_on: type[Exception] + | tuple[type[Exception], ...] + | Callable[[Exception], bool] = DEFAULT_RETRIABLE_EXCEPTIONS, + ) -> None: + """Initialize retry middleware. + + Args: + max_retries: Maximum number of retry attempts. Total attempts will be + max_retries + 1 (initial attempt plus retries). + delay: Constant delay in seconds between retry attempts. + retry_on: Specifies which exceptions should be retried. Can be: + - **type[Exception]**: Retry only this exception type + - **tuple[type[Exception], ...]**: Retry these exception types + - **Callable[[Exception], bool]**: Predicate function that returns + True if the exception should be retried. Type annotations on the + callable are used to filter which exceptions are passed to it. + Defaults to ``DEFAULT_RETRIABLE_EXCEPTIONS`` (ConnectionError, TimeoutError). + """ + super().__init__() + if max_retries < 0: + msg = "max_retries must be non-negative" + raise ValueError(msg) + if delay < 0: + msg = "delay must be non-negative" + raise ValueError(msg) + + self.max_retries = max_retries + self.delay = delay + self._retry_on = retry_on + + # Determine which exception types to check + if isinstance(retry_on, type) and issubclass(retry_on, Exception): + self._retriable_types = (retry_on,) + self._retry_predicate = None + elif isinstance(retry_on, tuple): + if not retry_on: + msg = "retry_on tuple must not be empty" + raise ValueError(msg) + if not all(isinstance(t, type) and issubclass(t, Exception) for t in retry_on): + msg = "All elements in retry_on tuple must be Exception types" + raise ValueError(msg) + self._retriable_types = retry_on + self._retry_predicate = None + elif callable(retry_on): + self._retriable_types = _infer_retriable_types(retry_on) + self._retry_predicate = retry_on + else: + msg = ( + "retry_on must be an Exception type, tuple of Exception types, " + f"or callable. Got {type(retry_on)}" + ) + raise ValueError(msg) + + def on_tool_call( + self, request: ToolCallRequest, state, runtime + ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: + """Retry tool execution on failures.""" + for attempt in range(1, self.max_retries + 2): # +1 for initial, +1 for inclusive + response = yield request + + # Success - return immediately + if response.action == "continue": + return response + + # Error - check if we should retry + if response.action == "raise": + exception = response.exception + if exception is None: + msg = "ToolCallResponse with action='raise' must have an exception" + raise ValueError(msg) + + # Check if this exception type is retriable + if not isinstance(exception, self._retriable_types): + logger.debug( + "Exception %s is not retriable for tool %s", + type(exception).__name__, + request.tool_call["name"], + ) + return response + + # If predicate is provided, check if we should retry + if self._retry_predicate is not None and not self._retry_predicate(exception): + logger.debug( + "Retry predicate returned False for %s in tool %s", + type(exception).__name__, + request.tool_call["name"], + ) + return response + + # Last attempt - return error + if attempt > self.max_retries: + logger.debug( + "Max retries (%d) reached for tool %s", + self.max_retries, + request.tool_call["name"], + ) + return response + + # Retry - log and delay + logger.debug( + "Retrying tool %s (attempt %d/%d) after error: %s", + request.tool_call["name"], + attempt, + self.max_retries + 1, + type(exception).__name__, + ) + time.sleep(self.delay) + continue + + # Should never reach here + msg = f"Unexpected control flow in ToolRetryMiddleware for tool {request.tool_call['name']}" + raise RuntimeError(msg) + + +class ErrorToMessageMiddleware(AgentMiddleware): + """Convert specific exception types to ToolMessages. + + This middleware intercepts errors and converts them into ToolMessages that + can be sent back to the model as feedback. This is useful for errors caused + by invalid model inputs where the model needs feedback to correct its behavior. + + Examples: + Convert validation errors to messages: + + ```python + from langchain.agents.middleware import ErrorToMessageMiddleware + from langchain.tools.tool_node import ToolInvocationError + + middleware = ErrorToMessageMiddleware( + exception_types=(ToolInvocationError,), + message_template="Invalid arguments: {error}. Please fix and try again.", + ) + + agent = create_agent( + model="openai:gpt-4o", + tools=[my_tool], + middleware=[middleware], + ) + ``` + + Compose with retry for network errors: + + ```python + from langchain.agents.middleware import ( + ToolRetryMiddleware, + ErrorToMessageMiddleware, + ) + + agent = create_agent( + model="openai:gpt-4o", + tools=[my_tool], + middleware=[ + # Outer: retry all errors + ToolRetryMiddleware(max_retries=3), + # Inner: convert validation errors to messages + ErrorToMessageMiddleware( + exception_types=(ValidationError,), + ), + ], + ) + ``` + """ + + def __init__( + self, + *, + exception_types: tuple[type[Exception], ...], + message_template: str = "Error: {error}", + ) -> None: + """Initialize error conversion middleware. + + Args: + exception_types: Tuple of exception types to convert to messages. + message_template: Template string for error messages. Can use ``{error}`` + placeholder for the exception string representation. + """ + super().__init__() + if not exception_types: + msg = "exception_types must not be empty" + raise ValueError(msg) + + self.exception_types = exception_types + self.message_template = message_template + + def on_tool_call( + self, request: ToolCallRequest, state, runtime + ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: + """Convert matching errors to ToolMessages.""" + response = yield request + + # Success - pass through + if response.action == "continue": + return response + + # Error - check if we should convert + if response.action == "raise": + exception = response.exception + if exception is None: + msg = "ToolCallResponse with action='raise' must have an exception" + raise ValueError(msg) + + # Check if exception type matches + if not isinstance(exception, self.exception_types): + return response + + # Convert to ToolMessage + logger.debug( + "Converting %s to ToolMessage for tool %s", + type(exception).__name__, + request.tool_call["name"], + ) + + error_message = self.message_template.format(error=str(exception)) + tool_message = ToolMessage( + content=error_message, + name=request.tool_call["name"], + tool_call_id=request.tool_call["id"], + status="error", + ) + + return ToolCallResponse( + action="continue", + result=tool_message, + exception=exception, # Preserve for logging/debugging + ) + + return response diff --git a/libs/langchain_v1/langchain/agents/middleware/types.py b/libs/langchain_v1/langchain/agents/middleware/types.py index 6aea90ab9fc35..d70ba37141c35 100644 --- a/libs/langchain_v1/langchain/agents/middleware/types.py +++ b/libs/langchain_v1/langchain/agents/middleware/types.py @@ -19,7 +19,7 @@ from langchain_core.runnables import run_in_executor if TYPE_CHECKING: - from collections.abc import Awaitable + from collections.abc import Awaitable, Generator # needed as top level import for pydantic schema generation on AgentState from langchain_core.messages import AnyMessage # noqa: TC002 @@ -36,6 +36,7 @@ from langgraph.types import Command from langchain.agents.structured_output import ResponseFormat + from langchain.tools.tool_node import ToolCallRequest, ToolCallResponse __all__ = [ "AgentMiddleware", @@ -215,6 +216,48 @@ async def aretry_model_request( None, self.retry_model_request, error, request, state, runtime, attempt ) + def on_tool_call( + self, + request: ToolCallRequest, + state: StateT, + runtime: Runtime[ContextT], + ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: + """Intercept tool execution to implement retry logic, monitoring, or request modification. + + Provides generator-based control over the complete tool execution lifecycle. + Multiple middleware can define this hook; they compose automatically with + outer middleware wrapping inner middleware (first defined = outermost layer). + + Generator Protocol: + 1. Yield a ToolCallRequest (potentially modified from the input) + 2. Receive a ToolCallResponse via .send() + 3. Optionally yield again to retry + 4. Return the final ToolCallResponse to propagate + + Args: + request: Tool invocation details including tool_call, tool instance, and config. + state: Current agent state (readonly context). + runtime: LangGraph runtime for accessing user context (readonly context). + + Returns: + Generator for request/response interception. + + Example: + Retry on rate limit with exponential backoff: + + ```python + def on_tool_call(self, request, state, runtime): + for attempt in range(3): + response = yield request + if response.action == "continue": + return response + if "rate limit" in str(response.exception): + time.sleep(2**attempt) + continue + return response + ``` + """ + class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]): """Callable with AgentState and Runtime as arguments.""" diff --git a/libs/langchain_v1/langchain/agents/middleware_agent.py b/libs/langchain_v1/langchain/agents/middleware_agent.py index b4f6370446b05..ac892f14a3bc5 100644 --- a/libs/langchain_v1/langchain/agents/middleware_agent.py +++ b/libs/langchain_v1/langchain/agents/middleware_agent.py @@ -1,7 +1,7 @@ """Middleware agent implementation.""" import itertools -from collections.abc import Callable, Sequence +from collections.abc import Callable, Generator, Sequence from typing import Annotated, Any, cast, get_args, get_origin, get_type_hints from langchain_core.language_models.chat_models import BaseChatModel @@ -35,12 +35,99 @@ ) from langchain.chat_models import init_chat_model from langchain.tools import ToolNode +from langchain.tools.tool_node import ToolCallHandler, ToolCallRequest, ToolCallResponse STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes." ResponseT = TypeVar("ResponseT") +def _chain_tool_call_handlers( + handlers: list[ToolCallHandler], +) -> ToolCallHandler | None: + """Compose multiple tool call handlers into a single middleware stack. + + Args: + handlers: Handlers in middleware order (first = outermost layer). + + Returns: + Single composed handler, or None if handlers is empty. + """ + if not handlers: + return None + + if len(handlers) == 1: + return handlers[0] + + def _extract_return_value(stop_iteration: StopIteration) -> ToolCallResponse: + """Extract ToolCallResponse from StopIteration, validating protocol compliance.""" + if stop_iteration.value is None: + msg = "on_tool_call handler must explicitly return a ToolCallResponse" + raise ValueError(msg) + return stop_iteration.value + + def compose_two(outer: ToolCallHandler, inner: ToolCallHandler) -> ToolCallHandler: + """Compose two handlers where outer wraps inner.""" + + def composed( + request: ToolCallRequest, + state: Any, + runtime: Any, + ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: + outer_gen = outer(request, state, runtime) + + # Initialize outer generator + try: + outer_request = next(outer_gen) + except StopIteration as e: + return _extract_return_value(e) + + # Outer retry loop + while True: + inner_gen = inner(outer_request, state, runtime) + + # Initialize inner generator + try: + inner_request = next(inner_gen) + except StopIteration as e: + # Inner returned immediately - send to outer + inner_response = _extract_return_value(e) + try: + outer_request = outer_gen.send(inner_response) + continue # Outer retrying + except StopIteration as e: + return _extract_return_value(e) + + # Inner retry loop - yield to next layer (or tool) + while True: + tool_response = yield inner_request + + try: + inner_request = inner_gen.send(tool_response) + # Inner retrying - continue inner loop + except StopIteration as e: + # Inner done - send response to outer + inner_response = _extract_return_value(e) + break + + # Send inner's final response to outer + try: + outer_request = outer_gen.send(inner_response) + # Outer retrying - continue outer loop + except StopIteration as e: + # Outer done - return final response + return _extract_return_value(e) + + return composed + + # Compose right-to-left: handlers[0](handlers[1](...(handlers[-1](tool)))) + result = handlers[-1] + for handler in reversed(handlers[:-1]): + result = compose_two(handler, result) + + return result + + def _resolve_schema(schemas: set[type], schema_name: str, omit_flag: str | None = None) -> type: """Resolve schema by merging schemas and optionally respecting OmitFromSchema annotations. @@ -226,6 +313,20 @@ def create_agent( # noqa: PLR0915 structured_output_tools[structured_tool_info.tool.name] = structured_tool_info middleware_tools = [t for m in middleware for t in getattr(m, "tools", [])] + # Validate middleware and collect handlers + assert len({m.name for m in middleware}) == len(middleware), ( # noqa: S101 + "Please remove duplicate middleware instances." + ) + middleware_w_on_tool_call = [ + m for m in middleware if m.__class__.on_tool_call is not AgentMiddleware.on_tool_call + ] + + # Chain all on_tool_call handlers into a single composed handler + on_tool_call_handler = None + if middleware_w_on_tool_call: + handlers = [m.on_tool_call for m in middleware_w_on_tool_call] + on_tool_call_handler = _chain_tool_call_handlers(handlers) + # Setup tools tool_node: ToolNode | None = None if isinstance(tools, list): @@ -237,7 +338,11 @@ def create_agent( # noqa: PLR0915 available_tools = middleware_tools + regular_tools # Only create ToolNode if we have client-side tools - tool_node = ToolNode(tools=available_tools) if available_tools else None + tool_node = ( + ToolNode(tools=available_tools, on_tool_call=on_tool_call_handler) + if available_tools + else None + ) # Default tools for ModelRequest initialization # Include built-ins and regular tools (can be changed dynamically by middleware) @@ -248,7 +353,7 @@ def create_agent( # noqa: PLR0915 if tool_node: # Add middleware tools to existing ToolNode available_tools = list(tool_node.tools_by_name.values()) + middleware_tools - tool_node = ToolNode(available_tools) + tool_node = ToolNode(available_tools, on_tool_call=on_tool_call_handler) # default_tools includes all client-side tools (no built-ins or structured tools) default_tools = available_tools @@ -256,10 +361,6 @@ def create_agent( # noqa: PLR0915 # No tools provided, only middleware_tools available default_tools = middleware_tools - # validate middleware - assert len({m.name for m in middleware}) == len(middleware), ( # noqa: S101 - "Please remove duplicate middleware instances." - ) middleware_w_before = [ m for m in middleware diff --git a/libs/langchain_v1/langchain/tools/tool_node.py b/libs/langchain_v1/langchain/tools/tool_node.py index 2c1df81cf2962..c1610f51f713a 100644 --- a/libs/langchain_v1/langchain/tools/tool_node.py +++ b/libs/langchain_v1/langchain/tools/tool_node.py @@ -38,21 +38,22 @@ def my_tool(x: int) -> str: import asyncio import inspect import json +from collections.abc import Callable, Generator, Sequence from copy import copy, deepcopy -from dataclasses import replace +from dataclasses import dataclass, replace from types import UnionType from typing import ( TYPE_CHECKING, Annotated, Any, Literal, - Optional, Union, cast, get_args, get_origin, get_type_hints, ) +from typing import Optional as Optional from langchain_core.messages import ( AIMessage, @@ -75,12 +76,11 @@ def my_tool(x: int) -> str: from langgraph._internal._runnable import RunnableCallable from langgraph.errors import GraphBubbleUp from langgraph.graph.message import REMOVE_ALL_MESSAGES +from langgraph.runtime import get_runtime from langgraph.types import Command, Send from pydantic import BaseModel, ValidationError if TYPE_CHECKING: - from collections.abc import Callable, Sequence - from langchain_core.runnables import RunnableConfig from langgraph.store.base import BaseStore @@ -100,6 +100,62 @@ def my_tool(x: int) -> str: ) +@dataclass() +class ToolCallRequest: + """Request passed to on_tool_call handler before tool execution. + + Attributes: + tool_call: The tool call dict containing name, args, and id. + tool: The BaseTool instance that will be invoked. + + Note: + tool_call["args"] can be mutated directly to modify arguments. + """ + + tool_call: ToolCall + tool: BaseTool + + +@dataclass() +class ToolCallResponse: + """Response returned from on_tool_call handler after tool execution. + + The action field determines control flow: + - "continue": Handler completed successfully, use result + - "raise": Handler wants to propagate the exception + + Attributes: + action: Control flow directive ("continue" or "raise"). + result: ToolMessage or Command when action="continue". + exception: The exception when action="raise", or for logging when + action="continue" with an error ToolMessage. + """ + + action: Literal["continue", "raise"] + result: ToolMessage | Command | None = None + exception: Exception | None = None + + def __post_init__(self) -> None: + """Validate that required fields are present based on action.""" + if self.action == "continue" and self.result is None: + msg = "action='continue' requires a result" + raise ValueError(msg) + if self.action == "raise" and self.exception is None: + msg = "action='raise' requires an exception" + raise ValueError(msg) + + +ToolCallHandler = Callable[ + [ToolCallRequest, Any, Any], Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse] +] +"""Generator-based handler that intercepts tool execution. + +Receives a ToolCallRequest, state, and runtime; yields modified ToolCallRequests; +receives ToolCallResponses; and returns a final ToolCallResponse. Supports multiple +yields for retry logic. +""" + + def msg_content_output(output: Any) -> str | list[dict]: """Convert tool output to valid message content format. @@ -156,7 +212,7 @@ def __init__( super().__init__(self.message) -def _default_handle_tool_errors(e: Exception) -> str: +def _default_handle_tool_errors(e: ToolInvocationError) -> str: """Default error handler for tool errors. If the tool is a tool invocation error, return its message. @@ -300,8 +356,8 @@ class ToolNode(RunnableCallable): Output format depends on input type and tool behavior: **For Regular tools**: - - Dict input → ``{"messages": [ToolMessage(...)]}`` - - List input → ``[ToolMessage(...)]`` + - Dict input -> ``{"messages": [ToolMessage(...)]}`` + - List input -> ``[ToolMessage(...)]`` **For Command tools**: - Returns ``[Command(...)]`` or mixed list with regular tool outputs @@ -335,6 +391,12 @@ class ToolNode(RunnableCallable): - catches tool invocation errors (due to invalid arguments provided by the model) and returns a descriptive error message - ignores tool execution errors (they will be re-raised) + on_tool_call: Optional handler to intercept tool execution. Receives + ``ToolCallRequest``, yields potentially modified requests, receives + ``ToolCallResponse`` via ``.send()``, and returns final ``ToolCallResponse``. + Enables retries, argument modification, and custom error handling. + Defaults to ``None``. + messages_key: The key in the state dictionary that contains the message list. This same key will be used for the output ToolMessages. Defaults to "messages". @@ -378,6 +440,23 @@ def handle_errors(e: ValueError) -> str: tool_node = ToolNode([my_tool], handle_tool_errors=handle_errors) ``` + + Intercepting tool calls: + + ```python + from langchain.tools.tool_node import ToolCallRequest, ToolCallResponse + + def retry_handler(request): + \"\"\"Retry failed tool calls up to 3 times.\"\"\" + for attempt in range(3): + response = yield request + if response.action == "continue": + return response + # Retry on error + return response # Final attempt + + tool_node = ToolNode([my_tool], on_tool_call=retry_handler) + ``` """ # noqa: E501 name: str = "tools" @@ -393,6 +472,7 @@ def __init__( | Callable[..., str] | type[Exception] | tuple[type[Exception], ...] = _default_handle_tool_errors, + on_tool_call: ToolCallHandler | None = None, messages_key: str = "messages", ) -> None: """Initialize the ToolNode with the provided tools and configuration. @@ -402,6 +482,7 @@ def __init__( name: Node name for graph identification. tags: Optional metadata tags. handle_tool_errors: Error handling configuration. + on_tool_call: Optional handler to intercept tool execution. messages_key: State key containing messages. """ super().__init__(self._func, self._afunc, name=name, tags=tags, trace=False) @@ -409,6 +490,7 @@ def __init__( self._tool_to_state_args: dict[str, dict[str, str | None]] = {} self._tool_to_store_arg: dict[str, str | None] = {} self._handle_tool_errors = handle_tool_errors + self._on_tool_call = on_tool_call self._messages_key = messages_key for tool in tools: if not isinstance(tool, BaseTool): @@ -429,13 +511,24 @@ def _func( input: list[AnyMessage] | dict[str, Any] | BaseModel, config: RunnableConfig, *, + # Optional[BaseStore] should not change to BaseStore | None + # until we support injection of store using `BaseStore | None` annotation store: Optional[BaseStore], # noqa: UP045 ) -> Any: + try: + runtime = get_runtime() + except RuntimeError: + # Running outside of the LangGrah runtime context (e.g., unit-tests) + runtime = None tool_calls, input_type = self._parse_input(input, store) config_list = get_config_list(config, len(tool_calls)) input_types = [input_type] * len(tool_calls) + inputs = [input] * len(tool_calls) + runtimes = [runtime] * len(tool_calls) with get_executor_for_config(config) as executor: - outputs = [*executor.map(self._run_one, tool_calls, input_types, config_list)] + outputs = [ + *executor.map(self._run_one, tool_calls, input_types, config_list, inputs, runtimes) + ] return self._combine_tool_outputs(outputs, input_type) @@ -444,11 +537,18 @@ async def _afunc( input: list[AnyMessage] | dict[str, Any] | BaseModel, config: RunnableConfig, *, + # Optional[BaseStore] should not change to BaseStore | None + # until we support injection of store using `BaseStore | None` annotation store: Optional[BaseStore], # noqa: UP045 ) -> Any: + try: + runtime = get_runtime() + except RuntimeError: + # Running outside of the LangGrah runtime context (e.g., unit-tests) + runtime = None tool_calls, input_type = self._parse_input(input, store) outputs = await asyncio.gather( - *(self._arun_one(call, input_type, config) for call in tool_calls) + *(self._arun_one(call, input_type, config, input, runtime) for call in tool_calls) ) return self._combine_tool_outputs(outputs, input_type) @@ -495,20 +595,19 @@ def _combine_tool_outputs( combined_outputs.append(parent_command) return combined_outputs - def _run_one( - self, - call: ToolCall, - input_type: Literal["list", "dict", "tool_calls"], - config: RunnableConfig, - ) -> ToolMessage | Command: - """Run a single tool call synchronously.""" - if invalid_tool_message := self._validate_tool_call(call): - return invalid_tool_message + def _execute_tool_sync( + self, request: ToolCallRequest, input_type: Literal["list", "dict", "tool_calls"], config: RunnableConfig + ) -> ToolCallResponse: + """Execute tool and return response. - try: - call_args = {**call, "type": "tool_call"} - tool = self.tools_by_name[call["name"]] + Applies handle_tool_errors configuration. When on_tool_call is configured, + unhandled errors return action="raise" instead of raising immediately. + """ + call = request.tool_call + tool = request.tool + call_args = {**call, "type": "tool_call"} + try: try: response = tool.invoke(call_args, config) except ValidationError as exc: @@ -541,40 +640,110 @@ def _run_one( # default behavior is catching all exceptions handled_types = (Exception,) - # Unhandled + # Check if error is handled if not self._handle_tool_errors or not isinstance(e, handled_types): + # Error is not handled + if self._on_tool_call is not None: + # If handler exists, return action="raise" so handler can decide + return ToolCallResponse(action="raise", exception=e) + # No handler - maintain backward compatibility by raising immediately raise - # Handled + + # Error is handled - create error ToolMessage content = _handle_tool_error(e, flag=self._handle_tool_errors) - return ToolMessage( + error_message = ToolMessage( content=content, name=call["name"], tool_call_id=call["id"], status="error", ) + return ToolCallResponse(action="continue", result=error_message, exception=e) + # Process successful response if isinstance(response, Command): - return self._validate_tool_command(response, call, input_type) + # Validate Command before returning to handler + validated_command = self._validate_tool_command(response, request.tool_call, input_type) + return ToolCallResponse(action="continue", result=validated_command) if isinstance(response, ToolMessage): response.content = cast("str | list", msg_content_output(response.content)) - return response + return ToolCallResponse(action="continue", result=response) + msg = f"Tool {call['name']} returned unexpected type: {type(response)}" raise TypeError(msg) - async def _arun_one( + def _run_one( self, call: ToolCall, input_type: Literal["list", "dict", "tool_calls"], config: RunnableConfig, + input: list[AnyMessage] | dict[str, Any] | BaseModel, + runtime: Any, ) -> ToolMessage | Command: - """Run a single tool call asynchronously.""" + """Run a single tool call synchronously.""" if invalid_tool_message := self._validate_tool_call(call): return invalid_tool_message - try: - call_args = {**call, "type": "tool_call"} - tool = self.tools_by_name[call["name"]] + tool = self.tools_by_name[call["name"]] + + # Create the tool request + tool_request = ToolCallRequest( + tool_call=call, + tool=tool, + ) + + if self._on_tool_call is None: + tool_response = self._execute_tool_sync(tool_request, input_type, config) + else: + # Generator protocol: start generator, send responses, receive requests + gen = self._on_tool_call(tool_request, input, runtime) + + try: + request = next(gen) + except StopIteration: + msg = "on_tool_call handler must yield at least once before returning" + raise ValueError(msg) + + while True: + tool_response = self._execute_tool_sync(request, input_type, config) + try: + request = gen.send(tool_response) + except StopIteration as e: + if e.value is None: + msg = ( + "on_tool_call handler must explicitly return a ToolCallResponse. " + "Ensure your handler ends with 'return response'." + ) + raise ValueError(msg) + tool_response = e.value + break + + # Apply action directive + if tool_response.action == "raise": + if tool_response.exception is None: + msg = "ToolCallResponse with action='raise' must have an exception" + raise ValueError(msg) + raise tool_response.exception + + result = tool_response.result + if result is None: + msg = "ToolCallResponse with action='continue' must have a result" + raise ValueError(msg) + + return result + + async def _execute_tool_async( + self, request: ToolCallRequest, input_type: Literal["list", "dict", "tool_calls"], config: RunnableConfig + ) -> ToolCallResponse: + """Execute tool asynchronously and return response. + Applies handle_tool_errors configuration. When on_tool_call is configured, + unhandled errors return action="raise" instead of raising immediately. + """ + call = request.tool_call + tool = request.tool + call_args = {**call, "type": "tool_call"} + + try: try: response = await tool.ainvoke(call_args, config) except ValidationError as exc: @@ -607,27 +776,97 @@ async def _arun_one( # default behavior is catching all exceptions handled_types = (Exception,) - # Unhandled + # Check if error is handled if not self._handle_tool_errors or not isinstance(e, handled_types): + # Error is not handled + if self._on_tool_call is not None: + # If handler exists, return action="raise" so handler can decide + return ToolCallResponse(action="raise", exception=e) + # No handler - maintain backward compatibility by raising immediately raise - # Handled - content = _handle_tool_error(e, flag=self._handle_tool_errors) - return ToolMessage( + # Error is handled - create error ToolMessage + content = _handle_tool_error(e, flag=self._handle_tool_errors) + error_message = ToolMessage( content=content, name=call["name"], tool_call_id=call["id"], status="error", ) + return ToolCallResponse(action="continue", result=error_message, exception=e) + # Process successful response if isinstance(response, Command): - return self._validate_tool_command(response, call, input_type) + # Validate Command before returning to handler + validated_command = self._validate_tool_command(response, request.tool_call, input_type) + return ToolCallResponse(action="continue", result=validated_command) if isinstance(response, ToolMessage): response.content = cast("str | list", msg_content_output(response.content)) - return response + return ToolCallResponse(action="continue", result=response) + msg = f"Tool {call['name']} returned unexpected type: {type(response)}" raise TypeError(msg) + async def _arun_one( + self, + call: ToolCall, + input_type: Literal["list", "dict", "tool_calls"], + config: RunnableConfig, + input: list[AnyMessage] | dict[str, Any] | BaseModel, + runtime: Any, + ) -> ToolMessage | Command: + """Run a single tool call asynchronously.""" + if invalid_tool_message := self._validate_tool_call(call): + return invalid_tool_message + + tool = self.tools_by_name[call["name"]] + + # Create the tool request + tool_request = ToolCallRequest( + tool_call=call, + tool=tool, + ) + + if self._on_tool_call is None: + tool_response = await self._execute_tool_async(tool_request, input_type, config) + else: + # Generator protocol: handler is sync generator, tool execution is async + gen = self._on_tool_call(tool_request, input, runtime) + + try: + request = next(gen) + except StopIteration: + msg = "on_tool_call handler must yield at least once before returning" + raise ValueError(msg) + + while True: + tool_response = await self._execute_tool_async(request, input_type, config) + try: + request = gen.send(tool_response) + except StopIteration as e: + if e.value is None: + msg = ( + "on_tool_call handler must explicitly return a ToolCallResponse. " + "Ensure your handler ends with 'return response'." + ) + raise ValueError(msg) + tool_response = e.value + break + + # Apply action directive + if tool_response.action == "raise": + if tool_response.exception is None: + msg = "ToolCallResponse with action='raise' must have an exception" + raise ValueError(msg) + raise tool_response.exception + + result = tool_response.result + if result is None: + msg = "ToolCallResponse with action='continue' must have a result" + raise ValueError(msg) + + return result + def _parse_input( self, input: list[AnyMessage] | dict[str, Any] | BaseModel, diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/test_on_tool_call_middleware.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_on_tool_call_middleware.py new file mode 100644 index 0000000000000..0698f5506048a --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_on_tool_call_middleware.py @@ -0,0 +1,396 @@ +"""Unit tests for on_tool_call middleware hook.""" + +from collections.abc import Generator +from typing import Any, Literal, Union +import typing + +from pydantic import BaseModel +import pytest +from langchain_core.language_models import LanguageModelInput +from langchain_core.language_models.fake_chat_models import GenericFakeChatModel +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage +from langchain_core.runnables import Runnable +from langchain_core.tools import BaseTool, tool + +from langchain.agents.middleware.types import AgentMiddleware +from langchain.agents.middleware_agent import create_agent +from langchain.tools.tool_node import ToolCallRequest, ToolCallResponse + + +class FakeModel(GenericFakeChatModel): + """Fake chat model for testing.""" + + tool_style: Literal["openai", "anthropic"] = "openai" + + def bind_tools( + self, + tools: typing.Sequence[Union[dict[str, Any], type[BaseModel], typing.Callable, BaseTool]], + **kwargs: Any, + ) -> Runnable[LanguageModelInput, BaseMessage]: + if len(tools) == 0: + msg = "Must provide at least one tool" + raise ValueError(msg) + + tool_dicts = [] + for tool in tools: + if isinstance(tool, dict): + tool_dicts.append(tool) + continue + if not isinstance(tool, BaseTool): + msg = "Only BaseTool and dict is supported by FakeModel.bind_tools" + raise TypeError(msg) + + # NOTE: this is a simplified tool spec for testing purposes only + if self.tool_style == "openai": + tool_dicts.append( + { + "type": "function", + "function": { + "name": tool.name, + }, + } + ) + elif self.tool_style == "anthropic": + tool_dicts.append( + { + "name": tool.name, + } + ) + + return self.bind(tools=tool_dicts) + + +@tool +def add_tool(x: int, y: int) -> int: + """Add two numbers.""" + return x + y + + +@tool +def failing_tool(x: int) -> int: + """Tool that raises an error.""" + msg = "Intentional failure" + raise ValueError(msg) + + +def test_single_middleware_on_tool_call(): + """Test that a single middleware can intercept tool calls.""" + call_log = [] + + class LoggingMiddleware(AgentMiddleware): + """Middleware that logs tool calls.""" + + def on_tool_call( + self, request: ToolCallRequest, state, runtime + ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: + call_log.append(f"before_{request.tool.name}") + response = yield request + call_log.append(f"after_{request.tool.name}") + return response + + model = FakeModel( + messages=iter( + [ + AIMessage( + content="", + tool_calls=[{"name": "add_tool", "args": {"x": 2, "y": 3}, "id": "1"}], + ), + AIMessage(content="Done"), + ] + ) + ) + + agent = create_agent( + model=model, + tools=[add_tool], + middleware=[LoggingMiddleware()], + ) + + result = agent.compile().invoke({"messages": [HumanMessage("Add 2 and 3")]}) + + assert "before_add_tool" in call_log + assert "after_add_tool" in call_log + assert call_log.index("before_add_tool") < call_log.index("after_add_tool") + + # Check that tool executed successfully + tool_messages = [m for m in result["messages"] if isinstance(m, ToolMessage)] + assert len(tool_messages) == 1 + assert tool_messages[0].content == "5" + + +def test_multiple_middleware_chaining(): + """Test that multiple middleware chain correctly (outer wraps inner).""" + call_order = [] + + class OuterMiddleware(AgentMiddleware): + """Outer middleware in the chain.""" + + def on_tool_call( + self, request: ToolCallRequest, state, runtime + ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: + call_order.append("outer_start") + response = yield request + call_order.append("outer_end") + return response + + class InnerMiddleware(AgentMiddleware): + """Inner middleware in the chain.""" + + def on_tool_call( + self, request: ToolCallRequest, state, runtime + ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: + call_order.append("inner_start") + response = yield request + call_order.append("inner_end") + return response + + model = FakeModel( + messages=iter( + [ + AIMessage( + content="", + tool_calls=[{"name": "add_tool", "args": {"x": 1, "y": 1}, "id": "1"}], + ), + AIMessage(content="Done"), + ] + ) + ) + + agent = create_agent( + model=model, + tools=[add_tool], + middleware=[OuterMiddleware(), InnerMiddleware()], + ) + + agent.compile().invoke({"messages": [HumanMessage("Add 1 and 1")]}) + + # Verify order: outer_start -> inner_start -> tool -> inner_end -> outer_end + assert call_order == ["outer_start", "inner_start", "inner_end", "outer_end"] + + +def test_middleware_retry_logic(): + """Test that middleware can retry tool calls.""" + attempt_count = 0 + + class RetryMiddleware(AgentMiddleware): + """Middleware that retries on failure.""" + + def on_tool_call( + self, request: ToolCallRequest, state, runtime + ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: + nonlocal attempt_count + max_retries = 2 + + for attempt in range(max_retries): + attempt_count += 1 + response = yield request + + if response.action == "continue": + return response + + if response.action == "raise" and attempt < max_retries - 1: + # Retry + continue + + # Convert error to success message + return ToolCallResponse( + action="continue", + result=ToolMessage( + content=f"Failed after {max_retries} attempts", + name=request.tool_call["name"], + tool_call_id=request.tool_call["id"], + status="error", + ), + ) + + raise AssertionError("Unreachable") + + model = FakeModel( + messages=iter( + [ + AIMessage( + content="", + tool_calls=[{"name": "failing_tool", "args": {"x": 1}, "id": "1"}], + ), + AIMessage(content="Done"), + ] + ) + ) + + agent = create_agent( + model=model, + tools=[failing_tool], + middleware=[RetryMiddleware()], + ) + + result = agent.compile().invoke({"messages": [HumanMessage("Test retry")]}) + + # Should have attempted twice + assert attempt_count == 2 + + # Check that we got an error message + tool_messages = [m for m in result["messages"] if isinstance(m, ToolMessage)] + assert len(tool_messages) == 1 + assert "Failed after 2 attempts" in tool_messages[0].content + + +def test_middleware_request_modification(): + """Test that middleware can modify tool requests.""" + + class RequestModifierMiddleware(AgentMiddleware): + """Middleware that doubles the input.""" + + def on_tool_call( + self, request: ToolCallRequest, state, runtime + ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: + # Modify the arguments + modified_tool_call = { + **request.tool_call, + "args": { + "x": request.tool_call["args"]["x"] * 2, + "y": request.tool_call["args"]["y"] * 2, + }, + } + modified_request = ToolCallRequest( + tool_call=modified_tool_call, + tool=request.tool, + ) + response = yield modified_request + return response + + model = FakeModel( + messages=iter( + [ + AIMessage( + content="", + tool_calls=[{"name": "add_tool", "args": {"x": 1, "y": 2}, "id": "1"}], + ), + AIMessage(content="Done"), + ] + ) + ) + + agent = create_agent( + model=model, + tools=[add_tool], + middleware=[RequestModifierMiddleware()], + ) + + result = agent.compile().invoke({"messages": [HumanMessage("Add 1 and 2")]}) + + # Original: 1 + 2 = 3, Modified: 2 + 4 = 6 + tool_messages = [m for m in result["messages"] if isinstance(m, ToolMessage)] + assert len(tool_messages) == 1 + assert tool_messages[0].content == "6" + + +def test_multiple_middleware_with_retry(): + """Test complex scenario with multiple middleware and retry logic.""" + call_log = [] + + class MonitoringMiddleware(AgentMiddleware): + """Outer middleware for monitoring.""" + + def on_tool_call( + self, request: ToolCallRequest, state, runtime + ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: + call_log.append("monitoring_start") + response = yield request + call_log.append("monitoring_end") + return response + + class RetryMiddleware(AgentMiddleware): + """Inner middleware for retries.""" + + def on_tool_call( + self, request: ToolCallRequest, state, runtime + ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: + call_log.append("retry_start") + for attempt in range(2): + call_log.append(f"retry_attempt_{attempt + 1}") + response = yield request + + if response.action == "continue": + call_log.append("retry_success") + return response + + if attempt == 0: # Retry once + call_log.append("retry_retry") + continue + + call_log.append("retry_failed") + return response + + model = FakeModel( + messages=iter( + [ + AIMessage( + content="", + tool_calls=[{"name": "add_tool", "args": {"x": 5, "y": 7}, "id": "1"}], + ), + AIMessage(content="Done"), + ] + ) + ) + + agent = create_agent( + model=model, + tools=[add_tool], + middleware=[MonitoringMiddleware(), RetryMiddleware()], + ) + + agent.compile().invoke({"messages": [HumanMessage("Add 5 and 7")]}) + + # Verify the call sequence + assert call_log[0] == "monitoring_start" + assert call_log[1] == "retry_start" + assert "retry_attempt_1" in call_log + assert "retry_success" in call_log + assert call_log[-1] == "monitoring_end" + + +def test_mixed_middleware(): + """Test middleware with both before_model and on_tool_call hooks.""" + call_log = [] + + class MixedMiddleware(AgentMiddleware): + """Middleware with multiple hooks.""" + + def before_model(self, state, runtime): + call_log.append("before_model") + return None + + def on_tool_call( + self, request: ToolCallRequest, state, runtime + ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: + call_log.append("on_tool_call_start") + response = yield request + call_log.append("on_tool_call_end") + return response + + model = FakeModel( + messages=iter( + [ + AIMessage( + content="", + tool_calls=[{"name": "add_tool", "args": {"x": 10, "y": 20}, "id": "1"}], + ), + AIMessage(content="Done"), + ] + ) + ) + + agent = create_agent( + model=model, + tools=[add_tool], + middleware=[MixedMiddleware()], + ) + + agent.compile().invoke({"messages": [HumanMessage("Add 10 and 20")]}) + + # Both hooks should have been called + assert "before_model" in call_log + assert "on_tool_call_start" in call_log + assert "on_tool_call_end" in call_log + # before_model runs before on_tool_call + assert call_log.index("before_model") < call_log.index("on_tool_call_start") diff --git a/libs/langchain_v1/tests/unit_tests/tools/test_on_tool_call.py b/libs/langchain_v1/tests/unit_tests/tools/test_on_tool_call.py new file mode 100644 index 0000000000000..4111f9e9483bc --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/tools/test_on_tool_call.py @@ -0,0 +1,383 @@ +"""Tests for on_tool_call handler functionality.""" + +from collections.abc import Generator +from typing import Any + +import pytest +from langchain_core.messages import AIMessage, ToolMessage +from langchain_core.tools import tool + +from langchain.tools import ToolNode +from langchain.tools.tool_node import ToolCallRequest, ToolCallResponse + + +# Test tools +@tool +def success_tool(x: int) -> int: + """A tool that always succeeds.""" + return x * 2 + + +@tool +def error_tool(x: int) -> int: + """A tool that always raises ValueError.""" + msg = f"Error with value: {x}" + raise ValueError(msg) + + +@tool +def rate_limit_tool(x: int) -> int: + """A tool that simulates rate limit errors.""" + if not hasattr(rate_limit_tool, "_call_count"): + rate_limit_tool._call_count = 0 + rate_limit_tool._call_count += 1 + + if rate_limit_tool._call_count < 3: # Fail first 2 times + msg = "Rate limit exceeded" + raise ValueError(msg) + return x * 2 + + +def test_on_tool_call_passthrough() -> None: + """Test that a simple passthrough handler works.""" + + def passthrough_handler( + request: ToolCallRequest, _state: Any, _runtime: Any + ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: + """Simply pass through without modification.""" + response = yield request + return response + + tool_node = ToolNode([success_tool], on_tool_call=passthrough_handler) + result = tool_node.invoke( + { + "messages": [ + AIMessage( + "", + tool_calls=[{"name": "success_tool", "args": {"x": 5}, "id": "1"}], + ) + ] + } + ) + + assert len(result["messages"]) == 1 + tool_message: ToolMessage = result["messages"][0] + assert tool_message.content == "10" + assert tool_message.status != "error" + + +def test_on_tool_call_retry_success() -> None: + """Test that retry handler can recover from transient errors.""" + # Reset counter + if hasattr(rate_limit_tool, "_call_count"): + rate_limit_tool._call_count = 0 + + def retry_handler( + request: ToolCallRequest, _state: Any, _runtime: Any + ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: + """Retry up to 3 times.""" + max_retries = 3 + + for attempt in range(max_retries): + response = yield request + + if response.action == "continue": + return response + + # Retry on error + if attempt < max_retries - 1: + continue + + # Final attempt failed - convert to error message + return ToolCallResponse( + action="continue", + result=ToolMessage( + content=f"Failed after {max_retries} attempts", + name=request.tool_call["name"], + tool_call_id=request.tool_call["id"], + status="error", + ), + ) + msg = "Unreachable code" + raise AssertionError(msg) + + tool_node = ToolNode([rate_limit_tool], on_tool_call=retry_handler, handle_tool_errors=False) + result = tool_node.invoke( + { + "messages": [ + AIMessage( + "", + tool_calls=[{"name": "rate_limit_tool", "args": {"x": 5}, "id": "1"}], + ) + ] + } + ) + + assert len(result["messages"]) == 1 + tool_message: ToolMessage = result["messages"][0] + assert tool_message.content == "10" # Should succeed on 3rd attempt + assert tool_message.status != "error" + + +def test_on_tool_call_convert_error_to_message() -> None: + """Test that handler can convert raised errors to error messages.""" + + def error_to_message_handler( + request: ToolCallRequest, _state: Any, _runtime: Any + ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: + """Convert any error to a user-friendly message.""" + response = yield request + + if response.action == "raise": + return ToolCallResponse( + action="continue", + result=ToolMessage( + content=f"Tool failed: {response.exception}", + name=request.tool_call["name"], + tool_call_id=request.tool_call["id"], + status="error", + ), + exception=response.exception, + ) + + return response + + tool_node = ToolNode( + [error_tool], on_tool_call=error_to_message_handler, handle_tool_errors=False + ) + result = tool_node.invoke( + { + "messages": [ + AIMessage( + "", + tool_calls=[{"name": "error_tool", "args": {"x": 5}, "id": "1"}], + ) + ] + } + ) + + assert len(result["messages"]) == 1 + tool_message: ToolMessage = result["messages"][0] + assert "Tool failed" in tool_message.content + assert "Error with value: 5" in tool_message.content + assert tool_message.status == "error" + + +def test_on_tool_call_let_error_raise() -> None: + """Test that handler can let errors propagate.""" + + def let_raise_handler( + request: ToolCallRequest, _state: Any, _runtime: Any + ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: + """Just return the response as-is, letting errors raise.""" + response = yield request + return response + + tool_node = ToolNode([error_tool], on_tool_call=let_raise_handler, handle_tool_errors=False) + + with pytest.raises(ValueError, match=r"Error with value: 5"): + tool_node.invoke( + { + "messages": [ + AIMessage( + "", + tool_calls=[{"name": "error_tool", "args": {"x": 5}, "id": "1"}], + ) + ] + } + ) + + +def test_on_tool_call_with_handled_errors() -> None: + """Test interaction between on_tool_call and handle_tool_errors.""" + call_count = {"count": 0} + + def counting_handler( + request: ToolCallRequest, _state: Any, _runtime: Any + ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: + """Count how many times we're called.""" + call_count["count"] += 1 + response = yield request + return response + + # When handle_tool_errors=True, errors are converted to ToolMessages + # so handler sees action="continue" + tool_node = ToolNode([error_tool], on_tool_call=counting_handler, handle_tool_errors=True) + result = tool_node.invoke( + { + "messages": [ + AIMessage( + "", + tool_calls=[{"name": "error_tool", "args": {"x": 5}, "id": "1"}], + ) + ] + } + ) + + assert call_count["count"] == 1 + assert len(result["messages"]) == 1 + tool_message: ToolMessage = result["messages"][0] + assert tool_message.status == "error" + assert "Please fix your mistakes" in tool_message.content + + +def test_on_tool_call_must_return_value() -> None: + """Test that handler must return a ToolCallResponse.""" + + def no_return_handler( + request: ToolCallRequest, _state: Any, _runtime: Any + ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: + """Handler that doesn't return anything.""" + _ = yield request + # Implicit return None + + tool_node = ToolNode([success_tool], on_tool_call=no_return_handler) + + with pytest.raises(ValueError, match=r"must explicitly return a ToolCallResponse"): + tool_node.invoke( + { + "messages": [ + AIMessage( + "", + tool_calls=[{"name": "success_tool", "args": {"x": 5}, "id": "1"}], + ) + ] + } + ) + + +def test_on_tool_call_request_modification() -> None: + """Test that handler can modify the request before execution.""" + + def double_input_handler( + request: ToolCallRequest, _state: Any, _runtime: Any + ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: + """Double the input value.""" + # Modify the tool call args + modified_tool_call = { + **request.tool_call, + "args": {**request.tool_call["args"], "x": request.tool_call["args"]["x"] * 2}, + } + modified_request = ToolCallRequest( + tool_call=modified_tool_call, + tool=request.tool, + ) + response = yield modified_request + return response + + tool_node = ToolNode([success_tool], on_tool_call=double_input_handler) + result = tool_node.invoke( + { + "messages": [ + AIMessage( + "", + tool_calls=[{"name": "success_tool", "args": {"x": 5}, "id": "1"}], + ) + ] + } + ) + + assert len(result["messages"]) == 1 + tool_message: ToolMessage = result["messages"][0] + # Input was 5, doubled to 10, then tool multiplies by 2 = 20 + assert tool_message.content == "20" + + +def test_on_tool_call_response_validation() -> None: + """Test that ToolCallResponse validates action and required fields.""" + # Test action="continue" requires result + with pytest.raises(ValueError, match=r"action='continue' requires a result"): + ToolCallResponse(action="continue") + + # Test action="raise" requires exception + with pytest.raises(ValueError, match=r"action='raise' requires an exception"): + ToolCallResponse(action="raise") + + # Valid responses should work + ToolCallResponse( + action="continue", + result=ToolMessage(content="test", tool_call_id="1", name="test"), + ) + ToolCallResponse(action="raise", exception=ValueError("test")) + + +def test_on_tool_call_without_handler_backward_compat() -> None: + """Test that tools work without on_tool_call handler (backward compatibility).""" + # Success case + tool_node = ToolNode([success_tool]) + result = tool_node.invoke( + { + "messages": [ + AIMessage( + "", + tool_calls=[{"name": "success_tool", "args": {"x": 5}, "id": "1"}], + ) + ] + } + ) + assert result["messages"][0].content == "10" + + # Error case with handle_tool_errors=False + tool_node_error = ToolNode([error_tool], handle_tool_errors=False) + with pytest.raises(ValueError, match=r"Error with value: 5"): + tool_node_error.invoke( + { + "messages": [ + AIMessage( + "", + tool_calls=[{"name": "error_tool", "args": {"x": 5}, "id": "1"}], + ) + ] + } + ) + + # Error case with handle_tool_errors=True + tool_node_handled = ToolNode([error_tool], handle_tool_errors=True) + result = tool_node_handled.invoke( + { + "messages": [ + AIMessage( + "", + tool_calls=[{"name": "error_tool", "args": {"x": 5}, "id": "1"}], + ) + ] + } + ) + assert result["messages"][0].status == "error" + + +def test_on_tool_call_multiple_yields() -> None: + """Test that handler can yield multiple times for retries.""" + attempts = {"count": 0} + + def multi_yield_handler( + request: ToolCallRequest, _state: Any, _runtime: Any + ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: + """Yield multiple times to track attempts.""" + max_attempts = 3 + + for _ in range(max_attempts): + attempts["count"] += 1 + response = yield request + + if response.action == "continue": + return response + + # All attempts failed + return response + + tool_node = ToolNode([error_tool], on_tool_call=multi_yield_handler, handle_tool_errors=False) + + with pytest.raises(ValueError, match=r"Error with value: 5"): + tool_node.invoke( + { + "messages": [ + AIMessage( + "", + tool_calls=[{"name": "error_tool", "args": {"x": 5}, "id": "1"}], + ) + ] + } + ) + + assert attempts["count"] == 3