diff --git a/libs/langchain_v1/langchain/agents/middleware/__init__.py b/libs/langchain_v1/langchain/agents/middleware/__init__.py index 5915b75164a74..df39772f6ee60 100644 --- a/libs/langchain_v1/langchain/agents/middleware/__init__.py +++ b/libs/langchain_v1/langchain/agents/middleware/__init__.py @@ -12,7 +12,9 @@ AgentMiddleware, AgentState, ModelRequest, + after_agent, after_model, + before_agent, before_model, dynamic_prompt, hook_config, @@ -33,7 +35,9 @@ "PlanningMiddleware", "SummarizationMiddleware", "ToolCallLimitMiddleware", + "after_agent", "after_model", + "before_agent", "before_model", "dynamic_prompt", "hook_config", diff --git a/libs/langchain_v1/langchain/agents/middleware/types.py b/libs/langchain_v1/langchain/agents/middleware/types.py index 6aea90ab9fc35..bf639e025e326 100644 --- a/libs/langchain_v1/langchain/agents/middleware/types.py +++ b/libs/langchain_v1/langchain/agents/middleware/types.py @@ -4,6 +4,7 @@ from collections.abc import Callable from dataclasses import dataclass, field +from enum import Enum from inspect import iscoroutinefunction from typing import ( TYPE_CHECKING, @@ -41,11 +42,18 @@ "AgentMiddleware", "AgentState", "ContextT", + "HookImplementation", + "MiddlewareHookInfo", "ModelRequest", "OmitFromSchema", "PublicAgentState", + "after_agent", + "after_model", + "before_agent", + "before_model", "dynamic_prompt", "hook_config", + "modify_model_request", ] JumpTo = Literal["tools", "model", "end"] @@ -54,6 +62,15 @@ ResponseT = TypeVar("ResponseT") +class HookImplementation(str, Enum): + """Tracks which implementation variants exist for a middleware hook.""" + + NONE = "none" + SYNC_ONLY = "sync" + ASYNC_ONLY = "async" + BOTH = "both" + + @dataclass class ModelRequest: """Model request information for the agent.""" @@ -93,7 +110,7 @@ class AgentState(TypedDict, Generic[ResponseT]): messages: Required[Annotated[list[AnyMessage], add_messages]] jump_to: NotRequired[Annotated[JumpTo | None, EphemeralValue, PrivateStateAttr]] - structured_response: NotRequired[ResponseT] + structured_response: NotRequired[Annotated[ResponseT, OmitFromInput]] thread_model_call_count: NotRequired[Annotated[int, PrivateStateAttr]] run_model_call_count: NotRequired[Annotated[int, UntrackedValue, PrivateStateAttr]] @@ -112,6 +129,56 @@ class PublicAgentState(TypedDict, Generic[ResponseT]): StateT_contra = TypeVar("StateT_contra", bound=AgentState, contravariant=True) +@dataclass +class MiddlewareHookInfo: + """Information about a specific middleware hook implementation. + + This class encapsulates metadata about how a middleware implements a particular hook, + including the actual hook functions and jump configuration. + """ + + middleware_name: str + """The name of the middleware that implements this hook.""" + + hook_name: str + """The name of the hook (e.g., 'before_model', 'after_agent').""" + + sync_fn: Callable[..., Any] | None + """The synchronous hook function, or None if not implemented.""" + + async_fn: Callable[..., Any] | None + """The asynchronous hook function, or None if not implemented.""" + + can_jump_to: list[JumpTo] + """Valid jump destinations for this hook.""" + + @property + def node_name(self) -> str: + """The graph node name for this hook.""" + return f"{self.middleware_name}.{self.hook_name}" + + @property + def has_sync(self) -> bool: + """Whether this hook has a sync implementation.""" + return self.sync_fn is not None + + @property + def has_async(self) -> bool: + """Whether this hook has an async implementation.""" + return self.async_fn is not None + + @property + def implementation(self) -> HookImplementation: + """Which variants (sync/async/both) are implemented.""" + if self.has_sync and self.has_async: + return HookImplementation.BOTH + if self.has_sync: + return HookImplementation.SYNC_ONLY + if self.has_async: + return HookImplementation.ASYNC_ONLY + return HookImplementation.NONE + + class AgentMiddleware(Generic[StateT, ContextT]): """Base middleware class for an agent. @@ -133,6 +200,14 @@ def name(self) -> str: """ return self.__class__.__name__ + def before_agent(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None: + """Logic to run before the agent execution starts.""" + + async def abefore_agent( + self, state: StateT, runtime: Runtime[ContextT] + ) -> dict[str, Any] | None: + """Async logic to run before the agent execution starts.""" + def before_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None: """Logic to run before the model is called.""" @@ -215,6 +290,102 @@ async def aretry_model_request( None, self.retry_model_request, error, request, state, runtime, attempt ) + def after_agent(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None: + """Logic to run after the agent execution completes.""" + + async def aafter_agent( + self, state: StateT, runtime: Runtime[ContextT] + ) -> dict[str, Any] | None: + """Async logic to run after the agent execution completes.""" + + def hook_info(self, hook_name: str) -> MiddlewareHookInfo | None: + """Get information about this middleware's implementation of a specific hook. + + Args: + hook_name: The name of the hook to inspect (e.g., 'before_model', 'after_agent'). + + Returns: + MiddlewareHookInfo if the hook is implemented, None otherwise. + + Example: + >>> middleware = MyMiddleware() + >>> info = middleware.hook_info("before_model") + >>> if info: + ... print(f"Has sync: {info.has_sync}, Has async: {info.has_async}") + """ + base_class = AgentMiddleware + middleware_class = self.__class__ + + # Check sync and async variants + sync_name = hook_name + async_name = f"a{hook_name}" + + base_sync_method = getattr(base_class, sync_name, None) + base_async_method = getattr(base_class, async_name, None) + + middleware_sync_method = getattr(middleware_class, sync_name, None) + middleware_async_method = getattr(middleware_class, async_name, None) + + has_custom_sync = middleware_sync_method is not base_sync_method + has_custom_async = middleware_async_method is not base_async_method + + if not has_custom_sync and not has_custom_async: + return None + + # Get the actual bound methods - only include customized implementations + sync_fn = getattr(self, sync_name) if has_custom_sync else None + async_fn = getattr(self, async_name) if has_custom_async else None + + # Get can_jump_to from either sync or async variant + can_jump_to: list[JumpTo] = [] + if has_custom_sync: + can_jump_to = getattr(middleware_sync_method, "__can_jump_to__", []) + elif has_custom_async: + can_jump_to = getattr(middleware_async_method, "__can_jump_to__", []) + + return MiddlewareHookInfo( + middleware_name=self.name, + hook_name=hook_name, + sync_fn=sync_fn, + async_fn=async_fn, + can_jump_to=can_jump_to, + ) + + def all_hook_info(self) -> dict[str, MiddlewareHookInfo]: + """Get information about all hooks implemented by this middleware. + + Returns: + Dictionary mapping hook names to their MiddlewareHookInfo. + + Example: + >>> middleware = MyMiddleware() + >>> for hook_name, info in middleware.all_hook_info().items(): + ... print(f"{hook_name}: sync={info.has_sync}, async={info.has_async}") + """ + hook_names = [ + "before_agent", + "before_model", + "modify_model_request", + "after_model", + "after_agent", + "retry_model_request", + ] + return {name: info for name in hook_names if (info := self.hook_info(name)) is not None} + + @property + def implemented_hooks(self) -> list[str]: + """List of hook names this middleware implements. + + Returns: + List of hook names that are overridden from the base class. + + Example: + >>> middleware = MyMiddleware() + >>> print(middleware.implemented_hooks) + ['before_model', 'after_model'] + """ + return list(self.all_hook_info().keys()) + class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]): """Callable with AgentState and Runtime as arguments.""" @@ -707,6 +878,279 @@ def wrapped( return decorator +@overload +def before_agent( + func: _CallableWithStateAndRuntime[StateT, ContextT], +) -> AgentMiddleware[StateT, ContextT]: ... + + +@overload +def before_agent( + func: None = None, + *, + state_schema: type[StateT] | None = None, + tools: list[BaseTool] | None = None, + can_jump_to: list[JumpTo] | None = None, + name: str | None = None, +) -> Callable[ + [_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT] +]: ... + + +def before_agent( + func: _CallableWithStateAndRuntime[StateT, ContextT] | None = None, + *, + state_schema: type[StateT] | None = None, + tools: list[BaseTool] | None = None, + can_jump_to: list[JumpTo] | None = None, + name: str | None = None, +) -> ( + Callable[[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]] + | AgentMiddleware[StateT, ContextT] +): + """Decorator used to dynamically create a middleware with the before_agent hook. + + Args: + func: The function to be decorated. Must accept: + `state: StateT, runtime: Runtime[ContextT]` - State and runtime context + state_schema: Optional custom state schema type. If not provided, uses the default + AgentState schema. + tools: Optional list of additional tools to register with this middleware. + can_jump_to: Optional list of valid jump destinations for conditional edges. + Valid values are: "tools", "model", "end" + name: Optional name for the generated middleware class. If not provided, + uses the decorated function's name. + + Returns: + Either an AgentMiddleware instance (if func is provided directly) or a decorator function + that can be applied to a function its wrapping. + + The decorated function should return: + - `dict[str, Any]` - State updates to merge into the agent state + - `Command` - A command to control flow (e.g., jump to different node) + - `None` - No state updates or flow control + + Examples: + Basic usage: + ```python + @before_agent + def log_before_agent(state: AgentState, runtime: Runtime) -> None: + print(f"Starting agent with {len(state['messages'])} messages") + ``` + + With conditional jumping: + ```python + @before_agent(can_jump_to=["end"]) + def conditional_before_agent(state: AgentState, runtime: Runtime) -> dict[str, Any] | None: + if some_condition(state): + return {"jump_to": "end"} + return None + ``` + + With custom state schema: + ```python + @before_agent(state_schema=MyCustomState) + def custom_before_agent(state: MyCustomState, runtime: Runtime) -> dict[str, Any]: + return {"custom_field": "initialized_value"} + ``` + """ + + def decorator( + func: _CallableWithStateAndRuntime[StateT, ContextT], + ) -> AgentMiddleware[StateT, ContextT]: + is_async = iscoroutinefunction(func) + + func_can_jump_to = ( + can_jump_to if can_jump_to is not None else getattr(func, "__can_jump_to__", []) + ) + + if is_async: + + async def async_wrapped( + self: AgentMiddleware[StateT, ContextT], # noqa: ARG001 + state: StateT, + runtime: Runtime[ContextT], + ) -> dict[str, Any] | Command | None: + return await func(state, runtime) # type: ignore[misc] + + # Preserve can_jump_to metadata on the wrapped function + if func_can_jump_to: + async_wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined] + + middleware_name = name or cast( + "str", getattr(func, "__name__", "BeforeAgentMiddleware") + ) + + return type( + middleware_name, + (AgentMiddleware,), + { + "state_schema": state_schema or AgentState, + "tools": tools or [], + "abefore_agent": async_wrapped, + }, + )() + + def wrapped( + self: AgentMiddleware[StateT, ContextT], # noqa: ARG001 + state: StateT, + runtime: Runtime[ContextT], + ) -> dict[str, Any] | Command | None: + return func(state, runtime) # type: ignore[return-value] + + # Preserve can_jump_to metadata on the wrapped function + if func_can_jump_to: + wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined] + + # Use function name as default if no name provided + middleware_name = name or cast("str", getattr(func, "__name__", "BeforeAgentMiddleware")) + + return type( + middleware_name, + (AgentMiddleware,), + { + "state_schema": state_schema or AgentState, + "tools": tools or [], + "before_agent": wrapped, + }, + )() + + if func is not None: + return decorator(func) + return decorator + + +@overload +def after_agent( + func: _CallableWithStateAndRuntime[StateT, ContextT], +) -> AgentMiddleware[StateT, ContextT]: ... + + +@overload +def after_agent( + func: None = None, + *, + state_schema: type[StateT] | None = None, + tools: list[BaseTool] | None = None, + can_jump_to: list[JumpTo] | None = None, + name: str | None = None, +) -> Callable[ + [_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT] +]: ... + + +def after_agent( + func: _CallableWithStateAndRuntime[StateT, ContextT] | None = None, + *, + state_schema: type[StateT] | None = None, + tools: list[BaseTool] | None = None, + can_jump_to: list[JumpTo] | None = None, + name: str | None = None, +) -> ( + Callable[[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]] + | AgentMiddleware[StateT, ContextT] +): + """Decorator used to dynamically create a middleware with the after_agent hook. + + Args: + func: The function to be decorated. Must accept: + `state: StateT, runtime: Runtime[ContextT]` - State and runtime context + state_schema: Optional custom state schema type. If not provided, uses the default + AgentState schema. + tools: Optional list of additional tools to register with this middleware. + can_jump_to: Optional list of valid jump destinations for conditional edges. + Valid values are: "tools", "model", "end" + name: Optional name for the generated middleware class. If not provided, + uses the decorated function's name. + + Returns: + Either an AgentMiddleware instance (if func is provided) or a decorator function + that can be applied to a function. + + The decorated function should return: + - `dict[str, Any]` - State updates to merge into the agent state + - `Command` - A command to control flow (e.g., jump to different node) + - `None` - No state updates or flow control + + Examples: + Basic usage for logging agent completion: + ```python + @after_agent + def log_completion(state: AgentState, runtime: Runtime) -> None: + print(f"Agent completed with {len(state['messages'])} messages") + ``` + + With custom state schema: + ```python + @after_agent(state_schema=MyCustomState, name="MyAfterAgentMiddleware") + def custom_after_agent(state: MyCustomState, runtime: Runtime) -> dict[str, Any]: + return {"custom_field": "finalized_value"} + ``` + """ + + def decorator( + func: _CallableWithStateAndRuntime[StateT, ContextT], + ) -> AgentMiddleware[StateT, ContextT]: + is_async = iscoroutinefunction(func) + # Extract can_jump_to from decorator parameter or from function metadata + func_can_jump_to = ( + can_jump_to if can_jump_to is not None else getattr(func, "__can_jump_to__", []) + ) + + if is_async: + + async def async_wrapped( + self: AgentMiddleware[StateT, ContextT], # noqa: ARG001 + state: StateT, + runtime: Runtime[ContextT], + ) -> dict[str, Any] | Command | None: + return await func(state, runtime) # type: ignore[misc] + + # Preserve can_jump_to metadata on the wrapped function + if func_can_jump_to: + async_wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined] + + middleware_name = name or cast("str", getattr(func, "__name__", "AfterAgentMiddleware")) + + return type( + middleware_name, + (AgentMiddleware,), + { + "state_schema": state_schema or AgentState, + "tools": tools or [], + "aafter_agent": async_wrapped, + }, + )() + + def wrapped( + self: AgentMiddleware[StateT, ContextT], # noqa: ARG001 + state: StateT, + runtime: Runtime[ContextT], + ) -> dict[str, Any] | Command | None: + return func(state, runtime) # type: ignore[return-value] + + # Preserve can_jump_to metadata on the wrapped function + if func_can_jump_to: + wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined] + + # Use function name as default if no name provided + middleware_name = name or cast("str", getattr(func, "__name__", "AfterAgentMiddleware")) + + return type( + middleware_name, + (AgentMiddleware,), + { + "state_schema": state_schema or AgentState, + "tools": tools or [], + "after_agent": wrapped, + }, + )() + + if func is not None: + return decorator(func) + return decorator + + @overload def dynamic_prompt( func: _CallableReturningPromptString[StateT, ContextT], diff --git a/libs/langchain_v1/langchain/agents/middleware_agent.py b/libs/langchain_v1/langchain/agents/middleware_agent.py index b4f6370446b05..03b29fb76563f 100644 --- a/libs/langchain_v1/langchain/agents/middleware_agent.py +++ b/libs/langchain_v1/langchain/agents/middleware_agent.py @@ -1,13 +1,14 @@ """Middleware agent implementation.""" -import itertools from collections.abc import Callable, Sequence -from typing import Annotated, Any, cast, get_args, get_origin, get_type_hints +from dataclasses import dataclass +from typing import Annotated, Any, Generic, cast, get_args, get_origin, get_type_hints from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage -from langchain_core.runnables import Runnable +from langchain_core.runnables import Runnable, run_in_executor from langchain_core.tools import BaseTool +from langgraph._internal._runnable import RunnableCallable from langgraph.constants import END, START from langgraph.graph.state import StateGraph from langgraph.runtime import Runtime @@ -19,6 +20,7 @@ AgentMiddleware, AgentState, JumpTo, + MiddlewareHookInfo, ModelRequest, OmitFromSchema, PublicAgentState, @@ -41,6 +43,167 @@ ResponseT = TypeVar("ResponseT") +# ============================================================================ +# Data Structures for Agent Graph Construction +# ============================================================================ + + +@dataclass +class MiddlewareHooks: + """Middleware hooks categorized by type for graph construction.""" + + before_agent: list[MiddlewareHookInfo] + """Hooks that run once before the agent starts.""" + + before_model: list[MiddlewareHookInfo] + """Hooks that run before each model call in the agent loop.""" + + modify_model_request: list[MiddlewareHookInfo] + """Hooks that modify the model request before calling the model.""" + + after_model: list[MiddlewareHookInfo] + """Hooks that run after each model call in the agent loop.""" + + after_agent: list[MiddlewareHookInfo] + """Hooks that run once after the agent completes.""" + + retry: list[MiddlewareHookInfo] + """Hooks that handle model invocation errors and optionally retry.""" + + @classmethod + def from_middleware_list( + cls, + middleware: Sequence[AgentMiddleware[AgentState[ResponseT], ContextT]], + ) -> "MiddlewareHooks": + """Extract and categorize all hooks from middleware instances. + + Args: + middleware: Sequence of middleware instances to analyze. + + Returns: + MiddlewareHooks with all hooks organized by type. + """ + hooks_by_type: dict[str, list[MiddlewareHookInfo]] = { + "before_agent": [], + "before_model": [], + "modify_model_request": [], + "after_model": [], + "after_agent": [], + "retry": [], + } + + # Map hook names to their category + hook_name_mapping = { + "before_agent": "before_agent", + "before_model": "before_model", + "modify_model_request": "modify_model_request", + "after_model": "after_model", + "after_agent": "after_agent", + "retry_model_request": "retry", + } + + for m in middleware: + for hook_name, category in hook_name_mapping.items(): + if hook_info := m.hook_info(hook_name): + hooks_by_type[category].append(hook_info) + + return cls( + before_agent=hooks_by_type["before_agent"], + before_model=hooks_by_type["before_model"], + modify_model_request=hooks_by_type["modify_model_request"], + after_model=hooks_by_type["after_model"], + after_agent=hooks_by_type["after_agent"], + retry=hooks_by_type["retry"], + ) + + +@dataclass +class AgentComponents: + """Core components and configuration for agent construction.""" + + model: BaseChatModel + """The language model to use for the agent.""" + + tool_node: ToolNode | None + """The tool execution node, or None if no tools are available.""" + + middleware_hooks: MiddlewareHooks + """Middleware hooks organized by type.""" + + structured_output_tools: dict[str, OutputToolBinding] + """Tools used for structured output parsing.""" + + default_tools: list[BaseTool | dict] + """Default tools available to the agent (regular tools + middleware tools + built-ins).""" + + initial_response_format: ResponseFormat | None + """The initial response format configuration.""" + + system_prompt: str | None + """The system prompt for the agent.""" + + +@dataclass +class GraphTopology: + """Key nodes in the graph topology defining the execution flow. + + The agent graph has the following structure: + START -> entry_node -> [loop: loop_entry_node -> model -> loop_exit_node -> tools] + -> exit_node -> END + + - entry_node: Runs once at the start (before_agent hooks) + - loop_entry_node: Beginning of agent loop (before_model hooks) + - loop_exit_node: End of each loop iteration (after_model hooks) + - exit_node: Runs once at the end (after_agent hooks) or END + """ + + entry_node: str + """The first node executed (START -> entry_node).""" + + loop_entry_node: str + """Where the agent loop begins (where tools loop back to).""" + + loop_exit_node: str + """The last node in each loop iteration.""" + + exit_node: str + """The final node before END (or END itself).""" + + @classmethod + def compute(cls, hooks: MiddlewareHooks) -> "GraphTopology": + """Compute graph topology from middleware hook configuration. + + Args: + hooks: The categorized middleware hooks. + + Returns: + GraphTopology describing the flow through the graph. + """ + # Entry node (runs once at start): before_agent -> before_model -> model_request + if hooks.before_agent: + entry_node = hooks.before_agent[0].node_name + elif hooks.before_model: + entry_node = hooks.before_model[0].node_name + else: + entry_node = "model_request" + + # Loop entry node (beginning of agent loop, excludes before_agent) + loop_entry_node = hooks.before_model[0].node_name if hooks.before_model else "model_request" + + # Loop exit node (end of each iteration, excludes after_agent) + loop_exit_node = hooks.after_model[0].node_name if hooks.after_model else "model_request" + + # Exit node (runs once at end): after_agent or END + exit_node = hooks.after_agent[-1].node_name if hooks.after_agent else END + + return cls( + entry_node=entry_node, + loop_entry_node=loop_entry_node, + loop_exit_node=loop_exit_node, + exit_node=exit_node, + ) + + def _resolve_schema(schemas: set[type], schema_name: str, omit_flag: str | None = None) -> type: """Resolve schema by merging schemas and optionally respecting OmitFromSchema annotations. @@ -86,39 +249,112 @@ def _extract_metadata(type_: type) -> list: return [] -def _get_can_jump_to(middleware: AgentMiddleware[Any, Any], hook_name: str) -> list[JumpTo]: - """Get the can_jump_to list from either sync or async hook methods. +# ============================================================================ +# Setup and Initialization Functions +# ============================================================================ + + +def _setup_components( + model: str | BaseChatModel, + tools: Sequence[BaseTool | Callable | dict[str, Any]] | ToolNode | None, + middleware: Sequence[AgentMiddleware[AgentState[ResponseT], ContextT]], + response_format: ResponseFormat[ResponseT] | type[ResponseT] | None, + system_prompt: str | None, +) -> AgentComponents: + """Setup and validate agent components. Args: - middleware: The middleware instance to inspect. - hook_name: The name of the hook ('before_model' or 'after_model'). + model: Model name or instance. + tools: Tools for the agent. + middleware: Middleware instances. + response_format: Response format configuration. + system_prompt: System prompt for the agent. Returns: - List of jump destinations, or empty list if not configured. + AgentComponents with all components configured and validated. """ - # Get the base class method for comparison - base_sync_method = getattr(AgentMiddleware, hook_name, None) - base_async_method = getattr(AgentMiddleware, f"a{hook_name}", None) - - # Try sync method first - only if it's overridden from base class - sync_method = getattr(middleware.__class__, hook_name, None) - if ( - sync_method - and sync_method is not base_sync_method - and hasattr(sync_method, "__can_jump_to__") - ): - return sync_method.__can_jump_to__ - - # Try async method - only if it's overridden from base class - async_method = getattr(middleware.__class__, f"a{hook_name}", None) - if ( - async_method - and async_method is not base_async_method - and hasattr(async_method, "__can_jump_to__") - ): - return async_method.__can_jump_to__ + # Initialize chat model + if isinstance(model, str): + model = init_chat_model(model) - return [] + # Handle tools being None or empty + if tools is None: + tools = [] + + # Convert response format and setup structured output tools + initial_response_format: ToolStrategy | ProviderStrategy | AutoStrategy | None + if response_format is None: + initial_response_format = None + elif isinstance(response_format, (ToolStrategy, ProviderStrategy, AutoStrategy)): + initial_response_format = response_format + else: + # Raw schema - wrap in AutoStrategy to enable auto-detection + initial_response_format = AutoStrategy(schema=response_format) + + # For AutoStrategy, convert to ToolStrategy to setup tools upfront + tool_strategy_for_setup: ToolStrategy | None = None + if isinstance(initial_response_format, AutoStrategy): + tool_strategy_for_setup = ToolStrategy(schema=initial_response_format.schema) + elif isinstance(initial_response_format, ToolStrategy): + tool_strategy_for_setup = initial_response_format + + structured_output_tools: dict[str, OutputToolBinding] = {} + if tool_strategy_for_setup: + for response_schema in tool_strategy_for_setup.schema_specs: + structured_tool_info = OutputToolBinding.from_schema_spec(response_schema) + structured_output_tools[structured_tool_info.tool.name] = structured_tool_info + + middleware_tools = [t for m in middleware for t in getattr(m, "tools", [])] + + # Setup tools + tool_node: ToolNode | None = None + default_tools: list[BaseTool | dict[str, Any]] + + if isinstance(tools, list): + # Extract built-in provider tools (dict format) and regular tools (BaseTool) + built_in_tools = [t for t in tools if isinstance(t, dict)] + regular_tools = [t for t in tools if not isinstance(t, dict)] + + # Tools that require client-side execution + available_tools = middleware_tools + regular_tools + + # Only create ToolNode if we have client-side tools + tool_node = ToolNode(tools=available_tools) if available_tools else None + + # Default tools for ModelRequest initialization + default_tools = regular_tools + middleware_tools + built_in_tools + elif isinstance(tools, ToolNode): + tool_node = tools + if tool_node: + # Add middleware tools to existing ToolNode + available_tools = list(tool_node.tools_by_name.values()) + middleware_tools + tool_node = ToolNode(available_tools) + + # default_tools includes all client-side tools + default_tools = available_tools + else: + default_tools = middleware_tools + else: + # No tools provided, only middleware_tools available + default_tools = middleware_tools + + # Validate middleware + assert len({m.name for m in middleware}) == len(middleware), ( # noqa: S101 + "Please remove duplicate middleware instances." + ) + + # Categorize middleware by hooks + middleware_hooks = MiddlewareHooks.from_middleware_list(middleware) + + return AgentComponents( + model=model, + tool_node=tool_node, + middleware_hooks=middleware_hooks, + structured_output_tools=structured_output_tools, + default_tools=default_tools, + initial_response_format=initial_response_format, + system_prompt=system_prompt, + ) def _supports_provider_strategy(model: str | BaseChatModel) -> bool: @@ -144,6 +380,58 @@ def _supports_provider_strategy(model: str | BaseChatModel) -> bool: ) +# ============================================================================ +# Node Building Functions +# ============================================================================ + + +def _create_hook_node(hook_info: MiddlewareHookInfo) -> RunnableCallable: + """Create a graph node for a middleware hook. + + Args: + hook_info: Information about the hook to create a node for. + + Returns: + RunnableCallable that supports both sync and async execution. + """ + return RunnableCallable(hook_info.sync_fn, hook_info.async_fn, trace=False) + + +def _add_middleware_nodes( + graph: StateGraph[AgentState, ContextT, PublicAgentState, PublicAgentState], + components: AgentComponents, + state_schema: type, +) -> None: + """Add all middleware hook nodes to the graph. + + Args: + graph: The state graph to add nodes to. + components: Agent components with middleware hooks. + state_schema: The state schema for input validation. + """ + hooks = components.middleware_hooks + + # Add before_agent nodes + for hook_info in hooks.before_agent: + node = _create_hook_node(hook_info) + graph.add_node(hook_info.node_name, node, input_schema=state_schema) + + # Add before_model nodes + for hook_info in hooks.before_model: + node = _create_hook_node(hook_info) + graph.add_node(hook_info.node_name, node, input_schema=state_schema) + + # Add after_model nodes + for hook_info in hooks.after_model: + node = _create_hook_node(hook_info) + graph.add_node(hook_info.node_name, node, input_schema=state_schema) + + # Add after_agent nodes + for hook_info in hooks.after_agent: + node = _create_hook_node(hook_info) + graph.add_node(hook_info.node_name, node, input_schema=state_schema) + + def _handle_structured_output_error( exception: Exception, response_format: ResponseFormat, @@ -174,6 +462,157 @@ def _handle_structured_output_error( return False, "" +# ============================================================================ +# Edge Building Functions +# ============================================================================ + + +def _connect_entry_edges( + graph: StateGraph[AgentState, ContextT, PublicAgentState, PublicAgentState], + topology: GraphTopology, +) -> None: + """Connect the entry edge from START to the entry node. + + Args: + graph: The state graph to add edges to. + topology: Graph topology configuration. + """ + graph.add_edge(START, topology.entry_node) + + +def _connect_loop_edges( + graph: StateGraph[AgentState, ContextT, PublicAgentState, PublicAgentState], + topology: GraphTopology, + components: AgentComponents, +) -> None: + """Connect conditional edges for the agent loop (tools <-> model). + + Args: + graph: The state graph to add edges to. + topology: Graph topology configuration. + components: Agent components with tool configuration. + """ + tool_node = components.tool_node + structured_output_tools = components.structured_output_tools + + if tool_node is None: + # No tools - connect loop_exit directly to exit_node + if topology.loop_exit_node == "model_request": + graph.add_edge(topology.loop_exit_node, topology.exit_node) + else: + # We have after_model but no tools + _add_middleware_edge( + graph, + topology.loop_exit_node, + topology.exit_node, + topology.loop_entry_node, + can_jump_to=components.middleware_hooks.after_model[0].can_jump_to, + ) + return + + # Add conditional edge from tools back to model or exit + graph.add_conditional_edges( + "tools", + _make_tools_to_model_edge( + tool_node, topology.loop_entry_node, structured_output_tools, topology.exit_node + ), + [topology.loop_entry_node, topology.exit_node], + ) + + # Add conditional edge from model to tools or exit + graph.add_conditional_edges( + topology.loop_exit_node, + _make_model_to_tools_edge( + topology.loop_entry_node, structured_output_tools, tool_node, topology.exit_node + ), + [topology.loop_entry_node, "tools", topology.exit_node], + ) + + +def _connect_middleware_chains( + graph: StateGraph[AgentState, ContextT, PublicAgentState, PublicAgentState], + components: AgentComponents, + topology: GraphTopology, +) -> None: + """Connect middleware hooks in chains. + + Args: + graph: The state graph to add edges to. + components: Agent components with middleware hooks. + topology: Graph topology configuration. + """ + hooks = components.middleware_hooks + + # Connect before_agent chain + if hooks.before_agent: + for i in range(len(hooks.before_agent) - 1): + _add_middleware_edge( + graph, + hooks.before_agent[i].node_name, + hooks.before_agent[i + 1].node_name, + topology.loop_entry_node, + can_jump_to=hooks.before_agent[i].can_jump_to, + ) + # Connect last before_agent to loop_entry_node + _add_middleware_edge( + graph, + hooks.before_agent[-1].node_name, + topology.loop_entry_node, + topology.loop_entry_node, + can_jump_to=hooks.before_agent[-1].can_jump_to, + ) + + # Connect before_model chain + if hooks.before_model: + for i in range(len(hooks.before_model) - 1): + _add_middleware_edge( + graph, + hooks.before_model[i].node_name, + hooks.before_model[i + 1].node_name, + topology.loop_entry_node, + can_jump_to=hooks.before_model[i].can_jump_to, + ) + # Connect last before_model to model_request + _add_middleware_edge( + graph, + hooks.before_model[-1].node_name, + "model_request", + topology.loop_entry_node, + can_jump_to=hooks.before_model[-1].can_jump_to, + ) + + # Connect after_model chain (reverse order) + if hooks.after_model: + graph.add_edge("model_request", hooks.after_model[-1].node_name) + for i in range(len(hooks.after_model) - 1, 0, -1): + _add_middleware_edge( + graph, + hooks.after_model[i].node_name, + hooks.after_model[i - 1].node_name, + topology.loop_entry_node, + can_jump_to=hooks.after_model[i].can_jump_to, + ) + + # Connect after_agent chain (reverse order) + if hooks.after_agent: + for i in range(len(hooks.after_agent) - 1, 0, -1): + _add_middleware_edge( + graph, + hooks.after_agent[i].node_name, + hooks.after_agent[i - 1].node_name, + topology.loop_entry_node, + can_jump_to=hooks.after_agent[i].can_jump_to, + ) + # Connect first after_agent to END + _add_middleware_edge( + graph, + hooks.after_agent[0].node_name, + END, + topology.loop_entry_node, + can_jump_to=hooks.after_agent[0].can_jump_to, + ) + + def create_agent( # noqa: PLR0915 *, model: str | BaseChatModel, @@ -185,106 +624,23 @@ def create_agent( # noqa: PLR0915 ) -> StateGraph[ AgentState[ResponseT], ContextT, PublicAgentState[ResponseT], PublicAgentState[ResponseT] ]: - """Create a middleware agent graph.""" - # init chat model - if isinstance(model, str): - model = init_chat_model(model) - - # Handle tools being None or empty - if tools is None: - tools = [] - - # Convert response format and setup structured output tools - # Raw schemas are wrapped in AutoStrategy to preserve auto-detection intent. - # AutoStrategy is converted to ToolStrategy upfront to calculate tools during agent creation, - # but may be replaced with ProviderStrategy later based on model capabilities. - initial_response_format: ToolStrategy | ProviderStrategy | AutoStrategy | None - if response_format is None: - initial_response_format = None - elif isinstance(response_format, (ToolStrategy, ProviderStrategy)): - # Preserve explicitly requested strategies - initial_response_format = response_format - elif isinstance(response_format, AutoStrategy): - # AutoStrategy provided - preserve it for later auto-detection - initial_response_format = response_format - else: - # Raw schema - wrap in AutoStrategy to enable auto-detection - initial_response_format = AutoStrategy(schema=response_format) - - # For AutoStrategy, convert to ToolStrategy to setup tools upfront - # (may be replaced with ProviderStrategy later based on model) - tool_strategy_for_setup: ToolStrategy | None = None - if isinstance(initial_response_format, AutoStrategy): - tool_strategy_for_setup = ToolStrategy(schema=initial_response_format.schema) - elif isinstance(initial_response_format, ToolStrategy): - tool_strategy_for_setup = initial_response_format + """Create a middleware agent graph. - structured_output_tools: dict[str, OutputToolBinding] = {} - if tool_strategy_for_setup: - for response_schema in tool_strategy_for_setup.schema_specs: - structured_tool_info = OutputToolBinding.from_schema_spec(response_schema) - structured_output_tools[structured_tool_info.tool.name] = structured_tool_info - middleware_tools = [t for m in middleware for t in getattr(m, "tools", [])] - - # Setup tools - tool_node: ToolNode | None = None - if isinstance(tools, list): - # Extract built-in provider tools (dict format) and regular tools (BaseTool) - built_in_tools = [t for t in tools if isinstance(t, dict)] - regular_tools = [t for t in tools if not isinstance(t, dict)] - - # Tools that require client-side execution (must be in ToolNode) - available_tools = middleware_tools + regular_tools - - # Only create ToolNode if we have client-side tools - tool_node = ToolNode(tools=available_tools) if available_tools else None - - # Default tools for ModelRequest initialization - # Include built-ins and regular tools (can be changed dynamically by middleware) - # Structured tools are NOT included - they're added dynamically based on response_format - default_tools = regular_tools + middleware_tools + built_in_tools - elif isinstance(tools, ToolNode): - tool_node = tools - if tool_node: - # Add middleware tools to existing ToolNode - available_tools = list(tool_node.tools_by_name.values()) + middleware_tools - tool_node = ToolNode(available_tools) - - # default_tools includes all client-side tools (no built-ins or structured tools) - default_tools = available_tools - else: - # No tools provided, only middleware_tools available - default_tools = middleware_tools + Args: + model: Model name or BaseChatModel instance. + tools: Tools for the agent to use. + system_prompt: System prompt for the agent. + middleware: Middleware instances to customize agent behavior. + response_format: Response format configuration for structured outputs. + context_schema: Context schema for the graph runtime. - # validate middleware - assert len({m.name for m in middleware}) == len(middleware), ( # noqa: S101 - "Please remove duplicate middleware instances." - ) - middleware_w_before = [ - m - for m in middleware - if m.__class__.before_model is not AgentMiddleware.before_model - or m.__class__.abefore_model is not AgentMiddleware.abefore_model - ] - middleware_w_modify_model_request = [ - m - for m in middleware - if m.__class__.modify_model_request is not AgentMiddleware.modify_model_request - or m.__class__.amodify_model_request is not AgentMiddleware.amodify_model_request - ] - middleware_w_after = [ - m - for m in middleware - if m.__class__.after_model is not AgentMiddleware.after_model - or m.__class__.aafter_model is not AgentMiddleware.aafter_model - ] - middleware_w_retry = [ - m - for m in middleware - if m.__class__.retry_model_request is not AgentMiddleware.retry_model_request - or m.__class__.aretry_model_request is not AgentMiddleware.aretry_model_request - ] + Returns: + StateGraph configured with all nodes and edges. + """ + # Phase 1: Setup and validate components + components = _setup_components(model, tools, middleware, response_format, system_prompt) + # Phase 2: Create schemas state_schemas = {m.state_schema for m in middleware} state_schemas.add(AgentState) @@ -292,7 +648,7 @@ def create_agent( # noqa: PLR0915 input_schema = _resolve_schema(state_schemas, "InputSchema", "input") output_schema = _resolve_schema(state_schemas, "OutputSchema", "output") - # create graph, add nodes + # Phase 3: Create graph graph: StateGraph[ AgentState[ResponseT], ContextT, PublicAgentState[ResponseT], PublicAgentState[ResponseT] ] = StateGraph( @@ -302,6 +658,15 @@ def create_agent( # noqa: PLR0915 context_schema=context_schema, ) + # Phase 4: Define model request handlers (need access to components via closure) + # These are inner functions because they need access to components + structured_output_tools = components.structured_output_tools + default_tools = components.default_tools + initial_response_format = components.initial_response_format + model_instance = components.model + middleware_w_modify_model_request = components.middleware_hooks.modify_model_request + middleware_w_retry = components.middleware_hooks.retry + def _handle_model_output( output: AIMessage, effective_response_format: ResponseFormat | None ) -> dict[str, Any]: @@ -511,29 +876,28 @@ def _get_bound_model(request: ModelRequest) -> tuple[Runnable, ResponseFormat | def model_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]: """Sync model request handler with sequential middleware processing.""" request = ModelRequest( - model=model, + model=model_instance, tools=default_tools, - system_prompt=system_prompt, + system_prompt=components.system_prompt, response_format=initial_response_format, messages=state["messages"], tool_choice=None, ) # Apply modify_model_request middleware in sequence - for m in middleware_w_modify_model_request: - if m.__class__.modify_model_request is not AgentMiddleware.modify_model_request: - m.modify_model_request(request, state, runtime) + for hook_info in middleware_w_modify_model_request: + if hook_info.sync_fn: + hook_info.sync_fn(request, state, runtime) else: msg = ( f"No synchronous function provided for " - f'{m.__class__.__name__}.amodify_model_request".' + f"{hook_info.middleware_name}.amodify_model_request" "\nEither initialize with a synchronous function or invoke" " via the async API (ainvoke, astream, etc.)" ) raise TypeError(msg) # Retry loop for model invocation with error handling - # Hard limit of 100 attempts to prevent infinite loops from buggy middleware max_attempts = 100 for attempt in range(1, max_attempts + 1): try: @@ -551,18 +915,17 @@ def model_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, An } except Exception as error: # Try retry_model_request on each middleware - for m in middleware_w_retry: - if m.__class__.retry_model_request is not AgentMiddleware.retry_model_request: - if retry_request := m.retry_model_request( + for hook_info in middleware_w_retry: + if hook_info.sync_fn: + if retry_request := hook_info.sync_fn( error, request, state, runtime, attempt ): - # Break on first middleware that wants to retry request = retry_request break else: msg = ( f"No synchronous function provided for " - f'{m.__class__.__name__}.aretry_model_request".' + f"{hook_info.middleware_name}.aretry_model_request" "\nEither initialize with a synchronous function or invoke" " via the async API (ainvoke, astream, etc.)" ) @@ -577,20 +940,26 @@ def model_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, An async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]: """Async model request handler with sequential middleware processing.""" request = ModelRequest( - model=model, + model=model_instance, tools=default_tools, - system_prompt=system_prompt, + system_prompt=components.system_prompt, response_format=initial_response_format, messages=state["messages"], tool_choice=None, ) # Apply modify_model_request middleware in sequence - for m in middleware_w_modify_model_request: - await m.amodify_model_request(request, state, runtime) + for hook_info in middleware_w_modify_model_request: + if hook_info.async_fn: + await hook_info.async_fn(request, state, runtime) + elif hook_info.sync_fn: + # Fallback to sync if only sync is implemented + await run_in_executor(None, hook_info.sync_fn, request, state, runtime) + else: + msg = f"No function provided for {hook_info.middleware_name}.modify_model_request" + raise RuntimeError(msg) # Retry loop for model invocation with error handling - # Hard limit of 100 attempts to prevent infinite loops from buggy middleware max_attempts = 100 for attempt in range(1, max_attempts + 1): try: @@ -608,135 +977,41 @@ async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[ } except Exception as error: # Try retry_model_request on each middleware - for m in middleware_w_retry: - if retry_request := await m.aretry_model_request( - error, request, state, runtime, attempt - ): - # Break on first middleware that wants to retry + for hook_info in middleware_w_retry: + retry_request = None + if hook_info.async_fn: + retry_request = await hook_info.async_fn( + error, request, state, runtime, attempt + ) + elif hook_info.sync_fn: + # Fallback to sync if only sync is implemented + retry_request = await run_in_executor( + None, hook_info.sync_fn, error, request, state, runtime, attempt + ) + + if retry_request: request = retry_request break else: - # If no middleware wants to retry, re-raise the error raise # If we exit the loop, max attempts exceeded msg = f"Maximum retry attempts ({max_attempts}) exceeded" raise RuntimeError(msg) - # Use sync or async based on model capabilities - from langgraph._internal._runnable import RunnableCallable - - graph.add_node("model_request", RunnableCallable(model_request, amodel_request)) - - # Only add tools node if we have tools - if tool_node is not None: - graph.add_node("tools", tool_node) - - # Add middleware nodes - for m in middleware: - if ( - m.__class__.before_model is not AgentMiddleware.before_model - or m.__class__.abefore_model is not AgentMiddleware.abefore_model - ): - # Use RunnableCallable to support both sync and async - # Pass None for sync if not overridden to avoid signature conflicts - sync_before = ( - m.before_model - if m.__class__.before_model is not AgentMiddleware.before_model - else None - ) - async_before = ( - m.abefore_model - if m.__class__.abefore_model is not AgentMiddleware.abefore_model - else None - ) - before_node = RunnableCallable(sync_before, async_before) - graph.add_node(f"{m.name}.before_model", before_node, input_schema=state_schema) - - if ( - m.__class__.after_model is not AgentMiddleware.after_model - or m.__class__.aafter_model is not AgentMiddleware.aafter_model - ): - # Use RunnableCallable to support both sync and async - # Pass None for sync if not overridden to avoid signature conflicts - sync_after = ( - m.after_model - if m.__class__.after_model is not AgentMiddleware.after_model - else None - ) - async_after = ( - m.aafter_model - if m.__class__.aafter_model is not AgentMiddleware.aafter_model - else None - ) - after_node = RunnableCallable(sync_after, async_after) - graph.add_node(f"{m.name}.after_model", after_node, input_schema=state_schema) - - # add start edge - first_node = ( - f"{middleware_w_before[0].name}.before_model" if middleware_w_before else "model_request" - ) - last_node = ( - f"{middleware_w_after[0].name}.after_model" if middleware_w_after else "model_request" - ) - graph.add_edge(START, first_node) - - # add conditional edges only if tools exist - if tool_node is not None: - graph.add_conditional_edges( - "tools", - _make_tools_to_model_edge(tool_node, first_node, structured_output_tools), - [first_node, END], - ) - graph.add_conditional_edges( - last_node, - _make_model_to_tools_edge(first_node, structured_output_tools, tool_node), - [first_node, "tools", END], - ) - elif last_node == "model_request": - # If no tools, just go to END from model - graph.add_edge(last_node, END) - else: - # If after_model, then need to check for can_jump_to - _add_middleware_edge( - graph, - f"{middleware_w_after[0].name}.after_model", - END, - first_node, - can_jump_to=_get_can_jump_to(middleware_w_after[0], "after_model"), - ) + # Phase 5: Add nodes to graph + graph.add_node("model_request", RunnableCallable(model_request, amodel_request, trace=False)) + if components.tool_node is not None: + graph.add_node("tools", components.tool_node) + _add_middleware_nodes(graph, components, state_schema) - # Add middleware edges (same as before) - if middleware_w_before: - for m1, m2 in itertools.pairwise(middleware_w_before): - _add_middleware_edge( - graph, - f"{m1.name}.before_model", - f"{m2.name}.before_model", - first_node, - can_jump_to=_get_can_jump_to(m1, "before_model"), - ) - # Go directly to model_request after the last before_model - _add_middleware_edge( - graph, - f"{middleware_w_before[-1].name}.before_model", - "model_request", - first_node, - can_jump_to=_get_can_jump_to(middleware_w_before[-1], "before_model"), - ) + # Phase 6: Compute graph topology + topology = GraphTopology.compute(components.middleware_hooks) - if middleware_w_after: - graph.add_edge("model_request", f"{middleware_w_after[-1].name}.after_model") - for idx in range(len(middleware_w_after) - 1, 0, -1): - m1 = middleware_w_after[idx] - m2 = middleware_w_after[idx - 1] - _add_middleware_edge( - graph, - f"{m1.name}.after_model", - f"{m2.name}.after_model", - first_node, - can_jump_to=_get_can_jump_to(m1, "after_model"), - ) + # Phase 7: Connect edges + _connect_entry_edges(graph, topology) + _connect_loop_edges(graph, topology, components) + _connect_middleware_chains(graph, components, topology) return graph @@ -768,7 +1043,10 @@ def _fetch_last_ai_and_tool_messages( def _make_model_to_tools_edge( - first_node: str, structured_output_tools: dict[str, OutputToolBinding], tool_node: ToolNode + first_node: str, + structured_output_tools: dict[str, OutputToolBinding], + tool_node: ToolNode, + exit_node: str, ) -> Callable[[dict[str, Any]], str | list[Send] | None]: def model_to_tools(state: dict[str, Any]) -> str | list[Send] | None: # 1. if there's an explicit jump_to in the state, use it @@ -778,10 +1056,10 @@ def model_to_tools(state: dict[str, Any]) -> str | list[Send] | None: last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"]) tool_message_ids = [m.tool_call_id for m in tool_messages] - # 2. if the model hasn't called any tools, jump to END + # 2. if the model hasn't called any tools, exit the loop # this is the classic exit condition for an agent loop if len(last_ai_message.tool_calls) == 0: - return END + return exit_node pending_tool_calls = [ c @@ -804,7 +1082,10 @@ def model_to_tools(state: dict[str, Any]) -> str | list[Send] | None: def _make_tools_to_model_edge( - tool_node: ToolNode, next_node: str, structured_output_tools: dict[str, OutputToolBinding] + tool_node: ToolNode, + next_node: str, + structured_output_tools: dict[str, OutputToolBinding], + exit_node: str, ) -> Callable[[dict[str, Any]], str | None]: def tools_to_model(state: dict[str, Any]) -> str | None: last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"]) @@ -814,10 +1095,10 @@ def tools_to_model(state: dict[str, Any]) -> str | None: for c in last_ai_message.tool_calls if c["name"] in tool_node.tools_by_name ): - return END + return exit_node if any(t.name in structured_output_tools for t in tool_messages): - return END + return exit_node return next_node diff --git a/libs/langchain_v1/tests/unit_tests/agents/__snapshots__/test_middleware_agent.ambr b/libs/langchain_v1/tests/unit_tests/agents/__snapshots__/test_middleware_agent.ambr index 2581a1ea07426..6cd5e9d92f30c 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/__snapshots__/test_middleware_agent.ambr +++ b/libs/langchain_v1/tests/unit_tests/agents/__snapshots__/test_middleware_agent.ambr @@ -73,8 +73,8 @@ __start__([
__start__
]):::first model_request(model_request) NoopTen\2ebefore_model(NoopTen.before_model) - NoopTen\2eafter_model(NoopTen.after_model) NoopEleven\2ebefore_model(NoopEleven.before_model) + NoopTen\2eafter_model(NoopTen.after_model) NoopEleven\2eafter_model(NoopEleven.after_model) __end__([__end__
]):::last NoopEleven\2eafter_model --> NoopTen\2eafter_model; @@ -240,8 +240,8 @@ __start__([__start__
]):::first model_request(model_request) NoopSeven\2ebefore_model(NoopSeven.before_model) - NoopSeven\2eafter_model(NoopSeven.after_model) NoopEight\2ebefore_model(NoopEight.before_model) + NoopSeven\2eafter_model(NoopSeven.after_model) NoopEight\2eafter_model(NoopEight.after_model) __end__([__end__
]):::last NoopEight\2eafter_model --> NoopSeven\2eafter_model; @@ -267,10 +267,10 @@ __start__([__start__
]):::first model_request(model_request) NoopSeven\2ebefore_model(NoopSeven.before_model) - NoopSeven\2eafter_model(NoopSeven.after_model) NoopEight\2ebefore_model(NoopEight.before_model) - NoopEight\2eafter_model(NoopEight.after_model) NoopNine\2ebefore_model(NoopNine.before_model) + NoopSeven\2eafter_model(NoopSeven.after_model) + NoopEight\2eafter_model(NoopEight.after_model) NoopNine\2eafter_model(NoopNine.after_model) __end__([__end__
]):::last NoopEight\2eafter_model --> NoopSeven\2eafter_model; @@ -299,8 +299,8 @@ model_request(model_request) tools(tools) NoopSeven\2ebefore_model(NoopSeven.before_model) - NoopSeven\2eafter_model(NoopSeven.after_model) NoopEight\2ebefore_model(NoopEight.before_model) + NoopSeven\2eafter_model(NoopSeven.after_model) NoopEight\2eafter_model(NoopEight.after_model) __end__([__end__
]):::last NoopEight\2eafter_model --> NoopSeven\2eafter_model; @@ -332,8 +332,8 @@ model_request(model_request) tools(tools) NoopSeven\2ebefore_model(NoopSeven.before_model) - NoopSeven\2eafter_model(NoopSeven.after_model) NoopEight\2ebefore_model(NoopEight.before_model) + NoopSeven\2eafter_model(NoopSeven.after_model) NoopEight\2eafter_model(NoopEight.after_model) __end__([__end__
]):::last NoopEight\2eafter_model --> NoopSeven\2eafter_model; @@ -365,8 +365,8 @@ model_request(model_request) tools(tools) NoopSeven\2ebefore_model(NoopSeven.before_model) - NoopSeven\2eafter_model(NoopSeven.after_model) NoopEight\2ebefore_model(NoopEight.before_model) + NoopSeven\2eafter_model(NoopSeven.after_model) NoopEight\2eafter_model(NoopEight.after_model) __end__([__end__
]):::last NoopEight\2eafter_model --> NoopSeven\2eafter_model; @@ -398,8 +398,8 @@ model_request(model_request) tools(tools) NoopSeven\2ebefore_model(NoopSeven.before_model) - NoopSeven\2eafter_model(NoopSeven.after_model) NoopEight\2ebefore_model(NoopEight.before_model) + NoopSeven\2eafter_model(NoopSeven.after_model) NoopEight\2eafter_model(NoopEight.after_model) __end__([__end__
]):::last NoopEight\2eafter_model --> NoopSeven\2eafter_model; @@ -431,8 +431,8 @@ model_request(model_request) tools(tools) NoopSeven\2ebefore_model(NoopSeven.before_model) - NoopSeven\2eafter_model(NoopSeven.after_model) NoopEight\2ebefore_model(NoopEight.before_model) + NoopSeven\2eafter_model(NoopSeven.after_model) NoopEight\2eafter_model(NoopEight.after_model) __end__([__end__
]):::last NoopEight\2eafter_model --> NoopSeven\2eafter_model; diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/__init__.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/test_before_after_agent.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_before_after_agent.py new file mode 100644 index 0000000000000..d13fec4b2272a --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_before_after_agent.py @@ -0,0 +1,388 @@ +"""Unit tests for before_agent and after_agent middleware hooks.""" + +from typing import Any + +from langchain.agents import create_agent +from langchain.agents.middleware import ( + AgentMiddleware, + AgentState, + after_agent, + before_agent, +) +from langchain_core.language_models.fake_chat_models import GenericFakeChatModel +from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.tools import tool +from langgraph.runtime import Runtime + +from ..model import FakeToolCallingModel + + +@tool +def sample_tool(query: str) -> str: + """A sample tool for testing.""" + return f"Result for: {query}" + + +class TestBeforeAgentBasic: + """Test basic before_agent functionality.""" + + def test_sync_before_agent_execution(self) -> None: + """Test that before_agent hook is called synchronously.""" + execution_log = [] + + @before_agent + def log_before_agent(state: AgentState, runtime: Runtime) -> dict[str, Any] | None: + execution_log.append("before_agent_called") + execution_log.append(f"message_count: {len(state['messages'])}") + return None + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello!")])) + + agent = create_agent(model=model, tools=[], middleware=[log_before_agent]) + + agent.invoke({"messages": [HumanMessage("Hi")]}) + + assert "before_agent_called" in execution_log + assert "message_count: 1" in execution_log + + async def test_async_before_agent_execution(self) -> None: + """Test that before_agent hook is called asynchronously.""" + execution_log = [] + + @before_agent + async def async_log_before_agent( + state: AgentState, runtime: Runtime + ) -> dict[str, Any] | None: + execution_log.append("async_before_agent_called") + execution_log.append(f"message_count: {len(state['messages'])}") + return None + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello!")])) + + agent = create_agent(model=model, tools=[], middleware=[async_log_before_agent]) + + await agent.ainvoke({"messages": [HumanMessage("Hi")]}) + + assert "async_before_agent_called" in execution_log + assert "message_count: 1" in execution_log + + def test_before_agent_state_modification(self) -> None: + """Test that before_agent can modify state.""" + + @before_agent + def add_metadata(state: AgentState, runtime: Runtime) -> dict[str, Any]: + return {"messages": [HumanMessage("Injected by middleware")]} + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")])) + + agent = create_agent(model=model, tools=[], middleware=[add_metadata]) + + result = agent.invoke({"messages": [HumanMessage("Original")]}) + + # Should have original + injected + AI response + assert len(result["messages"]) >= 2 + message_contents = [msg.content for msg in result["messages"]] + assert "Injected by middleware" in message_contents + + def test_before_agent_with_class_inheritance(self) -> None: + """Test before_agent using class inheritance.""" + execution_log = [] + + class CustomBeforeAgentMiddleware(AgentMiddleware): + def before_agent(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: + execution_log.append("class_before_agent_called") + return None + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")])) + + agent = create_agent(model=model, tools=[], middleware=[CustomBeforeAgentMiddleware()]) + + agent.invoke({"messages": [HumanMessage("Test")]}) + + assert "class_before_agent_called" in execution_log + + async def test_before_agent_with_async_class_inheritance(self) -> None: + """Test async before_agent using class inheritance.""" + execution_log = [] + + class CustomAsyncBeforeAgentMiddleware(AgentMiddleware): + async def abefore_agent( + self, state: AgentState, runtime: Runtime + ) -> dict[str, Any] | None: + execution_log.append("async_class_before_agent_called") + return None + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")])) + + agent = create_agent(model=model, tools=[], middleware=[CustomAsyncBeforeAgentMiddleware()]) + + await agent.ainvoke({"messages": [HumanMessage("Test")]}) + + assert "async_class_before_agent_called" in execution_log + + +class TestAfterAgentBasic: + """Test basic after_agent functionality.""" + + def test_sync_after_agent_execution(self) -> None: + """Test that after_agent hook is called synchronously.""" + execution_log = [] + + @after_agent + def log_after_agent(state: AgentState, runtime: Runtime) -> dict[str, Any] | None: + execution_log.append("after_agent_called") + execution_log.append(f"final_message_count: {len(state['messages'])}") + return None + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Final response")])) + + agent = create_agent(model=model, tools=[], middleware=[log_after_agent]) + + agent.invoke({"messages": [HumanMessage("Hi")]}) + + assert "after_agent_called" in execution_log + assert any("final_message_count:" in log for log in execution_log) + + async def test_async_after_agent_execution(self) -> None: + """Test that after_agent hook is called asynchronously.""" + execution_log = [] + + @after_agent + async def async_log_after_agent( + state: AgentState, runtime: Runtime + ) -> dict[str, Any] | None: + execution_log.append("async_after_agent_called") + return None + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")])) + + agent = create_agent(model=model, tools=[], middleware=[async_log_after_agent]) + + await agent.ainvoke({"messages": [HumanMessage("Hi")]}) + + assert "async_after_agent_called" in execution_log + + def test_after_agent_state_modification(self) -> None: + """Test that after_agent can modify state.""" + + @after_agent + def add_final_message(state: AgentState, runtime: Runtime) -> dict[str, Any]: + return {"messages": [AIMessage("Added by after_agent")]} + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Model response")])) + + agent = create_agent(model=model, tools=[], middleware=[add_final_message]) + + result = agent.invoke({"messages": [HumanMessage("Test")]}) + + message_contents = [msg.content for msg in result["messages"]] + assert "Added by after_agent" in message_contents + + def test_after_agent_with_class_inheritance(self) -> None: + """Test after_agent using class inheritance.""" + execution_log = [] + + class CustomAfterAgentMiddleware(AgentMiddleware): + def after_agent(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: + execution_log.append("class_after_agent_called") + return None + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")])) + + agent = create_agent(model=model, tools=[], middleware=[CustomAfterAgentMiddleware()]) + + agent.invoke({"messages": [HumanMessage("Test")]}) + + assert "class_after_agent_called" in execution_log + + async def test_after_agent_with_async_class_inheritance(self) -> None: + """Test async after_agent using class inheritance.""" + execution_log = [] + + class CustomAsyncAfterAgentMiddleware(AgentMiddleware): + async def aafter_agent( + self, state: AgentState, runtime: Runtime + ) -> dict[str, Any] | None: + execution_log.append("async_class_after_agent_called") + return None + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")])) + + agent = create_agent(model=model, tools=[], middleware=[CustomAsyncAfterAgentMiddleware()]) + + await agent.ainvoke({"messages": [HumanMessage("Test")]}) + + assert "async_class_after_agent_called" in execution_log + + +class TestBeforeAndAfterAgentCombined: + """Test before_agent and after_agent hooks working together.""" + + def test_execution_order(self) -> None: + """Test that before_agent executes before after_agent.""" + execution_log = [] + + @before_agent + def log_before(state: AgentState, runtime: Runtime) -> None: + execution_log.append("before") + + @after_agent + def log_after(state: AgentState, runtime: Runtime) -> None: + execution_log.append("after") + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")])) + + agent = create_agent(model=model, tools=[], middleware=[log_before, log_after]) + + agent.invoke({"messages": [HumanMessage("Test")]}) + + assert execution_log == ["before", "after"] + + async def test_async_execution_order(self) -> None: + """Test async execution order of before_agent and after_agent.""" + execution_log = [] + + @before_agent + async def async_log_before(state: AgentState, runtime: Runtime) -> None: + execution_log.append("async_before") + + @after_agent + async def async_log_after(state: AgentState, runtime: Runtime) -> None: + execution_log.append("async_after") + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")])) + + agent = create_agent(model=model, tools=[], middleware=[async_log_before, async_log_after]) + + await agent.ainvoke({"messages": [HumanMessage("Test")]}) + + assert execution_log == ["async_before", "async_after"] + + def test_state_passthrough(self) -> None: + """Test that state modifications in before_agent are visible to after_agent.""" + collected_states = {} + + @before_agent + def modify_in_before(state: AgentState, runtime: Runtime) -> dict[str, Any]: + return {"messages": [HumanMessage("Modified by before_agent")]} + + @after_agent + def capture_in_after(state: AgentState, runtime: Runtime) -> None: + collected_states["messages"] = state["messages"] + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")])) + + agent = create_agent(model=model, tools=[], middleware=[modify_in_before, capture_in_after]) + + agent.invoke({"messages": [HumanMessage("Original")]}) + + message_contents = [msg.content for msg in collected_states["messages"]] + assert "Modified by before_agent" in message_contents + + def test_multiple_middleware_instances(self) -> None: + """Test multiple before_agent and after_agent middleware instances.""" + execution_log = [] + + @before_agent + def before_one(state: AgentState, runtime: Runtime) -> None: + execution_log.append("before_1") + + @before_agent + def before_two(state: AgentState, runtime: Runtime) -> None: + execution_log.append("before_2") + + @after_agent + def after_one(state: AgentState, runtime: Runtime) -> None: + execution_log.append("after_1") + + @after_agent + def after_two(state: AgentState, runtime: Runtime) -> None: + execution_log.append("after_2") + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")])) + + agent = create_agent( + model=model, tools=[], middleware=[before_one, before_two, after_one, after_two] + ) + + agent.invoke({"messages": [HumanMessage("Test")]}) + + assert "before_1" in execution_log + assert "before_2" in execution_log + assert "after_1" in execution_log + assert "after_2" in execution_log + + def test_agent_hooks_run_once_with_multiple_model_calls(self) -> None: + """Test that before_agent and after_agent run only once even with tool calls.""" + execution_log = [] + + @before_agent + def log_before_agent(state: AgentState, runtime: Runtime) -> None: + execution_log.append("before_agent") + + @after_agent + def log_after_agent(state: AgentState, runtime: Runtime) -> None: + execution_log.append("after_agent") + + # Model will call a tool once, then respond with final answer + model = FakeToolCallingModel( + tool_calls=[ + [{"name": "sample_tool", "args": {"query": "test"}, "id": "1"}], + [], # Second call returns no tool calls (final answer) + ] + ) + + agent = create_agent( + model=model, + tools=[sample_tool], + middleware=[log_before_agent, log_after_agent], + ) + + agent.invoke({"messages": [HumanMessage("Test")]}) + + # before_agent and after_agent should run exactly once + assert execution_log.count("before_agent") == 1 + assert execution_log.count("after_agent") == 1 + # before_agent should run first, after_agent should run last + assert execution_log[0] == "before_agent" + assert execution_log[-1] == "after_agent" + + +class TestDecoratorParameters: + """Test decorator parameters for before_agent and after_agent.""" + + def test_before_agent_with_custom_name(self) -> None: + """Test before_agent with custom middleware name.""" + + @before_agent(name="CustomBeforeAgentMiddleware") + def custom_named_before(state: AgentState, runtime: Runtime) -> None: + pass + + assert custom_named_before.name == "CustomBeforeAgentMiddleware" + + def test_after_agent_with_custom_name(self) -> None: + """Test after_agent with custom middleware name.""" + + @after_agent(name="CustomAfterAgentMiddleware") + def custom_named_after(state: AgentState, runtime: Runtime) -> None: + pass + + assert custom_named_after.name == "CustomAfterAgentMiddleware" + + def test_before_agent_default_name(self) -> None: + """Test that before_agent uses function name by default.""" + + @before_agent + def my_before_agent_function(state: AgentState, runtime: Runtime) -> None: + pass + + assert my_before_agent_function.name == "my_before_agent_function" + + def test_after_agent_default_name(self) -> None: + """Test that after_agent uses function name by default.""" + + @after_agent + def my_after_agent_function(state: AgentState, runtime: Runtime) -> None: + pass + + assert my_after_agent_function.name == "my_after_agent_function" diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_middleware_decorators.py b/libs/langchain_v1/tests/unit_tests/agents/test_middleware_decorators.py index 536d5fdc6e776..c393202b2199a 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/test_middleware_decorators.py +++ b/libs/langchain_v1/tests/unit_tests/agents/test_middleware_decorators.py @@ -8,8 +8,8 @@ from langchain_core.messages import HumanMessage, AIMessage from langchain_core.tools import tool from langgraph.runtime import Runtime -from langgraph.types import Command +from langchain.agents.middleware_agent import create_agent from langchain.agents.middleware.types import ( AgentMiddleware, AgentState, @@ -20,7 +20,6 @@ modify_model_request, hook_config, ) -from langchain.agents.middleware_agent import create_agent, _get_can_jump_to from .model import FakeToolCallingModel @@ -470,49 +469,6 @@ async def async_early_exit(state: AgentState, runtime: Runtime) -> dict[str, Any assert len(result["messages"]) > 1 -def test_get_can_jump_to_no_false_positives() -> None: - """Test that _get_can_jump_to doesn't return false positives for base class methods.""" - - # Middleware with no overridden methods should return empty list - class EmptyMiddleware(AgentMiddleware): - pass - - empty_middleware = EmptyMiddleware() - empty_middleware.tools = [] - - # Should not return any jump destinations for base class methods - assert _get_can_jump_to(empty_middleware, "before_model") == [] - assert _get_can_jump_to(empty_middleware, "after_model") == [] - - -def test_get_can_jump_to_only_overridden_methods() -> None: - """Test that _get_can_jump_to only checks overridden methods.""" - - # Middleware with only sync method overridden - class SyncOnlyMiddleware(AgentMiddleware): - @hook_config(can_jump_to=["end"]) - def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: - return None - - sync_middleware = SyncOnlyMiddleware() - sync_middleware.tools = [] - - # Should return can_jump_to from overridden sync method - assert _get_can_jump_to(sync_middleware, "before_model") == ["end"] - - # Middleware with only async method overridden - class AsyncOnlyMiddleware(AgentMiddleware): - @hook_config(can_jump_to=["model"]) - async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: - return None - - async_middleware = AsyncOnlyMiddleware() - async_middleware.tools = [] - - # Should return can_jump_to from overridden async method - assert _get_can_jump_to(async_middleware, "after_model") == ["model"] - - def test_async_middleware_with_can_jump_to_graph_snapshot(snapshot: SnapshotAssertion) -> None: """Test that async middleware with can_jump_to creates correct graph structure with conditional edges."""