diff --git a/libs/langchain_v1/langchain/agents/middleware/__init__.py b/libs/langchain_v1/langchain/agents/middleware/__init__.py index 5915b75164a74..df39772f6ee60 100644 --- a/libs/langchain_v1/langchain/agents/middleware/__init__.py +++ b/libs/langchain_v1/langchain/agents/middleware/__init__.py @@ -12,7 +12,9 @@ AgentMiddleware, AgentState, ModelRequest, + after_agent, after_model, + before_agent, before_model, dynamic_prompt, hook_config, @@ -33,7 +35,9 @@ "PlanningMiddleware", "SummarizationMiddleware", "ToolCallLimitMiddleware", + "after_agent", "after_model", + "before_agent", "before_model", "dynamic_prompt", "hook_config", diff --git a/libs/langchain_v1/langchain/agents/middleware/types.py b/libs/langchain_v1/langchain/agents/middleware/types.py index 6aea90ab9fc35..327bab5e5bcc7 100644 --- a/libs/langchain_v1/langchain/agents/middleware/types.py +++ b/libs/langchain_v1/langchain/agents/middleware/types.py @@ -44,8 +44,13 @@ "ModelRequest", "OmitFromSchema", "PublicAgentState", + "after_agent", + "after_model", + "before_agent", + "before_model", "dynamic_prompt", "hook_config", + "modify_model_request", ] JumpTo = Literal["tools", "model", "end"] @@ -93,7 +98,7 @@ class AgentState(TypedDict, Generic[ResponseT]): messages: Required[Annotated[list[AnyMessage], add_messages]] jump_to: NotRequired[Annotated[JumpTo | None, EphemeralValue, PrivateStateAttr]] - structured_response: NotRequired[ResponseT] + structured_response: NotRequired[Annotated[ResponseT, OmitFromInput]] thread_model_call_count: NotRequired[Annotated[int, PrivateStateAttr]] run_model_call_count: NotRequired[Annotated[int, UntrackedValue, PrivateStateAttr]] @@ -133,6 +138,14 @@ def name(self) -> str: """ return self.__class__.__name__ + def before_agent(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None: + """Logic to run before the agent execution starts.""" + + async def abefore_agent( + self, state: StateT, runtime: Runtime[ContextT] + ) -> dict[str, Any] | None: + """Async logic to run before the agent execution starts.""" + def before_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None: """Logic to run before the model is called.""" @@ -215,6 +228,14 @@ async def aretry_model_request( None, self.retry_model_request, error, request, state, runtime, attempt ) + def after_agent(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None: + """Logic to run after the agent execution completes.""" + + async def aafter_agent( + self, state: StateT, runtime: Runtime[ContextT] + ) -> dict[str, Any] | None: + """Async logic to run after the agent execution completes.""" + class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]): """Callable with AgentState and Runtime as arguments.""" @@ -707,6 +728,279 @@ def wrapped( return decorator +@overload +def before_agent( + func: _CallableWithStateAndRuntime[StateT, ContextT], +) -> AgentMiddleware[StateT, ContextT]: ... + + +@overload +def before_agent( + func: None = None, + *, + state_schema: type[StateT] | None = None, + tools: list[BaseTool] | None = None, + can_jump_to: list[JumpTo] | None = None, + name: str | None = None, +) -> Callable[ + [_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT] +]: ... + + +def before_agent( + func: _CallableWithStateAndRuntime[StateT, ContextT] | None = None, + *, + state_schema: type[StateT] | None = None, + tools: list[BaseTool] | None = None, + can_jump_to: list[JumpTo] | None = None, + name: str | None = None, +) -> ( + Callable[[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]] + | AgentMiddleware[StateT, ContextT] +): + """Decorator used to dynamically create a middleware with the before_agent hook. + + Args: + func: The function to be decorated. Must accept: + `state: StateT, runtime: Runtime[ContextT]` - State and runtime context + state_schema: Optional custom state schema type. If not provided, uses the default + AgentState schema. + tools: Optional list of additional tools to register with this middleware. + can_jump_to: Optional list of valid jump destinations for conditional edges. + Valid values are: "tools", "model", "end" + name: Optional name for the generated middleware class. If not provided, + uses the decorated function's name. + + Returns: + Either an AgentMiddleware instance (if func is provided directly) or a decorator function + that can be applied to a function its wrapping. + + The decorated function should return: + - `dict[str, Any]` - State updates to merge into the agent state + - `Command` - A command to control flow (e.g., jump to different node) + - `None` - No state updates or flow control + + Examples: + Basic usage: + ```python + @before_agent + def log_before_agent(state: AgentState, runtime: Runtime) -> None: + print(f"Starting agent with {len(state['messages'])} messages") + ``` + + With conditional jumping: + ```python + @before_agent(can_jump_to=["end"]) + def conditional_before_agent(state: AgentState, runtime: Runtime) -> dict[str, Any] | None: + if some_condition(state): + return {"jump_to": "end"} + return None + ``` + + With custom state schema: + ```python + @before_agent(state_schema=MyCustomState) + def custom_before_agent(state: MyCustomState, runtime: Runtime) -> dict[str, Any]: + return {"custom_field": "initialized_value"} + ``` + """ + + def decorator( + func: _CallableWithStateAndRuntime[StateT, ContextT], + ) -> AgentMiddleware[StateT, ContextT]: + is_async = iscoroutinefunction(func) + + func_can_jump_to = ( + can_jump_to if can_jump_to is not None else getattr(func, "__can_jump_to__", []) + ) + + if is_async: + + async def async_wrapped( + self: AgentMiddleware[StateT, ContextT], # noqa: ARG001 + state: StateT, + runtime: Runtime[ContextT], + ) -> dict[str, Any] | Command | None: + return await func(state, runtime) # type: ignore[misc] + + # Preserve can_jump_to metadata on the wrapped function + if func_can_jump_to: + async_wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined] + + middleware_name = name or cast( + "str", getattr(func, "__name__", "BeforeAgentMiddleware") + ) + + return type( + middleware_name, + (AgentMiddleware,), + { + "state_schema": state_schema or AgentState, + "tools": tools or [], + "abefore_agent": async_wrapped, + }, + )() + + def wrapped( + self: AgentMiddleware[StateT, ContextT], # noqa: ARG001 + state: StateT, + runtime: Runtime[ContextT], + ) -> dict[str, Any] | Command | None: + return func(state, runtime) # type: ignore[return-value] + + # Preserve can_jump_to metadata on the wrapped function + if func_can_jump_to: + wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined] + + # Use function name as default if no name provided + middleware_name = name or cast("str", getattr(func, "__name__", "BeforeAgentMiddleware")) + + return type( + middleware_name, + (AgentMiddleware,), + { + "state_schema": state_schema or AgentState, + "tools": tools or [], + "before_agent": wrapped, + }, + )() + + if func is not None: + return decorator(func) + return decorator + + +@overload +def after_agent( + func: _CallableWithStateAndRuntime[StateT, ContextT], +) -> AgentMiddleware[StateT, ContextT]: ... + + +@overload +def after_agent( + func: None = None, + *, + state_schema: type[StateT] | None = None, + tools: list[BaseTool] | None = None, + can_jump_to: list[JumpTo] | None = None, + name: str | None = None, +) -> Callable[ + [_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT] +]: ... + + +def after_agent( + func: _CallableWithStateAndRuntime[StateT, ContextT] | None = None, + *, + state_schema: type[StateT] | None = None, + tools: list[BaseTool] | None = None, + can_jump_to: list[JumpTo] | None = None, + name: str | None = None, +) -> ( + Callable[[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]] + | AgentMiddleware[StateT, ContextT] +): + """Decorator used to dynamically create a middleware with the after_agent hook. + + Args: + func: The function to be decorated. Must accept: + `state: StateT, runtime: Runtime[ContextT]` - State and runtime context + state_schema: Optional custom state schema type. If not provided, uses the default + AgentState schema. + tools: Optional list of additional tools to register with this middleware. + can_jump_to: Optional list of valid jump destinations for conditional edges. + Valid values are: "tools", "model", "end" + name: Optional name for the generated middleware class. If not provided, + uses the decorated function's name. + + Returns: + Either an AgentMiddleware instance (if func is provided) or a decorator function + that can be applied to a function. + + The decorated function should return: + - `dict[str, Any]` - State updates to merge into the agent state + - `Command` - A command to control flow (e.g., jump to different node) + - `None` - No state updates or flow control + + Examples: + Basic usage for logging agent completion: + ```python + @after_agent + def log_completion(state: AgentState, runtime: Runtime) -> None: + print(f"Agent completed with {len(state['messages'])} messages") + ``` + + With custom state schema: + ```python + @after_agent(state_schema=MyCustomState, name="MyAfterAgentMiddleware") + def custom_after_agent(state: MyCustomState, runtime: Runtime) -> dict[str, Any]: + return {"custom_field": "finalized_value"} + ``` + """ + + def decorator( + func: _CallableWithStateAndRuntime[StateT, ContextT], + ) -> AgentMiddleware[StateT, ContextT]: + is_async = iscoroutinefunction(func) + # Extract can_jump_to from decorator parameter or from function metadata + func_can_jump_to = ( + can_jump_to if can_jump_to is not None else getattr(func, "__can_jump_to__", []) + ) + + if is_async: + + async def async_wrapped( + self: AgentMiddleware[StateT, ContextT], # noqa: ARG001 + state: StateT, + runtime: Runtime[ContextT], + ) -> dict[str, Any] | Command | None: + return await func(state, runtime) # type: ignore[misc] + + # Preserve can_jump_to metadata on the wrapped function + if func_can_jump_to: + async_wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined] + + middleware_name = name or cast("str", getattr(func, "__name__", "AfterAgentMiddleware")) + + return type( + middleware_name, + (AgentMiddleware,), + { + "state_schema": state_schema or AgentState, + "tools": tools or [], + "aafter_agent": async_wrapped, + }, + )() + + def wrapped( + self: AgentMiddleware[StateT, ContextT], # noqa: ARG001 + state: StateT, + runtime: Runtime[ContextT], + ) -> dict[str, Any] | Command | None: + return func(state, runtime) # type: ignore[return-value] + + # Preserve can_jump_to metadata on the wrapped function + if func_can_jump_to: + wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined] + + # Use function name as default if no name provided + middleware_name = name or cast("str", getattr(func, "__name__", "AfterAgentMiddleware")) + + return type( + middleware_name, + (AgentMiddleware,), + { + "state_schema": state_schema or AgentState, + "tools": tools or [], + "after_agent": wrapped, + }, + )() + + if func is not None: + return decorator(func) + return decorator + + @overload def dynamic_prompt( func: _CallableReturningPromptString[StateT, ContextT], diff --git a/libs/langchain_v1/langchain/agents/middleware_agent.py b/libs/langchain_v1/langchain/agents/middleware_agent.py index b4f6370446b05..6ff5d306b22c3 100644 --- a/libs/langchain_v1/langchain/agents/middleware_agent.py +++ b/libs/langchain_v1/langchain/agents/middleware_agent.py @@ -8,6 +8,7 @@ from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage from langchain_core.runnables import Runnable from langchain_core.tools import BaseTool +from langgraph._internal._runnable import RunnableCallable from langgraph.constants import END, START from langgraph.graph.state import StateGraph from langgraph.runtime import Runtime @@ -260,7 +261,13 @@ def create_agent( # noqa: PLR0915 assert len({m.name for m in middleware}) == len(middleware), ( # noqa: S101 "Please remove duplicate middleware instances." ) - middleware_w_before = [ + middleware_w_before_agent = [ + m + for m in middleware + if m.__class__.before_agent is not AgentMiddleware.before_agent + or m.__class__.abefore_agent is not AgentMiddleware.abefore_agent + ] + middleware_w_before_model = [ m for m in middleware if m.__class__.before_model is not AgentMiddleware.before_model @@ -272,12 +279,18 @@ def create_agent( # noqa: PLR0915 if m.__class__.modify_model_request is not AgentMiddleware.modify_model_request or m.__class__.amodify_model_request is not AgentMiddleware.amodify_model_request ] - middleware_w_after = [ + middleware_w_after_model = [ m for m in middleware if m.__class__.after_model is not AgentMiddleware.after_model or m.__class__.aafter_model is not AgentMiddleware.aafter_model ] + middleware_w_after_agent = [ + m + for m in middleware + if m.__class__.after_agent is not AgentMiddleware.after_agent + or m.__class__.aafter_agent is not AgentMiddleware.aafter_agent + ] middleware_w_retry = [ m for m in middleware @@ -624,9 +637,7 @@ async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[ raise RuntimeError(msg) # Use sync or async based on model capabilities - from langgraph._internal._runnable import RunnableCallable - - graph.add_node("model_request", RunnableCallable(model_request, amodel_request)) + graph.add_node("model_request", RunnableCallable(model_request, amodel_request, trace=False)) # Only add tools node if we have tools if tool_node is not None: @@ -634,6 +645,25 @@ async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[ # Add middleware nodes for m in middleware: + if ( + m.__class__.before_agent is not AgentMiddleware.before_agent + or m.__class__.abefore_agent is not AgentMiddleware.abefore_agent + ): + # Use RunnableCallable to support both sync and async + # Pass None for sync if not overridden to avoid signature conflicts + sync_before_agent = ( + m.before_agent + if m.__class__.before_agent is not AgentMiddleware.before_agent + else None + ) + async_before_agent = ( + m.abefore_agent + if m.__class__.abefore_agent is not AgentMiddleware.abefore_agent + else None + ) + before_agent_node = RunnableCallable(sync_before_agent, async_before_agent, trace=False) + graph.add_node(f"{m.name}.before_agent", before_agent_node, input_schema=state_schema) + if ( m.__class__.before_model is not AgentMiddleware.before_model or m.__class__.abefore_model is not AgentMiddleware.abefore_model @@ -650,7 +680,7 @@ async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[ if m.__class__.abefore_model is not AgentMiddleware.abefore_model else None ) - before_node = RunnableCallable(sync_before, async_before) + before_node = RunnableCallable(sync_before, async_before, trace=False) graph.add_node(f"{m.name}.before_model", before_node, input_schema=state_schema) if ( @@ -669,74 +699,163 @@ async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[ if m.__class__.aafter_model is not AgentMiddleware.aafter_model else None ) - after_node = RunnableCallable(sync_after, async_after) + after_node = RunnableCallable(sync_after, async_after, trace=False) graph.add_node(f"{m.name}.after_model", after_node, input_schema=state_schema) - # add start edge - first_node = ( - f"{middleware_w_before[0].name}.before_model" if middleware_w_before else "model_request" - ) - last_node = ( - f"{middleware_w_after[0].name}.after_model" if middleware_w_after else "model_request" - ) - graph.add_edge(START, first_node) + if ( + m.__class__.after_agent is not AgentMiddleware.after_agent + or m.__class__.aafter_agent is not AgentMiddleware.aafter_agent + ): + # Use RunnableCallable to support both sync and async + # Pass None for sync if not overridden to avoid signature conflicts + sync_after_agent = ( + m.after_agent + if m.__class__.after_agent is not AgentMiddleware.after_agent + else None + ) + async_after_agent = ( + m.aafter_agent + if m.__class__.aafter_agent is not AgentMiddleware.aafter_agent + else None + ) + after_agent_node = RunnableCallable(sync_after_agent, async_after_agent, trace=False) + graph.add_node(f"{m.name}.after_agent", after_agent_node, input_schema=state_schema) + + # Determine the entry node (runs once at start): before_agent -> before_model -> model_request + if middleware_w_before_agent: + entry_node = f"{middleware_w_before_agent[0].name}.before_agent" + elif middleware_w_before_model: + entry_node = f"{middleware_w_before_model[0].name}.before_model" + else: + entry_node = "model_request" + + # Determine the loop entry node (beginning of agent loop, excludes before_agent) + # This is where tools will loop back to for the next iteration + if middleware_w_before_model: + loop_entry_node = f"{middleware_w_before_model[0].name}.before_model" + else: + loop_entry_node = "model_request" + # Determine the loop exit node (end of each iteration, can run multiple times) + # This is after_model or model_request, but NOT after_agent + if middleware_w_after_model: + loop_exit_node = f"{middleware_w_after_model[0].name}.after_model" + else: + loop_exit_node = "model_request" + + # Determine the exit node (runs once at end): after_agent or END + if middleware_w_after_agent: + exit_node = f"{middleware_w_after_agent[-1].name}.after_agent" + else: + exit_node = END + + graph.add_edge(START, entry_node) # add conditional edges only if tools exist if tool_node is not None: graph.add_conditional_edges( "tools", - _make_tools_to_model_edge(tool_node, first_node, structured_output_tools), - [first_node, END], + _make_tools_to_model_edge( + tool_node, loop_entry_node, structured_output_tools, exit_node + ), + [loop_entry_node, exit_node], ) + graph.add_conditional_edges( - last_node, - _make_model_to_tools_edge(first_node, structured_output_tools, tool_node), - [first_node, "tools", END], + loop_exit_node, + _make_model_to_tools_edge( + loop_entry_node, structured_output_tools, tool_node, exit_node + ), + [loop_entry_node, "tools", exit_node], ) - elif last_node == "model_request": - # If no tools, just go to END from model - graph.add_edge(last_node, END) + elif loop_exit_node == "model_request": + # If no tools and no after_model, go directly to exit_node + graph.add_edge(loop_exit_node, exit_node) + # No tools but we have after_model - connect after_model to exit_node else: - # If after_model, then need to check for can_jump_to _add_middleware_edge( graph, - f"{middleware_w_after[0].name}.after_model", - END, - first_node, - can_jump_to=_get_can_jump_to(middleware_w_after[0], "after_model"), + f"{middleware_w_after_model[0].name}.after_model", + exit_node, + loop_entry_node, + can_jump_to=_get_can_jump_to(middleware_w_after_model[0], "after_model"), ) - # Add middleware edges (same as before) - if middleware_w_before: - for m1, m2 in itertools.pairwise(middleware_w_before): + # Add before_agent middleware edges + if middleware_w_before_agent: + for m1, m2 in itertools.pairwise(middleware_w_before_agent): + _add_middleware_edge( + graph, + f"{m1.name}.before_agent", + f"{m2.name}.before_agent", + loop_entry_node, + can_jump_to=_get_can_jump_to(m1, "before_agent"), + ) + # Connect last before_agent to loop_entry_node (before_model or model_request) + _add_middleware_edge( + graph, + f"{middleware_w_before_agent[-1].name}.before_agent", + loop_entry_node, + loop_entry_node, + can_jump_to=_get_can_jump_to(middleware_w_before_agent[-1], "before_agent"), + ) + + # Add before_model middleware edges + if middleware_w_before_model: + for m1, m2 in itertools.pairwise(middleware_w_before_model): _add_middleware_edge( graph, f"{m1.name}.before_model", f"{m2.name}.before_model", - first_node, + loop_entry_node, can_jump_to=_get_can_jump_to(m1, "before_model"), ) # Go directly to model_request after the last before_model _add_middleware_edge( graph, - f"{middleware_w_before[-1].name}.before_model", + f"{middleware_w_before_model[-1].name}.before_model", "model_request", - first_node, - can_jump_to=_get_can_jump_to(middleware_w_before[-1], "before_model"), + loop_entry_node, + can_jump_to=_get_can_jump_to(middleware_w_before_model[-1], "before_model"), ) - if middleware_w_after: - graph.add_edge("model_request", f"{middleware_w_after[-1].name}.after_model") - for idx in range(len(middleware_w_after) - 1, 0, -1): - m1 = middleware_w_after[idx] - m2 = middleware_w_after[idx - 1] + # Add after_model middleware edges + if middleware_w_after_model: + graph.add_edge("model_request", f"{middleware_w_after_model[-1].name}.after_model") + for idx in range(len(middleware_w_after_model) - 1, 0, -1): + m1 = middleware_w_after_model[idx] + m2 = middleware_w_after_model[idx - 1] _add_middleware_edge( graph, f"{m1.name}.after_model", f"{m2.name}.after_model", - first_node, + loop_entry_node, can_jump_to=_get_can_jump_to(m1, "after_model"), ) + # Note: Connection from after_model to after_agent/END is handled above + # in the conditional edges section + + # Add after_agent middleware edges + if middleware_w_after_agent: + # Chain after_agent middleware (runs once at the very end, before END) + for idx in range(len(middleware_w_after_agent) - 1, 0, -1): + m1 = middleware_w_after_agent[idx] + m2 = middleware_w_after_agent[idx - 1] + _add_middleware_edge( + graph, + f"{m1.name}.after_agent", + f"{m2.name}.after_agent", + loop_entry_node, + can_jump_to=_get_can_jump_to(m1, "after_agent"), + ) + + # Connect the last after_agent to END + _add_middleware_edge( + graph, + f"{middleware_w_after_agent[0].name}.after_agent", + END, + loop_entry_node, + can_jump_to=_get_can_jump_to(middleware_w_after_agent[0], "after_agent"), + ) return graph @@ -768,7 +887,10 @@ def _fetch_last_ai_and_tool_messages( def _make_model_to_tools_edge( - first_node: str, structured_output_tools: dict[str, OutputToolBinding], tool_node: ToolNode + first_node: str, + structured_output_tools: dict[str, OutputToolBinding], + tool_node: ToolNode, + exit_node: str, ) -> Callable[[dict[str, Any]], str | list[Send] | None]: def model_to_tools(state: dict[str, Any]) -> str | list[Send] | None: # 1. if there's an explicit jump_to in the state, use it @@ -778,10 +900,10 @@ def model_to_tools(state: dict[str, Any]) -> str | list[Send] | None: last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"]) tool_message_ids = [m.tool_call_id for m in tool_messages] - # 2. if the model hasn't called any tools, jump to END + # 2. if the model hasn't called any tools, exit the loop # this is the classic exit condition for an agent loop if len(last_ai_message.tool_calls) == 0: - return END + return exit_node pending_tool_calls = [ c @@ -804,7 +926,10 @@ def model_to_tools(state: dict[str, Any]) -> str | list[Send] | None: def _make_tools_to_model_edge( - tool_node: ToolNode, next_node: str, structured_output_tools: dict[str, OutputToolBinding] + tool_node: ToolNode, + next_node: str, + structured_output_tools: dict[str, OutputToolBinding], + exit_node: str, ) -> Callable[[dict[str, Any]], str | None]: def tools_to_model(state: dict[str, Any]) -> str | None: last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"]) @@ -814,10 +939,10 @@ def tools_to_model(state: dict[str, Any]) -> str | None: for c in last_ai_message.tool_calls if c["name"] in tool_node.tools_by_name ): - return END + return exit_node if any(t.name in structured_output_tools for t in tool_messages): - return END + return exit_node return next_node diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/__init__.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/test_before_after_agent.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_before_after_agent.py new file mode 100644 index 0000000000000..9736c395448db --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_before_after_agent.py @@ -0,0 +1,286 @@ +"""Unit tests for before_agent and after_agent middleware hooks.""" + +from typing import Any + +import pytest +from langchain.agents import create_agent +from langchain.agents.middleware import ( + AgentMiddleware, + AgentState, + after_agent, + after_model, + before_model, + before_agent, +) +from langchain_core.language_models.fake_chat_models import GenericFakeChatModel +from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.tools import tool +from langgraph.runtime import Runtime + +from ..model import FakeToolCallingModel + + +@tool +def sample_tool(query: str) -> str: + """A sample tool for testing.""" + return f"Result for: {query}" + + +class TestAgentMiddlewareHooks: + """Test before_agent and after_agent middleware hooks.""" + + @pytest.mark.parametrize("is_async", [False, True]) + @pytest.mark.parametrize("hook_type", ["before", "after"]) + async def test_hook_execution(self, is_async: bool, hook_type: str) -> None: + """Test that agent hooks are called in both sync and async modes.""" + execution_log: list[str] = [] + + if is_async: + if hook_type == "before": + + @before_agent + async def log_hook(state: AgentState, runtime: Runtime) -> dict[str, Any] | None: + execution_log.append(f"{hook_type}_agent_called") + execution_log.append(f"message_count: {len(state['messages'])}") + return None + else: + + @after_agent + async def log_hook(state: AgentState, runtime: Runtime) -> dict[str, Any] | None: + execution_log.append(f"{hook_type}_agent_called") + execution_log.append(f"message_count: {len(state['messages'])}") + return None + else: + if hook_type == "before": + + @before_agent + def log_hook(state: AgentState, runtime: Runtime) -> dict[str, Any] | None: + execution_log.append(f"{hook_type}_agent_called") + execution_log.append(f"message_count: {len(state['messages'])}") + return None + else: + + @after_agent + def log_hook(state: AgentState, runtime: Runtime) -> dict[str, Any] | None: + execution_log.append(f"{hook_type}_agent_called") + execution_log.append(f"message_count: {len(state['messages'])}") + return None + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")])) + agent = create_agent(model=model, tools=[], middleware=[log_hook]) + + if is_async: + await agent.ainvoke({"messages": [HumanMessage("Hi")]}) + else: + agent.invoke({"messages": [HumanMessage("Hi")]}) + + assert f"{hook_type}_agent_called" in execution_log + assert any("message_count:" in log for log in execution_log) + + @pytest.mark.parametrize("is_async", [False, True]) + @pytest.mark.parametrize("hook_type", ["before", "after"]) + async def test_hook_with_class_inheritance(self, is_async: bool, hook_type: str) -> None: + """Test agent hooks using class inheritance in both sync and async modes.""" + execution_log: list[str] = [] + + if is_async: + + class CustomMiddleware(AgentMiddleware): + async def abefore_agent( + self, state: AgentState, runtime: Runtime + ) -> dict[str, Any] | None: + if hook_type == "before": + execution_log.append("hook_called") + return None + + async def aafter_agent( + self, state: AgentState, runtime: Runtime + ) -> dict[str, Any] | None: + if hook_type == "after": + execution_log.append("hook_called") + return None + else: + + class CustomMiddleware(AgentMiddleware): + def before_agent( + self, state: AgentState, runtime: Runtime + ) -> dict[str, Any] | None: + if hook_type == "before": + execution_log.append("hook_called") + return None + + def after_agent(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: + if hook_type == "after": + execution_log.append("hook_called") + return None + + middleware = CustomMiddleware() + model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")])) + agent = create_agent(model=model, tools=[], middleware=[middleware]) + + if is_async: + await agent.ainvoke({"messages": [HumanMessage("Test")]}) + else: + agent.invoke({"messages": [HumanMessage("Test")]}) + + assert "hook_called" in execution_log + + +class TestAgentHooksCombined: + """Test before_agent and after_agent hooks working together.""" + + @pytest.mark.parametrize("is_async", [False, True]) + async def test_execution_order(self, is_async: bool) -> None: + """Test that before_agent executes before after_agent in both sync and async modes.""" + execution_log: list[str] = [] + + if is_async: + + @before_agent + async def log_before(state: AgentState, runtime: Runtime) -> None: + execution_log.append("before") + + @after_agent + async def log_after(state: AgentState, runtime: Runtime) -> None: + execution_log.append("after") + else: + + @before_agent + def log_before(state: AgentState, runtime: Runtime) -> None: + execution_log.append("before") + + @after_agent + def log_after(state: AgentState, runtime: Runtime) -> None: + execution_log.append("after") + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")])) + agent = create_agent(model=model, tools=[], middleware=[log_before, log_after]) + + if is_async: + await agent.ainvoke({"messages": [HumanMessage("Test")]}) + else: + agent.invoke({"messages": [HumanMessage("Test")]}) + + assert execution_log == ["before", "after"] + + def test_state_passthrough(self) -> None: + """Test that state modifications in before_agent are visible to after_agent.""" + + @before_agent + def modify_in_before(state: AgentState, runtime: Runtime) -> dict[str, Any]: + return {"messages": [HumanMessage("Added by before_agent")]} + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")])) + agent = create_agent(model=model, tools=[], middleware=[modify_in_before]) + result = agent.invoke({"messages": [HumanMessage("Original")]}) + + message_contents = [msg.content for msg in result["messages"]] + assert message_contents[1] == "Added by before_agent" + + def test_multiple_middleware_instances(self) -> None: + """Test multiple before_agent and after_agent middleware instances.""" + execution_log = [] + + @before_agent + def before_one(state: AgentState, runtime: Runtime) -> None: + execution_log.append("before_1") + + @before_agent + def before_two(state: AgentState, runtime: Runtime) -> None: + execution_log.append("before_2") + + @after_agent + def after_one(state: AgentState, runtime: Runtime) -> None: + execution_log.append("after_1") + + @after_agent + def after_two(state: AgentState, runtime: Runtime) -> None: + execution_log.append("after_2") + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")])) + agent = create_agent( + model=model, tools=[], middleware=[before_one, before_two, after_one, after_two] + ) + agent.invoke({"messages": [HumanMessage("Test")]}) + + assert execution_log == ["before_1", "before_2", "after_2", "after_1"] + + def test_agent_hooks_run_once_with_multiple_model_calls(self) -> None: + """Test that before_agent and after_agent run only once per thread. + + This test verifies that agent-level hooks (before_agent, after_agent) execute + exactly once per agent invocation, regardless of how many tool calling loops occur. + This is different from model-level hooks (before_model, after_model) which run + on every model invocation within the tool calling loop. + """ + execution_log = [] + + @before_agent + def log_before_agent(state: AgentState, runtime: Runtime) -> None: + execution_log.append("before_agent") + + @before_model + def log_before_model(state: AgentState, runtime: Runtime) -> None: + execution_log.append("before_model") + + @after_agent + def log_after_agent(state: AgentState, runtime: Runtime) -> None: + execution_log.append("after_agent") + + @after_model + def log_after_model(state: AgentState, runtime: Runtime) -> None: + execution_log.append("after_model") + + # Model will call a tool twice, then respond with final answer + # This creates 3 model invocations total, but agent hooks should still run once + model = FakeToolCallingModel( + tool_calls=[ + [{"name": "sample_tool", "args": {"query": "first"}, "id": "1"}], + [{"name": "sample_tool", "args": {"query": "second"}, "id": "2"}], + [], # Third call returns no tool calls (final answer) + ] + ) + + agent = create_agent( + model=model, + tools=[sample_tool], + middleware=[log_before_agent, log_before_model, log_after_model, log_after_agent], + ) + + agent.invoke( + {"messages": [HumanMessage("Test")]}, config={"configurable": {"thread_id": "abc"}} + ) + + assert execution_log == [ + "before_agent", + "before_model", + "after_model", + "before_model", + "after_model", + "before_model", + "after_model", + "after_agent", + ] + + agent.invoke( + {"messages": [HumanMessage("Test")]}, config={"configurable": {"thread_id": "abc"}} + ) + + assert execution_log == [ + "before_agent", + "before_model", + "after_model", + "before_model", + "after_model", + "before_model", + "after_model", + "after_agent", + "before_agent", + "before_model", + "after_model", + "before_model", + "after_model", + "before_model", + "after_model", + "after_agent", + ]