From 983d84ade8ca3fdf33025ebfd7b303a658f4151f Mon Sep 17 00:00:00 2001 From: Sydney Runkle Date: Sun, 5 Oct 2025 17:02:28 -0400 Subject: [PATCH 1/8] first pass at before and after agent --- .../langchain/agents/middleware/__init__.py | 4 + .../langchain/agents/middleware/types.py | 294 +++++++++++++++ .../langchain/agents/middleware_agent.py | 141 ++++++- .../middleware/test_before_after_agent.py | 351 ++++++++++++++++++ 4 files changed, 780 insertions(+), 10 deletions(-) create mode 100644 libs/langchain_v1/tests/unit_tests/agents/middleware/test_before_after_agent.py diff --git a/libs/langchain_v1/langchain/agents/middleware/__init__.py b/libs/langchain_v1/langchain/agents/middleware/__init__.py index 5915b75164a74..df39772f6ee60 100644 --- a/libs/langchain_v1/langchain/agents/middleware/__init__.py +++ b/libs/langchain_v1/langchain/agents/middleware/__init__.py @@ -12,7 +12,9 @@ AgentMiddleware, AgentState, ModelRequest, + after_agent, after_model, + before_agent, before_model, dynamic_prompt, hook_config, @@ -33,7 +35,9 @@ "PlanningMiddleware", "SummarizationMiddleware", "ToolCallLimitMiddleware", + "after_agent", "after_model", + "before_agent", "before_model", "dynamic_prompt", "hook_config", diff --git a/libs/langchain_v1/langchain/agents/middleware/types.py b/libs/langchain_v1/langchain/agents/middleware/types.py index 6aea90ab9fc35..b10eaab7b3c0f 100644 --- a/libs/langchain_v1/langchain/agents/middleware/types.py +++ b/libs/langchain_v1/langchain/agents/middleware/types.py @@ -44,8 +44,13 @@ "ModelRequest", "OmitFromSchema", "PublicAgentState", + "after_agent", + "after_model", + "before_agent", + "before_model", "dynamic_prompt", "hook_config", + "modify_model_request", ] JumpTo = Literal["tools", "model", "end"] @@ -133,6 +138,14 @@ def name(self) -> str: """ return self.__class__.__name__ + def before_agent(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None: + """Logic to run before the agent execution starts.""" + + async def abefore_agent( + self, state: StateT, runtime: Runtime[ContextT] + ) -> dict[str, Any] | None: + """Async logic to run before the agent execution starts.""" + def before_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None: """Logic to run before the model is called.""" @@ -215,6 +228,14 @@ async def aretry_model_request( None, self.retry_model_request, error, request, state, runtime, attempt ) + def after_agent(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None: + """Logic to run after the agent execution completes.""" + + async def aafter_agent( + self, state: StateT, runtime: Runtime[ContextT] + ) -> dict[str, Any] | None: + """Async logic to run after the agent execution completes.""" + class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]): """Callable with AgentState and Runtime as arguments.""" @@ -707,6 +728,279 @@ def wrapped( return decorator +@overload +def before_agent( + func: _CallableWithStateAndRuntime[StateT, ContextT], +) -> AgentMiddleware[StateT, ContextT]: ... + + +@overload +def before_agent( + func: None = None, + *, + state_schema: type[StateT] | None = None, + tools: list[BaseTool] | None = None, + can_jump_to: list[JumpTo] | None = None, + name: str | None = None, +) -> Callable[ + [_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT] +]: ... + + +def before_agent( + func: _CallableWithStateAndRuntime[StateT, ContextT] | None = None, + *, + state_schema: type[StateT] | None = None, + tools: list[BaseTool] | None = None, + can_jump_to: list[JumpTo] | None = None, + name: str | None = None, +) -> ( + Callable[[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]] + | AgentMiddleware[StateT, ContextT] +): + """Decorator used to dynamically create a middleware with the before_agent hook. + + Args: + func: The function to be decorated. Must accept: + `state: StateT, runtime: Runtime[ContextT]` - State and runtime context + state_schema: Optional custom state schema type. If not provided, uses the default + AgentState schema. + tools: Optional list of additional tools to register with this middleware. + can_jump_to: Optional list of valid jump destinations for conditional edges. + Valid values are: "tools", "model", "end" + name: Optional name for the generated middleware class. If not provided, + uses the decorated function's name. + + Returns: + Either an AgentMiddleware instance (if func is provided directly) or a decorator function + that can be applied to a function its wrapping. + + The decorated function should return: + - `dict[str, Any]` - State updates to merge into the agent state + - `Command` - A command to control flow (e.g., jump to different node) + - `None` - No state updates or flow control + + Examples: + Basic usage: + ```python + @before_agent + def log_before_agent(state: AgentState, runtime: Runtime) -> None: + print(f"Starting agent with {len(state['messages'])} messages") + ``` + + With conditional jumping: + ```python + @before_agent(can_jump_to=["end"]) + def conditional_before_agent(state: AgentState, runtime: Runtime) -> dict[str, Any] | None: + if some_condition(state): + return {"jump_to": "end"} + return None + ``` + + With custom state schema: + ```python + @before_agent(state_schema=MyCustomState) + def custom_before_agent(state: MyCustomState, runtime: Runtime) -> dict[str, Any]: + return {"custom_field": "initialized_value"} + ``` + """ + + def decorator( + func: _CallableWithStateAndRuntime[StateT, ContextT], + ) -> AgentMiddleware[StateT, ContextT]: + is_async = iscoroutinefunction(func) + + func_can_jump_to = ( + can_jump_to if can_jump_to is not None else getattr(func, "__can_jump_to__", []) + ) + + if is_async: + + async def async_wrapped( + self: AgentMiddleware[StateT, ContextT], # noqa: ARG001 + state: StateT, + runtime: Runtime[ContextT], + ) -> dict[str, Any] | Command | None: + return await func(state, runtime) # type: ignore[misc] + + # Preserve can_jump_to metadata on the wrapped function + if func_can_jump_to: + async_wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined] + + middleware_name = name or cast( + "str", getattr(func, "__name__", "BeforeAgentMiddleware") + ) + + return type( + middleware_name, + (AgentMiddleware,), + { + "state_schema": state_schema or AgentState, + "tools": tools or [], + "abefore_agent": async_wrapped, + }, + )() + + def wrapped( + self: AgentMiddleware[StateT, ContextT], # noqa: ARG001 + state: StateT, + runtime: Runtime[ContextT], + ) -> dict[str, Any] | Command | None: + return func(state, runtime) # type: ignore[return-value] + + # Preserve can_jump_to metadata on the wrapped function + if func_can_jump_to: + wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined] + + # Use function name as default if no name provided + middleware_name = name or cast("str", getattr(func, "__name__", "BeforeAgentMiddleware")) + + return type( + middleware_name, + (AgentMiddleware,), + { + "state_schema": state_schema or AgentState, + "tools": tools or [], + "before_agent": wrapped, + }, + )() + + if func is not None: + return decorator(func) + return decorator + + +@overload +def after_agent( + func: _CallableWithStateAndRuntime[StateT, ContextT], +) -> AgentMiddleware[StateT, ContextT]: ... + + +@overload +def after_agent( + func: None = None, + *, + state_schema: type[StateT] | None = None, + tools: list[BaseTool] | None = None, + can_jump_to: list[JumpTo] | None = None, + name: str | None = None, +) -> Callable[ + [_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT] +]: ... + + +def after_agent( + func: _CallableWithStateAndRuntime[StateT, ContextT] | None = None, + *, + state_schema: type[StateT] | None = None, + tools: list[BaseTool] | None = None, + can_jump_to: list[JumpTo] | None = None, + name: str | None = None, +) -> ( + Callable[[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]] + | AgentMiddleware[StateT, ContextT] +): + """Decorator used to dynamically create a middleware with the after_agent hook. + + Args: + func: The function to be decorated. Must accept: + `state: StateT, runtime: Runtime[ContextT]` - State and runtime context + state_schema: Optional custom state schema type. If not provided, uses the default + AgentState schema. + tools: Optional list of additional tools to register with this middleware. + can_jump_to: Optional list of valid jump destinations for conditional edges. + Valid values are: "tools", "model", "end" + name: Optional name for the generated middleware class. If not provided, + uses the decorated function's name. + + Returns: + Either an AgentMiddleware instance (if func is provided) or a decorator function + that can be applied to a function. + + The decorated function should return: + - `dict[str, Any]` - State updates to merge into the agent state + - `Command` - A command to control flow (e.g., jump to different node) + - `None` - No state updates or flow control + + Examples: + Basic usage for logging agent completion: + ```python + @after_agent + def log_completion(state: AgentState, runtime: Runtime) -> None: + print(f"Agent completed with {len(state['messages'])} messages") + ``` + + With custom state schema: + ```python + @after_agent(state_schema=MyCustomState, name="MyAfterAgentMiddleware") + def custom_after_agent(state: MyCustomState, runtime: Runtime) -> dict[str, Any]: + return {"custom_field": "finalized_value"} + ``` + """ + + def decorator( + func: _CallableWithStateAndRuntime[StateT, ContextT], + ) -> AgentMiddleware[StateT, ContextT]: + is_async = iscoroutinefunction(func) + # Extract can_jump_to from decorator parameter or from function metadata + func_can_jump_to = ( + can_jump_to if can_jump_to is not None else getattr(func, "__can_jump_to__", []) + ) + + if is_async: + + async def async_wrapped( + self: AgentMiddleware[StateT, ContextT], # noqa: ARG001 + state: StateT, + runtime: Runtime[ContextT], + ) -> dict[str, Any] | Command | None: + return await func(state, runtime) # type: ignore[misc] + + # Preserve can_jump_to metadata on the wrapped function + if func_can_jump_to: + async_wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined] + + middleware_name = name or cast("str", getattr(func, "__name__", "AfterAgentMiddleware")) + + return type( + middleware_name, + (AgentMiddleware,), + { + "state_schema": state_schema or AgentState, + "tools": tools or [], + "aafter_agent": async_wrapped, + }, + )() + + def wrapped( + self: AgentMiddleware[StateT, ContextT], # noqa: ARG001 + state: StateT, + runtime: Runtime[ContextT], + ) -> dict[str, Any] | Command | None: + return func(state, runtime) # type: ignore[return-value] + + # Preserve can_jump_to metadata on the wrapped function + if func_can_jump_to: + wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined] + + # Use function name as default if no name provided + middleware_name = name or cast("str", getattr(func, "__name__", "AfterAgentMiddleware")) + + return type( + middleware_name, + (AgentMiddleware,), + { + "state_schema": state_schema or AgentState, + "tools": tools or [], + "after_agent": wrapped, + }, + )() + + if func is not None: + return decorator(func) + return decorator + + @overload def dynamic_prompt( func: _CallableReturningPromptString[StateT, ContextT], diff --git a/libs/langchain_v1/langchain/agents/middleware_agent.py b/libs/langchain_v1/langchain/agents/middleware_agent.py index b4f6370446b05..7211445888bbd 100644 --- a/libs/langchain_v1/langchain/agents/middleware_agent.py +++ b/libs/langchain_v1/langchain/agents/middleware_agent.py @@ -260,6 +260,12 @@ def create_agent( # noqa: PLR0915 assert len({m.name for m in middleware}) == len(middleware), ( # noqa: S101 "Please remove duplicate middleware instances." ) + middleware_w_before_agent = [ + m + for m in middleware + if m.__class__.before_agent is not AgentMiddleware.before_agent + or m.__class__.abefore_agent is not AgentMiddleware.abefore_agent + ] middleware_w_before = [ m for m in middleware @@ -278,6 +284,12 @@ def create_agent( # noqa: PLR0915 if m.__class__.after_model is not AgentMiddleware.after_model or m.__class__.aafter_model is not AgentMiddleware.aafter_model ] + middleware_w_after_agent = [ + m + for m in middleware + if m.__class__.after_agent is not AgentMiddleware.after_agent + or m.__class__.aafter_agent is not AgentMiddleware.aafter_agent + ] middleware_w_retry = [ m for m in middleware @@ -634,6 +646,25 @@ async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[ # Add middleware nodes for m in middleware: + if ( + m.__class__.before_agent is not AgentMiddleware.before_agent + or m.__class__.abefore_agent is not AgentMiddleware.abefore_agent + ): + # Use RunnableCallable to support both sync and async + # Pass None for sync if not overridden to avoid signature conflicts + sync_before_agent = ( + m.before_agent + if m.__class__.before_agent is not AgentMiddleware.before_agent + else None + ) + async_before_agent = ( + m.abefore_agent + if m.__class__.abefore_agent is not AgentMiddleware.abefore_agent + else None + ) + before_agent_node = RunnableCallable(sync_before_agent, async_before_agent) + graph.add_node(f"{m.name}.before_agent", before_agent_node, input_schema=state_schema) + if ( m.__class__.before_model is not AgentMiddleware.before_model or m.__class__.abefore_model is not AgentMiddleware.abefore_model @@ -672,13 +703,42 @@ async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[ after_node = RunnableCallable(sync_after, async_after) graph.add_node(f"{m.name}.after_model", after_node, input_schema=state_schema) + if ( + m.__class__.after_agent is not AgentMiddleware.after_agent + or m.__class__.aafter_agent is not AgentMiddleware.aafter_agent + ): + # Use RunnableCallable to support both sync and async + # Pass None for sync if not overridden to avoid signature conflicts + sync_after_agent = ( + m.after_agent + if m.__class__.after_agent is not AgentMiddleware.after_agent + else None + ) + async_after_agent = ( + m.aafter_agent + if m.__class__.aafter_agent is not AgentMiddleware.aafter_agent + else None + ) + after_agent_node = RunnableCallable(sync_after_agent, async_after_agent) + graph.add_node(f"{m.name}.after_agent", after_agent_node, input_schema=state_schema) + # add start edge - first_node = ( - f"{middleware_w_before[0].name}.before_model" if middleware_w_before else "model_request" - ) - last_node = ( - f"{middleware_w_after[0].name}.after_model" if middleware_w_after else "model_request" - ) + # Determine the first node in the chain: before_agent -> before_model -> model_request + if middleware_w_before_agent: + first_node = f"{middleware_w_before_agent[0].name}.before_agent" + elif middleware_w_before: + first_node = f"{middleware_w_before[0].name}.before_model" + else: + first_node = "model_request" + + # Determine the last node in the chain: model_request -> after_model -> after_agent + if middleware_w_after_agent: + last_node = f"{middleware_w_after_agent[0].name}.after_agent" + elif middleware_w_after: + last_node = f"{middleware_w_after[0].name}.after_model" + else: + last_node = "model_request" + graph.add_edge(START, first_node) # add conditional edges only if tools exist @@ -697,16 +757,49 @@ async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[ # If no tools, just go to END from model graph.add_edge(last_node, END) else: - # If after_model, then need to check for can_jump_to + # Determine the final node to connect to END + if middleware_w_after_agent: + final_node_name = f"{middleware_w_after_agent[0].name}.after_agent" + final_hook_name = "after_agent" + final_middleware = middleware_w_after_agent[0] + else: + final_node_name = f"{middleware_w_after[0].name}.after_model" + final_hook_name = "after_model" + final_middleware = middleware_w_after[0] + _add_middleware_edge( graph, - f"{middleware_w_after[0].name}.after_model", + final_node_name, END, first_node, - can_jump_to=_get_can_jump_to(middleware_w_after[0], "after_model"), + can_jump_to=_get_can_jump_to(final_middleware, final_hook_name), + ) + + # Add before_agent middleware edges + if middleware_w_before_agent: + for m1, m2 in itertools.pairwise(middleware_w_before_agent): + _add_middleware_edge( + graph, + f"{m1.name}.before_agent", + f"{m2.name}.before_agent", + first_node, + can_jump_to=_get_can_jump_to(m1, "before_agent"), + ) + # Connect last before_agent to either first before_model or model_request + next_node = ( + f"{middleware_w_before[0].name}.before_model" + if middleware_w_before + else "model_request" + ) + _add_middleware_edge( + graph, + f"{middleware_w_before_agent[-1].name}.before_agent", + next_node, + first_node, + can_jump_to=_get_can_jump_to(middleware_w_before_agent[-1], "before_agent"), ) - # Add middleware edges (same as before) + # Add before_model middleware edges if middleware_w_before: for m1, m2 in itertools.pairwise(middleware_w_before): _add_middleware_edge( @@ -725,6 +818,7 @@ async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[ can_jump_to=_get_can_jump_to(middleware_w_before[-1], "before_model"), ) + # Add after_model middleware edges if middleware_w_after: graph.add_edge("model_request", f"{middleware_w_after[-1].name}.after_model") for idx in range(len(middleware_w_after) - 1, 0, -1): @@ -737,6 +831,33 @@ async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[ first_node, can_jump_to=_get_can_jump_to(m1, "after_model"), ) + # Connect first (last in reverse) after_model to either first after_agent or END + if middleware_w_after_agent: + _add_middleware_edge( + graph, + f"{middleware_w_after[0].name}.after_model", + f"{middleware_w_after_agent[-1].name}.after_agent", + first_node, + can_jump_to=_get_can_jump_to(middleware_w_after[0], "after_model"), + ) + + # Add after_agent middleware edges + if middleware_w_after_agent: + # If we have after_model, it already connects to the first after_agent + # Otherwise, connect model_request to the first after_agent + if not middleware_w_after: + graph.add_edge("model_request", f"{middleware_w_after_agent[-1].name}.after_agent") + + for idx in range(len(middleware_w_after_agent) - 1, 0, -1): + m1 = middleware_w_after_agent[idx] + m2 = middleware_w_after_agent[idx - 1] + _add_middleware_edge( + graph, + f"{m1.name}.after_agent", + f"{m2.name}.after_agent", + first_node, + can_jump_to=_get_can_jump_to(m1, "after_agent"), + ) return graph diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/test_before_after_agent.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_before_after_agent.py new file mode 100644 index 0000000000000..f27dede81f144 --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_before_after_agent.py @@ -0,0 +1,351 @@ +"""Unit tests for before_agent and after_agent middleware hooks.""" + +from typing import Any + +from langchain.agents import create_agent +from langchain.agents.middleware import ( + AgentMiddleware, + AgentState, + after_agent, + before_agent, +) +from langchain_core.language_models.fake_chat_models import GenericFakeChatModel +from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.tools import tool +from langgraph.runtime import Runtime + + +@tool +def sample_tool(query: str) -> str: + """A sample tool for testing.""" + return f"Result for: {query}" + + +class TestBeforeAgentBasic: + """Test basic before_agent functionality.""" + + def test_sync_before_agent_execution(self) -> None: + """Test that before_agent hook is called synchronously.""" + execution_log = [] + + @before_agent + def log_before_agent(state: AgentState, runtime: Runtime) -> dict[str, Any] | None: + execution_log.append("before_agent_called") + execution_log.append(f"message_count: {len(state['messages'])}") + return None + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello!")])) + + agent = create_agent(model=model, tools=[], middleware=[log_before_agent]) + + agent.invoke({"messages": [HumanMessage("Hi")]}) + + assert "before_agent_called" in execution_log + assert "message_count: 1" in execution_log + + async def test_async_before_agent_execution(self) -> None: + """Test that before_agent hook is called asynchronously.""" + execution_log = [] + + @before_agent + async def async_log_before_agent( + state: AgentState, runtime: Runtime + ) -> dict[str, Any] | None: + execution_log.append("async_before_agent_called") + execution_log.append(f"message_count: {len(state['messages'])}") + return None + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello!")])) + + agent = create_agent(model=model, tools=[], middleware=[async_log_before_agent]) + + await agent.ainvoke({"messages": [HumanMessage("Hi")]}) + + assert "async_before_agent_called" in execution_log + assert "message_count: 1" in execution_log + + def test_before_agent_state_modification(self) -> None: + """Test that before_agent can modify state.""" + + @before_agent + def add_metadata(state: AgentState, runtime: Runtime) -> dict[str, Any]: + return {"messages": [HumanMessage("Injected by middleware")]} + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")])) + + agent = create_agent(model=model, tools=[], middleware=[add_metadata]) + + result = agent.invoke({"messages": [HumanMessage("Original")]}) + + # Should have original + injected + AI response + assert len(result["messages"]) >= 2 + message_contents = [msg.content for msg in result["messages"]] + assert "Injected by middleware" in message_contents + + def test_before_agent_with_class_inheritance(self) -> None: + """Test before_agent using class inheritance.""" + execution_log = [] + + class CustomBeforeAgentMiddleware(AgentMiddleware): + def before_agent(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: + execution_log.append("class_before_agent_called") + return None + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")])) + + agent = create_agent(model=model, tools=[], middleware=[CustomBeforeAgentMiddleware()]) + + agent.invoke({"messages": [HumanMessage("Test")]}) + + assert "class_before_agent_called" in execution_log + + async def test_before_agent_with_async_class_inheritance(self) -> None: + """Test async before_agent using class inheritance.""" + execution_log = [] + + class CustomAsyncBeforeAgentMiddleware(AgentMiddleware): + async def abefore_agent( + self, state: AgentState, runtime: Runtime + ) -> dict[str, Any] | None: + execution_log.append("async_class_before_agent_called") + return None + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")])) + + agent = create_agent(model=model, tools=[], middleware=[CustomAsyncBeforeAgentMiddleware()]) + + await agent.ainvoke({"messages": [HumanMessage("Test")]}) + + assert "async_class_before_agent_called" in execution_log + + +class TestAfterAgentBasic: + """Test basic after_agent functionality.""" + + def test_sync_after_agent_execution(self) -> None: + """Test that after_agent hook is called synchronously.""" + execution_log = [] + + @after_agent + def log_after_agent(state: AgentState, runtime: Runtime) -> dict[str, Any] | None: + execution_log.append("after_agent_called") + execution_log.append(f"final_message_count: {len(state['messages'])}") + return None + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Final response")])) + + agent = create_agent(model=model, tools=[], middleware=[log_after_agent]) + + agent.invoke({"messages": [HumanMessage("Hi")]}) + + assert "after_agent_called" in execution_log + assert any("final_message_count:" in log for log in execution_log) + + async def test_async_after_agent_execution(self) -> None: + """Test that after_agent hook is called asynchronously.""" + execution_log = [] + + @after_agent + async def async_log_after_agent( + state: AgentState, runtime: Runtime + ) -> dict[str, Any] | None: + execution_log.append("async_after_agent_called") + return None + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")])) + + agent = create_agent(model=model, tools=[], middleware=[async_log_after_agent]) + + await agent.ainvoke({"messages": [HumanMessage("Hi")]}) + + assert "async_after_agent_called" in execution_log + + def test_after_agent_state_modification(self) -> None: + """Test that after_agent can modify state.""" + + @after_agent + def add_final_message(state: AgentState, runtime: Runtime) -> dict[str, Any]: + return {"messages": [AIMessage("Added by after_agent")]} + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Model response")])) + + agent = create_agent(model=model, tools=[], middleware=[add_final_message]) + + result = agent.invoke({"messages": [HumanMessage("Test")]}) + + message_contents = [msg.content for msg in result["messages"]] + assert "Added by after_agent" in message_contents + + def test_after_agent_with_class_inheritance(self) -> None: + """Test after_agent using class inheritance.""" + execution_log = [] + + class CustomAfterAgentMiddleware(AgentMiddleware): + def after_agent(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: + execution_log.append("class_after_agent_called") + return None + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")])) + + agent = create_agent(model=model, tools=[], middleware=[CustomAfterAgentMiddleware()]) + + agent.invoke({"messages": [HumanMessage("Test")]}) + + assert "class_after_agent_called" in execution_log + + async def test_after_agent_with_async_class_inheritance(self) -> None: + """Test async after_agent using class inheritance.""" + execution_log = [] + + class CustomAsyncAfterAgentMiddleware(AgentMiddleware): + async def aafter_agent( + self, state: AgentState, runtime: Runtime + ) -> dict[str, Any] | None: + execution_log.append("async_class_after_agent_called") + return None + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")])) + + agent = create_agent(model=model, tools=[], middleware=[CustomAsyncAfterAgentMiddleware()]) + + await agent.ainvoke({"messages": [HumanMessage("Test")]}) + + assert "async_class_after_agent_called" in execution_log + + +class TestBeforeAndAfterAgentCombined: + """Test before_agent and after_agent hooks working together.""" + + def test_execution_order(self) -> None: + """Test that before_agent executes before after_agent.""" + execution_log = [] + + @before_agent + def log_before(state: AgentState, runtime: Runtime) -> None: + execution_log.append("before") + + @after_agent + def log_after(state: AgentState, runtime: Runtime) -> None: + execution_log.append("after") + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")])) + + agent = create_agent(model=model, tools=[], middleware=[log_before, log_after]) + + agent.invoke({"messages": [HumanMessage("Test")]}) + + assert execution_log == ["before", "after"] + + async def test_async_execution_order(self) -> None: + """Test async execution order of before_agent and after_agent.""" + execution_log = [] + + @before_agent + async def async_log_before(state: AgentState, runtime: Runtime) -> None: + execution_log.append("async_before") + + @after_agent + async def async_log_after(state: AgentState, runtime: Runtime) -> None: + execution_log.append("async_after") + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")])) + + agent = create_agent(model=model, tools=[], middleware=[async_log_before, async_log_after]) + + await agent.ainvoke({"messages": [HumanMessage("Test")]}) + + assert execution_log == ["async_before", "async_after"] + + def test_state_passthrough(self) -> None: + """Test that state modifications in before_agent are visible to after_agent.""" + collected_states = {} + + @before_agent + def modify_in_before(state: AgentState, runtime: Runtime) -> dict[str, Any]: + return {"messages": [HumanMessage("Modified by before_agent")]} + + @after_agent + def capture_in_after(state: AgentState, runtime: Runtime) -> None: + collected_states["messages"] = state["messages"] + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")])) + + agent = create_agent(model=model, tools=[], middleware=[modify_in_before, capture_in_after]) + + agent.invoke({"messages": [HumanMessage("Original")]}) + + message_contents = [msg.content for msg in collected_states["messages"]] + assert "Modified by before_agent" in message_contents + + def test_multiple_middleware_instances(self) -> None: + """Test multiple before_agent and after_agent middleware instances.""" + execution_log = [] + + @before_agent + def before_one(state: AgentState, runtime: Runtime) -> None: + execution_log.append("before_1") + + @before_agent + def before_two(state: AgentState, runtime: Runtime) -> None: + execution_log.append("before_2") + + @after_agent + def after_one(state: AgentState, runtime: Runtime) -> None: + execution_log.append("after_1") + + @after_agent + def after_two(state: AgentState, runtime: Runtime) -> None: + execution_log.append("after_2") + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")])) + + agent = create_agent( + model=model, tools=[], middleware=[before_one, before_two, after_one, after_two] + ) + + agent.invoke({"messages": [HumanMessage("Test")]}) + + assert "before_1" in execution_log + assert "before_2" in execution_log + assert "after_1" in execution_log + assert "after_2" in execution_log + + +class TestDecoratorParameters: + """Test decorator parameters for before_agent and after_agent.""" + + def test_before_agent_with_custom_name(self) -> None: + """Test before_agent with custom middleware name.""" + + @before_agent(name="CustomBeforeAgentMiddleware") + def custom_named_before(state: AgentState, runtime: Runtime) -> None: + pass + + assert custom_named_before.name == "CustomBeforeAgentMiddleware" + + def test_after_agent_with_custom_name(self) -> None: + """Test after_agent with custom middleware name.""" + + @after_agent(name="CustomAfterAgentMiddleware") + def custom_named_after(state: AgentState, runtime: Runtime) -> None: + pass + + assert custom_named_after.name == "CustomAfterAgentMiddleware" + + def test_before_agent_default_name(self) -> None: + """Test that before_agent uses function name by default.""" + + @before_agent + def my_before_agent_function(state: AgentState, runtime: Runtime) -> None: + pass + + assert my_before_agent_function.name == "my_before_agent_function" + + def test_after_agent_default_name(self) -> None: + """Test that after_agent uses function name by default.""" + + @after_agent + def my_after_agent_function(state: AgentState, runtime: Runtime) -> None: + pass + + assert my_after_agent_function.name == "my_after_agent_function" From 6e07fc79822e5d21b4938051f02d893e389cbf74 Mon Sep 17 00:00:00 2001 From: Sydney Runkle Date: Sun, 5 Oct 2025 17:38:14 -0400 Subject: [PATCH 2/8] renaming --- .../langchain/agents/middleware_agent.py | 151 +++++++++--------- .../unit_tests/agents/middleware/__init__.py | 0 .../middleware/test_before_after_agent.py | 37 +++++ 3 files changed, 113 insertions(+), 75 deletions(-) create mode 100644 libs/langchain_v1/tests/unit_tests/agents/middleware/__init__.py diff --git a/libs/langchain_v1/langchain/agents/middleware_agent.py b/libs/langchain_v1/langchain/agents/middleware_agent.py index 7211445888bbd..f824fd9ae9c73 100644 --- a/libs/langchain_v1/langchain/agents/middleware_agent.py +++ b/libs/langchain_v1/langchain/agents/middleware_agent.py @@ -266,7 +266,7 @@ def create_agent( # noqa: PLR0915 if m.__class__.before_agent is not AgentMiddleware.before_agent or m.__class__.abefore_agent is not AgentMiddleware.abefore_agent ] - middleware_w_before = [ + middleware_w_before_model = [ m for m in middleware if m.__class__.before_model is not AgentMiddleware.before_model @@ -278,7 +278,7 @@ def create_agent( # noqa: PLR0915 if m.__class__.modify_model_request is not AgentMiddleware.modify_model_request or m.__class__.amodify_model_request is not AgentMiddleware.amodify_model_request ] - middleware_w_after = [ + middleware_w_after_model = [ m for m in middleware if m.__class__.after_model is not AgentMiddleware.after_model @@ -722,57 +722,62 @@ async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[ after_agent_node = RunnableCallable(sync_after_agent, async_after_agent) graph.add_node(f"{m.name}.after_agent", after_agent_node, input_schema=state_schema) - # add start edge - # Determine the first node in the chain: before_agent -> before_model -> model_request + # Determine the entry node (runs once at start): before_agent -> before_model -> model_request if middleware_w_before_agent: - first_node = f"{middleware_w_before_agent[0].name}.before_agent" - elif middleware_w_before: - first_node = f"{middleware_w_before[0].name}.before_model" + entry_node = f"{middleware_w_before_agent[0].name}.before_agent" + elif middleware_w_before_model: + entry_node = f"{middleware_w_before_model[0].name}.before_model" else: - first_node = "model_request" + entry_node = "model_request" - # Determine the last node in the chain: model_request -> after_model -> after_agent + # Determine the loop entry node (beginning of agent loop, excludes before_agent) + # This is where tools will loop back to for the next iteration + if middleware_w_before_model: + loop_entry_node = f"{middleware_w_before_model[0].name}.before_model" + else: + loop_entry_node = "model_request" + + # Determine the loop exit node (end of each iteration, can run multiple times) + # This is after_model or model_request, but NOT after_agent + if middleware_w_after_model: + loop_exit_node = f"{middleware_w_after_model[0].name}.after_model" + else: + loop_exit_node = "model_request" + + # Determine the exit node (runs once at end): after_agent or END if middleware_w_after_agent: - last_node = f"{middleware_w_after_agent[0].name}.after_agent" - elif middleware_w_after: - last_node = f"{middleware_w_after[0].name}.after_model" + exit_node = f"{middleware_w_after_agent[-1].name}.after_agent" else: - last_node = "model_request" + exit_node = END - graph.add_edge(START, first_node) + graph.add_edge(START, entry_node) # add conditional edges only if tools exist if tool_node is not None: graph.add_conditional_edges( "tools", - _make_tools_to_model_edge(tool_node, first_node, structured_output_tools), - [first_node, END], + _make_tools_to_model_edge(tool_node, loop_entry_node, structured_output_tools), + [loop_entry_node, END], ) + graph.add_conditional_edges( - last_node, - _make_model_to_tools_edge(first_node, structured_output_tools, tool_node), - [first_node, "tools", END], + loop_exit_node, + _make_model_to_tools_edge( + loop_entry_node, structured_output_tools, tool_node, exit_node + ), + [loop_entry_node, "tools", exit_node], ) - elif last_node == "model_request": - # If no tools, just go to END from model - graph.add_edge(last_node, END) + elif loop_exit_node == "model_request": + # If no tools and no after_model, go directly to exit_node + graph.add_edge(loop_exit_node, exit_node) + # No tools but we have after_model - connect after_model to exit_node else: - # Determine the final node to connect to END - if middleware_w_after_agent: - final_node_name = f"{middleware_w_after_agent[0].name}.after_agent" - final_hook_name = "after_agent" - final_middleware = middleware_w_after_agent[0] - else: - final_node_name = f"{middleware_w_after[0].name}.after_model" - final_hook_name = "after_model" - final_middleware = middleware_w_after[0] - _add_middleware_edge( graph, - final_node_name, - END, - first_node, - can_jump_to=_get_can_jump_to(final_middleware, final_hook_name), + f"{middleware_w_after_model[0].name}.after_model", + exit_node, + loop_entry_node, + can_jump_to=_get_can_jump_to(middleware_w_after_model[0], "after_model"), ) # Add before_agent middleware edges @@ -782,72 +787,56 @@ async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[ graph, f"{m1.name}.before_agent", f"{m2.name}.before_agent", - first_node, + loop_entry_node, can_jump_to=_get_can_jump_to(m1, "before_agent"), ) - # Connect last before_agent to either first before_model or model_request - next_node = ( - f"{middleware_w_before[0].name}.before_model" - if middleware_w_before - else "model_request" - ) + # Connect last before_agent to loop_entry_node (before_model or model_request) _add_middleware_edge( graph, f"{middleware_w_before_agent[-1].name}.before_agent", - next_node, - first_node, + loop_entry_node, + loop_entry_node, can_jump_to=_get_can_jump_to(middleware_w_before_agent[-1], "before_agent"), ) # Add before_model middleware edges - if middleware_w_before: - for m1, m2 in itertools.pairwise(middleware_w_before): + if middleware_w_before_model: + for m1, m2 in itertools.pairwise(middleware_w_before_model): _add_middleware_edge( graph, f"{m1.name}.before_model", f"{m2.name}.before_model", - first_node, + loop_entry_node, can_jump_to=_get_can_jump_to(m1, "before_model"), ) # Go directly to model_request after the last before_model _add_middleware_edge( graph, - f"{middleware_w_before[-1].name}.before_model", + f"{middleware_w_before_model[-1].name}.before_model", "model_request", - first_node, - can_jump_to=_get_can_jump_to(middleware_w_before[-1], "before_model"), + loop_entry_node, + can_jump_to=_get_can_jump_to(middleware_w_before_model[-1], "before_model"), ) # Add after_model middleware edges - if middleware_w_after: - graph.add_edge("model_request", f"{middleware_w_after[-1].name}.after_model") - for idx in range(len(middleware_w_after) - 1, 0, -1): - m1 = middleware_w_after[idx] - m2 = middleware_w_after[idx - 1] + if middleware_w_after_model: + graph.add_edge("model_request", f"{middleware_w_after_model[-1].name}.after_model") + for idx in range(len(middleware_w_after_model) - 1, 0, -1): + m1 = middleware_w_after_model[idx] + m2 = middleware_w_after_model[idx - 1] _add_middleware_edge( graph, f"{m1.name}.after_model", f"{m2.name}.after_model", - first_node, + loop_entry_node, can_jump_to=_get_can_jump_to(m1, "after_model"), ) - # Connect first (last in reverse) after_model to either first after_agent or END - if middleware_w_after_agent: - _add_middleware_edge( - graph, - f"{middleware_w_after[0].name}.after_model", - f"{middleware_w_after_agent[-1].name}.after_agent", - first_node, - can_jump_to=_get_can_jump_to(middleware_w_after[0], "after_model"), - ) + # Note: Connection from after_model to after_agent/END is handled above + # in the conditional edges section # Add after_agent middleware edges if middleware_w_after_agent: - # If we have after_model, it already connects to the first after_agent - # Otherwise, connect model_request to the first after_agent - if not middleware_w_after: - graph.add_edge("model_request", f"{middleware_w_after_agent[-1].name}.after_agent") - + # Chain after_agent middleware (runs once at the very end, before END) for idx in range(len(middleware_w_after_agent) - 1, 0, -1): m1 = middleware_w_after_agent[idx] m2 = middleware_w_after_agent[idx - 1] @@ -855,10 +844,19 @@ async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[ graph, f"{m1.name}.after_agent", f"{m2.name}.after_agent", - first_node, + loop_entry_node, can_jump_to=_get_can_jump_to(m1, "after_agent"), ) + # Connect the last after_agent to END + _add_middleware_edge( + graph, + f"{middleware_w_after_agent[0].name}.after_agent", + END, + loop_entry_node, + can_jump_to=_get_can_jump_to(middleware_w_after_agent[0], "after_agent"), + ) + return graph @@ -889,7 +887,10 @@ def _fetch_last_ai_and_tool_messages( def _make_model_to_tools_edge( - first_node: str, structured_output_tools: dict[str, OutputToolBinding], tool_node: ToolNode + first_node: str, + structured_output_tools: dict[str, OutputToolBinding], + tool_node: ToolNode, + exit_destination: str = END, ) -> Callable[[dict[str, Any]], str | list[Send] | None]: def model_to_tools(state: dict[str, Any]) -> str | list[Send] | None: # 1. if there's an explicit jump_to in the state, use it @@ -899,10 +900,10 @@ def model_to_tools(state: dict[str, Any]) -> str | list[Send] | None: last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"]) tool_message_ids = [m.tool_call_id for m in tool_messages] - # 2. if the model hasn't called any tools, jump to END + # 2. if the model hasn't called any tools, exit the loop # this is the classic exit condition for an agent loop if len(last_ai_message.tool_calls) == 0: - return END + return exit_destination pending_tool_calls = [ c diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/__init__.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/test_before_after_agent.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_before_after_agent.py index f27dede81f144..d13fec4b2272a 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/test_before_after_agent.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_before_after_agent.py @@ -14,6 +14,8 @@ from langchain_core.tools import tool from langgraph.runtime import Runtime +from ..model import FakeToolCallingModel + @tool def sample_tool(query: str) -> str: @@ -310,6 +312,41 @@ def after_two(state: AgentState, runtime: Runtime) -> None: assert "after_1" in execution_log assert "after_2" in execution_log + def test_agent_hooks_run_once_with_multiple_model_calls(self) -> None: + """Test that before_agent and after_agent run only once even with tool calls.""" + execution_log = [] + + @before_agent + def log_before_agent(state: AgentState, runtime: Runtime) -> None: + execution_log.append("before_agent") + + @after_agent + def log_after_agent(state: AgentState, runtime: Runtime) -> None: + execution_log.append("after_agent") + + # Model will call a tool once, then respond with final answer + model = FakeToolCallingModel( + tool_calls=[ + [{"name": "sample_tool", "args": {"query": "test"}, "id": "1"}], + [], # Second call returns no tool calls (final answer) + ] + ) + + agent = create_agent( + model=model, + tools=[sample_tool], + middleware=[log_before_agent, log_after_agent], + ) + + agent.invoke({"messages": [HumanMessage("Test")]}) + + # before_agent and after_agent should run exactly once + assert execution_log.count("before_agent") == 1 + assert execution_log.count("after_agent") == 1 + # before_agent should run first, after_agent should run last + assert execution_log[0] == "before_agent" + assert execution_log[-1] == "after_agent" + class TestDecoratorParameters: """Test decorator parameters for before_agent and after_agent.""" From 22fb405b45993dc11e5908ba5065bd5c1b45b8e9 Mon Sep 17 00:00:00 2001 From: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com> Date: Sun, 5 Oct 2025 21:16:03 -0400 Subject: [PATCH 3/8] feat(langchain_v1): improving tracing (#33283) Forgo tracing `RunnableCallable` outer calls + ensure we're not showing structured response on input fields --- .../langchain/agents/middleware/types.py | 2 +- .../langchain/agents/middleware_agent.py | 33 ++++++++++--------- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/libs/langchain_v1/langchain/agents/middleware/types.py b/libs/langchain_v1/langchain/agents/middleware/types.py index b10eaab7b3c0f..327bab5e5bcc7 100644 --- a/libs/langchain_v1/langchain/agents/middleware/types.py +++ b/libs/langchain_v1/langchain/agents/middleware/types.py @@ -98,7 +98,7 @@ class AgentState(TypedDict, Generic[ResponseT]): messages: Required[Annotated[list[AnyMessage], add_messages]] jump_to: NotRequired[Annotated[JumpTo | None, EphemeralValue, PrivateStateAttr]] - structured_response: NotRequired[ResponseT] + structured_response: NotRequired[Annotated[ResponseT, OmitFromInput]] thread_model_call_count: NotRequired[Annotated[int, PrivateStateAttr]] run_model_call_count: NotRequired[Annotated[int, UntrackedValue, PrivateStateAttr]] diff --git a/libs/langchain_v1/langchain/agents/middleware_agent.py b/libs/langchain_v1/langchain/agents/middleware_agent.py index f824fd9ae9c73..6ff5d306b22c3 100644 --- a/libs/langchain_v1/langchain/agents/middleware_agent.py +++ b/libs/langchain_v1/langchain/agents/middleware_agent.py @@ -8,6 +8,7 @@ from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage from langchain_core.runnables import Runnable from langchain_core.tools import BaseTool +from langgraph._internal._runnable import RunnableCallable from langgraph.constants import END, START from langgraph.graph.state import StateGraph from langgraph.runtime import Runtime @@ -636,9 +637,7 @@ async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[ raise RuntimeError(msg) # Use sync or async based on model capabilities - from langgraph._internal._runnable import RunnableCallable - - graph.add_node("model_request", RunnableCallable(model_request, amodel_request)) + graph.add_node("model_request", RunnableCallable(model_request, amodel_request, trace=False)) # Only add tools node if we have tools if tool_node is not None: @@ -662,7 +661,7 @@ async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[ if m.__class__.abefore_agent is not AgentMiddleware.abefore_agent else None ) - before_agent_node = RunnableCallable(sync_before_agent, async_before_agent) + before_agent_node = RunnableCallable(sync_before_agent, async_before_agent, trace=False) graph.add_node(f"{m.name}.before_agent", before_agent_node, input_schema=state_schema) if ( @@ -681,7 +680,7 @@ async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[ if m.__class__.abefore_model is not AgentMiddleware.abefore_model else None ) - before_node = RunnableCallable(sync_before, async_before) + before_node = RunnableCallable(sync_before, async_before, trace=False) graph.add_node(f"{m.name}.before_model", before_node, input_schema=state_schema) if ( @@ -700,7 +699,7 @@ async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[ if m.__class__.aafter_model is not AgentMiddleware.aafter_model else None ) - after_node = RunnableCallable(sync_after, async_after) + after_node = RunnableCallable(sync_after, async_after, trace=False) graph.add_node(f"{m.name}.after_model", after_node, input_schema=state_schema) if ( @@ -719,7 +718,7 @@ async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[ if m.__class__.aafter_agent is not AgentMiddleware.aafter_agent else None ) - after_agent_node = RunnableCallable(sync_after_agent, async_after_agent) + after_agent_node = RunnableCallable(sync_after_agent, async_after_agent, trace=False) graph.add_node(f"{m.name}.after_agent", after_agent_node, input_schema=state_schema) # Determine the entry node (runs once at start): before_agent -> before_model -> model_request @@ -751,13 +750,14 @@ async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[ exit_node = END graph.add_edge(START, entry_node) - # add conditional edges only if tools exist if tool_node is not None: graph.add_conditional_edges( "tools", - _make_tools_to_model_edge(tool_node, loop_entry_node, structured_output_tools), - [loop_entry_node, END], + _make_tools_to_model_edge( + tool_node, loop_entry_node, structured_output_tools, exit_node + ), + [loop_entry_node, exit_node], ) graph.add_conditional_edges( @@ -890,7 +890,7 @@ def _make_model_to_tools_edge( first_node: str, structured_output_tools: dict[str, OutputToolBinding], tool_node: ToolNode, - exit_destination: str = END, + exit_node: str, ) -> Callable[[dict[str, Any]], str | list[Send] | None]: def model_to_tools(state: dict[str, Any]) -> str | list[Send] | None: # 1. if there's an explicit jump_to in the state, use it @@ -903,7 +903,7 @@ def model_to_tools(state: dict[str, Any]) -> str | list[Send] | None: # 2. if the model hasn't called any tools, exit the loop # this is the classic exit condition for an agent loop if len(last_ai_message.tool_calls) == 0: - return exit_destination + return exit_node pending_tool_calls = [ c @@ -926,7 +926,10 @@ def model_to_tools(state: dict[str, Any]) -> str | list[Send] | None: def _make_tools_to_model_edge( - tool_node: ToolNode, next_node: str, structured_output_tools: dict[str, OutputToolBinding] + tool_node: ToolNode, + next_node: str, + structured_output_tools: dict[str, OutputToolBinding], + exit_node: str, ) -> Callable[[dict[str, Any]], str | None]: def tools_to_model(state: dict[str, Any]) -> str | None: last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"]) @@ -936,10 +939,10 @@ def tools_to_model(state: dict[str, Any]) -> str | None: for c in last_ai_message.tool_calls if c["name"] in tool_node.tools_by_name ): - return END + return exit_node if any(t.name in structured_output_tools for t in tool_messages): - return END + return exit_node return next_node From c8156b79772c81bb1f27e300ac9ea0c7c6531ca4 Mon Sep 17 00:00:00 2001 From: Sydney Runkle Date: Mon, 6 Oct 2025 09:24:40 -0400 Subject: [PATCH 4/8] consolidating tests --- .../middleware/test_before_after_agent.py | 296 +++++++++--------- 1 file changed, 146 insertions(+), 150 deletions(-) diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/test_before_after_agent.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_before_after_agent.py index d13fec4b2272a..2cbad45de2320 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/test_before_after_agent.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_before_after_agent.py @@ -2,6 +2,7 @@ from typing import Any +import pytest from langchain.agents import create_agent from langchain.agents.middleware import ( AgentMiddleware, @@ -26,44 +27,41 @@ def sample_tool(query: str) -> str: class TestBeforeAgentBasic: """Test basic before_agent functionality.""" - def test_sync_before_agent_execution(self) -> None: - """Test that before_agent hook is called synchronously.""" - execution_log = [] - - @before_agent - def log_before_agent(state: AgentState, runtime: Runtime) -> dict[str, Any] | None: - execution_log.append("before_agent_called") - execution_log.append(f"message_count: {len(state['messages'])}") - return None - - model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello!")])) + @pytest.mark.parametrize("is_async", [False, True]) + async def test_before_agent_execution(self, is_async: bool) -> None: + """Test that before_agent hook is called in both sync and async modes.""" + execution_log: list[str] = [] - agent = create_agent(model=model, tools=[], middleware=[log_before_agent]) + if is_async: - agent.invoke({"messages": [HumanMessage("Hi")]}) + @before_agent + async def log_before_agent( + state: AgentState, runtime: Runtime + ) -> dict[str, Any] | None: + execution_log.append("before_agent_called") + execution_log.append(f"message_count: {len(state['messages'])}") + return None - assert "before_agent_called" in execution_log - assert "message_count: 1" in execution_log + middleware = log_before_agent + else: - async def test_async_before_agent_execution(self) -> None: - """Test that before_agent hook is called asynchronously.""" - execution_log = [] + @before_agent + def log_before_agent_sync(state: AgentState, runtime: Runtime) -> dict[str, Any] | None: + execution_log.append("before_agent_called") + execution_log.append(f"message_count: {len(state['messages'])}") + return None - @before_agent - async def async_log_before_agent( - state: AgentState, runtime: Runtime - ) -> dict[str, Any] | None: - execution_log.append("async_before_agent_called") - execution_log.append(f"message_count: {len(state['messages'])}") - return None + middleware = log_before_agent_sync model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello!")])) + agent = create_agent(model=model, tools=[], middleware=[middleware]) - agent = create_agent(model=model, tools=[], middleware=[async_log_before_agent]) - - await agent.ainvoke({"messages": [HumanMessage("Hi")]}) + if is_async: + await agent.ainvoke({"messages": [HumanMessage("Hi")]}) + else: + agent.invoke({"messages": [HumanMessage("Hi")]}) - assert "async_before_agent_called" in execution_log + assert "before_agent_called" in execution_log assert "message_count: 1" in execution_log def test_before_agent_state_modification(self) -> None: @@ -74,9 +72,7 @@ def add_metadata(state: AgentState, runtime: Runtime) -> dict[str, Any]: return {"messages": [HumanMessage("Injected by middleware")]} model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")])) - agent = create_agent(model=model, tools=[], middleware=[add_metadata]) - result = agent.invoke({"messages": [HumanMessage("Original")]}) # Should have original + injected + AI response @@ -84,83 +80,80 @@ def add_metadata(state: AgentState, runtime: Runtime) -> dict[str, Any]: message_contents = [msg.content for msg in result["messages"]] assert "Injected by middleware" in message_contents - def test_before_agent_with_class_inheritance(self) -> None: - """Test before_agent using class inheritance.""" - execution_log = [] - - class CustomBeforeAgentMiddleware(AgentMiddleware): - def before_agent(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: - execution_log.append("class_before_agent_called") - return None - - model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")])) + @pytest.mark.parametrize("is_async", [False, True]) + async def test_before_agent_with_class_inheritance(self, is_async: bool) -> None: + """Test before_agent using class inheritance in both sync and async modes.""" + execution_log: list[str] = [] - agent = create_agent(model=model, tools=[], middleware=[CustomBeforeAgentMiddleware()]) + if is_async: - agent.invoke({"messages": [HumanMessage("Test")]}) + class CustomAsyncBeforeAgentMiddleware(AgentMiddleware): + async def abefore_agent( + self, state: AgentState, runtime: Runtime + ) -> dict[str, Any] | None: + execution_log.append("before_agent_called") + return None - assert "class_before_agent_called" in execution_log + middleware = CustomAsyncBeforeAgentMiddleware() + else: - async def test_before_agent_with_async_class_inheritance(self) -> None: - """Test async before_agent using class inheritance.""" - execution_log = [] + class CustomBeforeAgentMiddleware(AgentMiddleware): + def before_agent( + self, state: AgentState, runtime: Runtime + ) -> dict[str, Any] | None: + execution_log.append("before_agent_called") + return None - class CustomAsyncBeforeAgentMiddleware(AgentMiddleware): - async def abefore_agent( - self, state: AgentState, runtime: Runtime - ) -> dict[str, Any] | None: - execution_log.append("async_class_before_agent_called") - return None + middleware = CustomBeforeAgentMiddleware() model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")])) + agent = create_agent(model=model, tools=[], middleware=[middleware]) - agent = create_agent(model=model, tools=[], middleware=[CustomAsyncBeforeAgentMiddleware()]) - - await agent.ainvoke({"messages": [HumanMessage("Test")]}) + if is_async: + await agent.ainvoke({"messages": [HumanMessage("Test")]}) + else: + agent.invoke({"messages": [HumanMessage("Test")]}) - assert "async_class_before_agent_called" in execution_log + assert "before_agent_called" in execution_log class TestAfterAgentBasic: """Test basic after_agent functionality.""" - def test_sync_after_agent_execution(self) -> None: - """Test that after_agent hook is called synchronously.""" - execution_log = [] + @pytest.mark.parametrize("is_async", [False, True]) + async def test_after_agent_execution(self, is_async: bool) -> None: + """Test that after_agent hook is called in both sync and async modes.""" + execution_log: list[str] = [] - @after_agent - def log_after_agent(state: AgentState, runtime: Runtime) -> dict[str, Any] | None: - execution_log.append("after_agent_called") - execution_log.append(f"final_message_count: {len(state['messages'])}") - return None + if is_async: - model = GenericFakeChatModel(messages=iter([AIMessage(content="Final response")])) - - agent = create_agent(model=model, tools=[], middleware=[log_after_agent]) - - agent.invoke({"messages": [HumanMessage("Hi")]}) - - assert "after_agent_called" in execution_log - assert any("final_message_count:" in log for log in execution_log) + @after_agent + async def log_after_agent(state: AgentState, runtime: Runtime) -> dict[str, Any] | None: + execution_log.append("after_agent_called") + execution_log.append(f"final_message_count: {len(state['messages'])}") + return None - async def test_async_after_agent_execution(self) -> None: - """Test that after_agent hook is called asynchronously.""" - execution_log = [] + middleware = log_after_agent + else: - @after_agent - async def async_log_after_agent( - state: AgentState, runtime: Runtime - ) -> dict[str, Any] | None: - execution_log.append("async_after_agent_called") - return None + @after_agent + def log_after_agent_sync(state: AgentState, runtime: Runtime) -> dict[str, Any] | None: + execution_log.append("after_agent_called") + execution_log.append(f"final_message_count: {len(state['messages'])}") + return None - model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")])) + middleware = log_after_agent_sync - agent = create_agent(model=model, tools=[], middleware=[async_log_after_agent]) + model = GenericFakeChatModel(messages=iter([AIMessage(content="Final response")])) + agent = create_agent(model=model, tools=[], middleware=[middleware]) - await agent.ainvoke({"messages": [HumanMessage("Hi")]}) + if is_async: + await agent.ainvoke({"messages": [HumanMessage("Hi")]}) + else: + agent.invoke({"messages": [HumanMessage("Hi")]}) - assert "async_after_agent_called" in execution_log + assert "after_agent_called" in execution_log + assert any("final_message_count:" in log for log in execution_log) def test_after_agent_state_modification(self) -> None: """Test that after_agent can modify state.""" @@ -170,93 +163,84 @@ def add_final_message(state: AgentState, runtime: Runtime) -> dict[str, Any]: return {"messages": [AIMessage("Added by after_agent")]} model = GenericFakeChatModel(messages=iter([AIMessage(content="Model response")])) - agent = create_agent(model=model, tools=[], middleware=[add_final_message]) - result = agent.invoke({"messages": [HumanMessage("Test")]}) message_contents = [msg.content for msg in result["messages"]] assert "Added by after_agent" in message_contents - def test_after_agent_with_class_inheritance(self) -> None: - """Test after_agent using class inheritance.""" - execution_log = [] + @pytest.mark.parametrize("is_async", [False, True]) + async def test_after_agent_with_class_inheritance(self, is_async: bool) -> None: + """Test after_agent using class inheritance in both sync and async modes.""" + execution_log: list[str] = [] - class CustomAfterAgentMiddleware(AgentMiddleware): - def after_agent(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: - execution_log.append("class_after_agent_called") - return None + if is_async: - model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")])) + class CustomAsyncAfterAgentMiddleware(AgentMiddleware): + async def aafter_agent( + self, state: AgentState, runtime: Runtime + ) -> dict[str, Any] | None: + execution_log.append("after_agent_called") + return None - agent = create_agent(model=model, tools=[], middleware=[CustomAfterAgentMiddleware()]) + middleware = CustomAsyncAfterAgentMiddleware() + else: - agent.invoke({"messages": [HumanMessage("Test")]}) - - assert "class_after_agent_called" in execution_log + class CustomAfterAgentMiddleware(AgentMiddleware): + def after_agent(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: + execution_log.append("after_agent_called") + return None - async def test_after_agent_with_async_class_inheritance(self) -> None: - """Test async after_agent using class inheritance.""" - execution_log = [] - - class CustomAsyncAfterAgentMiddleware(AgentMiddleware): - async def aafter_agent( - self, state: AgentState, runtime: Runtime - ) -> dict[str, Any] | None: - execution_log.append("async_class_after_agent_called") - return None + middleware = CustomAfterAgentMiddleware() model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")])) + agent = create_agent(model=model, tools=[], middleware=[middleware]) - agent = create_agent(model=model, tools=[], middleware=[CustomAsyncAfterAgentMiddleware()]) - - await agent.ainvoke({"messages": [HumanMessage("Test")]}) + if is_async: + await agent.ainvoke({"messages": [HumanMessage("Test")]}) + else: + agent.invoke({"messages": [HumanMessage("Test")]}) - assert "async_class_after_agent_called" in execution_log + assert "after_agent_called" in execution_log class TestBeforeAndAfterAgentCombined: """Test before_agent and after_agent hooks working together.""" - def test_execution_order(self) -> None: - """Test that before_agent executes before after_agent.""" - execution_log = [] + @pytest.mark.parametrize("is_async", [False, True]) + async def test_execution_order(self, is_async: bool) -> None: + """Test that before_agent executes before after_agent in both sync and async modes.""" + execution_log: list[str] = [] - @before_agent - def log_before(state: AgentState, runtime: Runtime) -> None: - execution_log.append("before") + if is_async: - @after_agent - def log_after(state: AgentState, runtime: Runtime) -> None: - execution_log.append("after") + @before_agent + async def log_before(state: AgentState, runtime: Runtime) -> None: + execution_log.append("before") - model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")])) - - agent = create_agent(model=model, tools=[], middleware=[log_before, log_after]) - - agent.invoke({"messages": [HumanMessage("Test")]}) - - assert execution_log == ["before", "after"] + @after_agent + async def log_after(state: AgentState, runtime: Runtime) -> None: + execution_log.append("after") - async def test_async_execution_order(self) -> None: - """Test async execution order of before_agent and after_agent.""" - execution_log = [] + else: - @before_agent - async def async_log_before(state: AgentState, runtime: Runtime) -> None: - execution_log.append("async_before") + @before_agent + def log_before(state: AgentState, runtime: Runtime) -> None: + execution_log.append("before") - @after_agent - async def async_log_after(state: AgentState, runtime: Runtime) -> None: - execution_log.append("async_after") + @after_agent + def log_after(state: AgentState, runtime: Runtime) -> None: + execution_log.append("after") model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")])) + agent = create_agent(model=model, tools=[], middleware=[log_before, log_after]) - agent = create_agent(model=model, tools=[], middleware=[async_log_before, async_log_after]) - - await agent.ainvoke({"messages": [HumanMessage("Test")]}) + if is_async: + await agent.ainvoke({"messages": [HumanMessage("Test")]}) + else: + agent.invoke({"messages": [HumanMessage("Test")]}) - assert execution_log == ["async_before", "async_after"] + assert execution_log == ["before", "after"] def test_state_passthrough(self) -> None: """Test that state modifications in before_agent are visible to after_agent.""" @@ -313,7 +297,13 @@ def after_two(state: AgentState, runtime: Runtime) -> None: assert "after_2" in execution_log def test_agent_hooks_run_once_with_multiple_model_calls(self) -> None: - """Test that before_agent and after_agent run only once even with tool calls.""" + """Test that before_agent and after_agent run only once per thread. + + This test verifies that agent-level hooks (before_agent, after_agent) execute + exactly once per agent invocation, regardless of how many tool calling loops occur. + This is different from model-level hooks (before_model, after_model) which run + on every model invocation within the tool calling loop. + """ execution_log = [] @before_agent @@ -324,11 +314,13 @@ def log_before_agent(state: AgentState, runtime: Runtime) -> None: def log_after_agent(state: AgentState, runtime: Runtime) -> None: execution_log.append("after_agent") - # Model will call a tool once, then respond with final answer + # Model will call a tool twice, then respond with final answer + # This creates 3 model invocations total, but agent hooks should still run once model = FakeToolCallingModel( tool_calls=[ - [{"name": "sample_tool", "args": {"query": "test"}, "id": "1"}], - [], # Second call returns no tool calls (final answer) + [{"name": "sample_tool", "args": {"query": "first"}, "id": "1"}], + [{"name": "sample_tool", "args": {"query": "second"}, "id": "2"}], + [], # Third call returns no tool calls (final answer) ] ) @@ -340,9 +332,13 @@ def log_after_agent(state: AgentState, runtime: Runtime) -> None: agent.invoke({"messages": [HumanMessage("Test")]}) - # before_agent and after_agent should run exactly once - assert execution_log.count("before_agent") == 1 - assert execution_log.count("after_agent") == 1 + # before_agent and after_agent should run exactly once, despite 3 model calls + assert execution_log.count("before_agent") == 1, ( + f"before_agent ran {execution_log.count('before_agent')} times, expected 1" + ) + assert execution_log.count("after_agent") == 1, ( + f"after_agent ran {execution_log.count('after_agent')} times, expected 1" + ) # before_agent should run first, after_agent should run last assert execution_log[0] == "before_agent" assert execution_log[-1] == "after_agent" From 0d60188dd570e1023ffab9af879720eed06277c2 Mon Sep 17 00:00:00 2001 From: Sydney Runkle Date: Mon, 6 Oct 2025 09:28:08 -0400 Subject: [PATCH 5/8] before agent tests --- .../middleware/test_before_after_agent.py | 34 ++++++++++++++----- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/test_before_after_agent.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_before_after_agent.py index 2cbad45de2320..03cc094c994d6 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/test_before_after_agent.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_before_after_agent.py @@ -8,6 +8,8 @@ AgentMiddleware, AgentState, after_agent, + after_model, + before_model, before_agent, ) from langchain_core.language_models.fake_chat_models import GenericFakeChatModel @@ -244,23 +246,18 @@ def log_after(state: AgentState, runtime: Runtime) -> None: def test_state_passthrough(self) -> None: """Test that state modifications in before_agent are visible to after_agent.""" - collected_states = {} @before_agent def modify_in_before(state: AgentState, runtime: Runtime) -> dict[str, Any]: return {"messages": [HumanMessage("Modified by before_agent")]} - @after_agent - def capture_in_after(state: AgentState, runtime: Runtime) -> None: - collected_states["messages"] = state["messages"] - model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")])) - agent = create_agent(model=model, tools=[], middleware=[modify_in_before, capture_in_after]) + agent = create_agent(model=model, tools=[], middleware=[modify_in_before]) - agent.invoke({"messages": [HumanMessage("Original")]}) + result = agent.invoke({"messages": [HumanMessage("Original")]}) - message_contents = [msg.content for msg in collected_states["messages"]] + message_contents = [msg.content for msg in result["messages"]] assert "Modified by before_agent" in message_contents def test_multiple_middleware_instances(self) -> None: @@ -310,10 +307,18 @@ def test_agent_hooks_run_once_with_multiple_model_calls(self) -> None: def log_before_agent(state: AgentState, runtime: Runtime) -> None: execution_log.append("before_agent") + @before_model + def log_before_model(state: AgentState, runtime: Runtime) -> None: + execution_log.append("before_model") + @after_agent def log_after_agent(state: AgentState, runtime: Runtime) -> None: execution_log.append("after_agent") + @after_model + def log_after_model(state: AgentState, runtime: Runtime) -> None: + execution_log.append("after_model") + # Model will call a tool twice, then respond with final answer # This creates 3 model invocations total, but agent hooks should still run once model = FakeToolCallingModel( @@ -327,7 +332,7 @@ def log_after_agent(state: AgentState, runtime: Runtime) -> None: agent = create_agent( model=model, tools=[sample_tool], - middleware=[log_before_agent, log_after_agent], + middleware=[log_before_agent, log_before_model, log_after_model, log_after_agent], ) agent.invoke({"messages": [HumanMessage("Test")]}) @@ -343,6 +348,17 @@ def log_after_agent(state: AgentState, runtime: Runtime) -> None: assert execution_log[0] == "before_agent" assert execution_log[-1] == "after_agent" + assert execution_log == [ + "before_agent", + "before_model", + "after_model", + "before_model", + "after_model", + "before_model", + "after_model", + "after_agent", + ] + class TestDecoratorParameters: """Test decorator parameters for before_agent and after_agent.""" From dbf9040113e357848dbd2d268b30479910226398 Mon Sep 17 00:00:00 2001 From: Sydney Runkle Date: Mon, 6 Oct 2025 09:29:51 -0400 Subject: [PATCH 6/8] multiple threads --- .../middleware/test_before_after_agent.py | 32 ++++++++++++------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/test_before_after_agent.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_before_after_agent.py index 03cc094c994d6..c166b47433e5d 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/test_before_after_agent.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_before_after_agent.py @@ -335,18 +335,20 @@ def log_after_model(state: AgentState, runtime: Runtime) -> None: middleware=[log_before_agent, log_before_model, log_after_model, log_after_agent], ) - agent.invoke({"messages": [HumanMessage("Test")]}) + agent.invoke({"messages": [HumanMessage("Test")]}, config={"configurable": {"thread_id": "abc"}}) - # before_agent and after_agent should run exactly once, despite 3 model calls - assert execution_log.count("before_agent") == 1, ( - f"before_agent ran {execution_log.count('before_agent')} times, expected 1" - ) - assert execution_log.count("after_agent") == 1, ( - f"after_agent ran {execution_log.count('after_agent')} times, expected 1" - ) - # before_agent should run first, after_agent should run last - assert execution_log[0] == "before_agent" - assert execution_log[-1] == "after_agent" + assert execution_log == [ + "before_agent", + "before_model", + "after_model", + "before_model", + "after_model", + "before_model", + "after_model", + "after_agent", + ] + + agent.invoke({"messages": [HumanMessage("Test")]}, config={"configurable": {"thread_id": "abc"}}) assert execution_log == [ "before_agent", @@ -357,6 +359,14 @@ def log_after_model(state: AgentState, runtime: Runtime) -> None: "before_model", "after_model", "after_agent", + "before_agent", + "before_model", + "after_model", + "before_model", + "after_model", + "before_model", + "after_model", + "after_agent", ] From e10407a8d0d8ba9d1f439a8bbd77ddd7a3cdd226 Mon Sep 17 00:00:00 2001 From: Sydney Runkle Date: Mon, 6 Oct 2025 09:30:51 -0400 Subject: [PATCH 7/8] formatting --- .../agents/middleware/test_before_after_agent.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/test_before_after_agent.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_before_after_agent.py index c166b47433e5d..35ecb4d7f1f14 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/test_before_after_agent.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_before_after_agent.py @@ -335,7 +335,9 @@ def log_after_model(state: AgentState, runtime: Runtime) -> None: middleware=[log_before_agent, log_before_model, log_after_model, log_after_agent], ) - agent.invoke({"messages": [HumanMessage("Test")]}, config={"configurable": {"thread_id": "abc"}}) + agent.invoke( + {"messages": [HumanMessage("Test")]}, config={"configurable": {"thread_id": "abc"}} + ) assert execution_log == [ "before_agent", @@ -348,7 +350,9 @@ def log_after_model(state: AgentState, runtime: Runtime) -> None: "after_agent", ] - agent.invoke({"messages": [HumanMessage("Test")]}, config={"configurable": {"thread_id": "abc"}}) + agent.invoke( + {"messages": [HumanMessage("Test")]}, config={"configurable": {"thread_id": "abc"}} + ) assert execution_log == [ "before_agent", From b6384f464dbe14fd3e00800e7c9e83b8b1e01709 Mon Sep 17 00:00:00 2001 From: Sydney Runkle Date: Mon, 6 Oct 2025 09:48:53 -0400 Subject: [PATCH 8/8] final cleanup on tests --- .../middleware/test_before_after_agent.py | 238 ++++-------------- 1 file changed, 55 insertions(+), 183 deletions(-) diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/test_before_after_agent.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_before_after_agent.py index 35ecb4d7f1f14..9736c395448db 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/test_before_after_agent.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_before_after_agent.py @@ -26,175 +26,95 @@ def sample_tool(query: str) -> str: return f"Result for: {query}" -class TestBeforeAgentBasic: - """Test basic before_agent functionality.""" +class TestAgentMiddlewareHooks: + """Test before_agent and after_agent middleware hooks.""" @pytest.mark.parametrize("is_async", [False, True]) - async def test_before_agent_execution(self, is_async: bool) -> None: - """Test that before_agent hook is called in both sync and async modes.""" + @pytest.mark.parametrize("hook_type", ["before", "after"]) + async def test_hook_execution(self, is_async: bool, hook_type: str) -> None: + """Test that agent hooks are called in both sync and async modes.""" execution_log: list[str] = [] if is_async: + if hook_type == "before": - @before_agent - async def log_before_agent( - state: AgentState, runtime: Runtime - ) -> dict[str, Any] | None: - execution_log.append("before_agent_called") - execution_log.append(f"message_count: {len(state['messages'])}") - return None - - middleware = log_before_agent + @before_agent + async def log_hook(state: AgentState, runtime: Runtime) -> dict[str, Any] | None: + execution_log.append(f"{hook_type}_agent_called") + execution_log.append(f"message_count: {len(state['messages'])}") + return None + else: + + @after_agent + async def log_hook(state: AgentState, runtime: Runtime) -> dict[str, Any] | None: + execution_log.append(f"{hook_type}_agent_called") + execution_log.append(f"message_count: {len(state['messages'])}") + return None else: + if hook_type == "before": - @before_agent - def log_before_agent_sync(state: AgentState, runtime: Runtime) -> dict[str, Any] | None: - execution_log.append("before_agent_called") - execution_log.append(f"message_count: {len(state['messages'])}") - return None + @before_agent + def log_hook(state: AgentState, runtime: Runtime) -> dict[str, Any] | None: + execution_log.append(f"{hook_type}_agent_called") + execution_log.append(f"message_count: {len(state['messages'])}") + return None + else: - middleware = log_before_agent_sync + @after_agent + def log_hook(state: AgentState, runtime: Runtime) -> dict[str, Any] | None: + execution_log.append(f"{hook_type}_agent_called") + execution_log.append(f"message_count: {len(state['messages'])}") + return None - model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello!")])) - agent = create_agent(model=model, tools=[], middleware=[middleware]) + model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")])) + agent = create_agent(model=model, tools=[], middleware=[log_hook]) if is_async: await agent.ainvoke({"messages": [HumanMessage("Hi")]}) else: agent.invoke({"messages": [HumanMessage("Hi")]}) - assert "before_agent_called" in execution_log - assert "message_count: 1" in execution_log - - def test_before_agent_state_modification(self) -> None: - """Test that before_agent can modify state.""" - - @before_agent - def add_metadata(state: AgentState, runtime: Runtime) -> dict[str, Any]: - return {"messages": [HumanMessage("Injected by middleware")]} - - model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")])) - agent = create_agent(model=model, tools=[], middleware=[add_metadata]) - result = agent.invoke({"messages": [HumanMessage("Original")]}) - - # Should have original + injected + AI response - assert len(result["messages"]) >= 2 - message_contents = [msg.content for msg in result["messages"]] - assert "Injected by middleware" in message_contents + assert f"{hook_type}_agent_called" in execution_log + assert any("message_count:" in log for log in execution_log) @pytest.mark.parametrize("is_async", [False, True]) - async def test_before_agent_with_class_inheritance(self, is_async: bool) -> None: - """Test before_agent using class inheritance in both sync and async modes.""" + @pytest.mark.parametrize("hook_type", ["before", "after"]) + async def test_hook_with_class_inheritance(self, is_async: bool, hook_type: str) -> None: + """Test agent hooks using class inheritance in both sync and async modes.""" execution_log: list[str] = [] if is_async: - class CustomAsyncBeforeAgentMiddleware(AgentMiddleware): + class CustomMiddleware(AgentMiddleware): async def abefore_agent( self, state: AgentState, runtime: Runtime ) -> dict[str, Any] | None: - execution_log.append("before_agent_called") + if hook_type == "before": + execution_log.append("hook_called") return None - middleware = CustomAsyncBeforeAgentMiddleware() - else: - - class CustomBeforeAgentMiddleware(AgentMiddleware): - def before_agent( + async def aafter_agent( self, state: AgentState, runtime: Runtime ) -> dict[str, Any] | None: - execution_log.append("before_agent_called") + if hook_type == "after": + execution_log.append("hook_called") return None - - middleware = CustomBeforeAgentMiddleware() - - model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")])) - agent = create_agent(model=model, tools=[], middleware=[middleware]) - - if is_async: - await agent.ainvoke({"messages": [HumanMessage("Test")]}) - else: - agent.invoke({"messages": [HumanMessage("Test")]}) - - assert "before_agent_called" in execution_log - - -class TestAfterAgentBasic: - """Test basic after_agent functionality.""" - - @pytest.mark.parametrize("is_async", [False, True]) - async def test_after_agent_execution(self, is_async: bool) -> None: - """Test that after_agent hook is called in both sync and async modes.""" - execution_log: list[str] = [] - - if is_async: - - @after_agent - async def log_after_agent(state: AgentState, runtime: Runtime) -> dict[str, Any] | None: - execution_log.append("after_agent_called") - execution_log.append(f"final_message_count: {len(state['messages'])}") - return None - - middleware = log_after_agent else: - @after_agent - def log_after_agent_sync(state: AgentState, runtime: Runtime) -> dict[str, Any] | None: - execution_log.append("after_agent_called") - execution_log.append(f"final_message_count: {len(state['messages'])}") - return None - - middleware = log_after_agent_sync - - model = GenericFakeChatModel(messages=iter([AIMessage(content="Final response")])) - agent = create_agent(model=model, tools=[], middleware=[middleware]) - - if is_async: - await agent.ainvoke({"messages": [HumanMessage("Hi")]}) - else: - agent.invoke({"messages": [HumanMessage("Hi")]}) - - assert "after_agent_called" in execution_log - assert any("final_message_count:" in log for log in execution_log) - - def test_after_agent_state_modification(self) -> None: - """Test that after_agent can modify state.""" - - @after_agent - def add_final_message(state: AgentState, runtime: Runtime) -> dict[str, Any]: - return {"messages": [AIMessage("Added by after_agent")]} - - model = GenericFakeChatModel(messages=iter([AIMessage(content="Model response")])) - agent = create_agent(model=model, tools=[], middleware=[add_final_message]) - result = agent.invoke({"messages": [HumanMessage("Test")]}) - - message_contents = [msg.content for msg in result["messages"]] - assert "Added by after_agent" in message_contents - - @pytest.mark.parametrize("is_async", [False, True]) - async def test_after_agent_with_class_inheritance(self, is_async: bool) -> None: - """Test after_agent using class inheritance in both sync and async modes.""" - execution_log: list[str] = [] - - if is_async: - - class CustomAsyncAfterAgentMiddleware(AgentMiddleware): - async def aafter_agent( + class CustomMiddleware(AgentMiddleware): + def before_agent( self, state: AgentState, runtime: Runtime ) -> dict[str, Any] | None: - execution_log.append("after_agent_called") + if hook_type == "before": + execution_log.append("hook_called") return None - middleware = CustomAsyncAfterAgentMiddleware() - else: - - class CustomAfterAgentMiddleware(AgentMiddleware): def after_agent(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: - execution_log.append("after_agent_called") + if hook_type == "after": + execution_log.append("hook_called") return None - middleware = CustomAfterAgentMiddleware() - + middleware = CustomMiddleware() model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")])) agent = create_agent(model=model, tools=[], middleware=[middleware]) @@ -203,10 +123,10 @@ def after_agent(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | N else: agent.invoke({"messages": [HumanMessage("Test")]}) - assert "after_agent_called" in execution_log + assert "hook_called" in execution_log -class TestBeforeAndAfterAgentCombined: +class TestAgentHooksCombined: """Test before_agent and after_agent hooks working together.""" @pytest.mark.parametrize("is_async", [False, True]) @@ -223,7 +143,6 @@ async def log_before(state: AgentState, runtime: Runtime) -> None: @after_agent async def log_after(state: AgentState, runtime: Runtime) -> None: execution_log.append("after") - else: @before_agent @@ -249,16 +168,14 @@ def test_state_passthrough(self) -> None: @before_agent def modify_in_before(state: AgentState, runtime: Runtime) -> dict[str, Any]: - return {"messages": [HumanMessage("Modified by before_agent")]} + return {"messages": [HumanMessage("Added by before_agent")]} model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")])) - agent = create_agent(model=model, tools=[], middleware=[modify_in_before]) - result = agent.invoke({"messages": [HumanMessage("Original")]}) message_contents = [msg.content for msg in result["messages"]] - assert "Modified by before_agent" in message_contents + assert message_contents[1] == "Added by before_agent" def test_multiple_middleware_instances(self) -> None: """Test multiple before_agent and after_agent middleware instances.""" @@ -281,17 +198,12 @@ def after_two(state: AgentState, runtime: Runtime) -> None: execution_log.append("after_2") model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")])) - agent = create_agent( model=model, tools=[], middleware=[before_one, before_two, after_one, after_two] ) - agent.invoke({"messages": [HumanMessage("Test")]}) - assert "before_1" in execution_log - assert "before_2" in execution_log - assert "after_1" in execution_log - assert "after_2" in execution_log + assert execution_log == ["before_1", "before_2", "after_2", "after_1"] def test_agent_hooks_run_once_with_multiple_model_calls(self) -> None: """Test that before_agent and after_agent run only once per thread. @@ -372,43 +284,3 @@ def log_after_model(state: AgentState, runtime: Runtime) -> None: "after_model", "after_agent", ] - - -class TestDecoratorParameters: - """Test decorator parameters for before_agent and after_agent.""" - - def test_before_agent_with_custom_name(self) -> None: - """Test before_agent with custom middleware name.""" - - @before_agent(name="CustomBeforeAgentMiddleware") - def custom_named_before(state: AgentState, runtime: Runtime) -> None: - pass - - assert custom_named_before.name == "CustomBeforeAgentMiddleware" - - def test_after_agent_with_custom_name(self) -> None: - """Test after_agent with custom middleware name.""" - - @after_agent(name="CustomAfterAgentMiddleware") - def custom_named_after(state: AgentState, runtime: Runtime) -> None: - pass - - assert custom_named_after.name == "CustomAfterAgentMiddleware" - - def test_before_agent_default_name(self) -> None: - """Test that before_agent uses function name by default.""" - - @before_agent - def my_before_agent_function(state: AgentState, runtime: Runtime) -> None: - pass - - assert my_before_agent_function.name == "my_before_agent_function" - - def test_after_agent_default_name(self) -> None: - """Test that after_agent uses function name by default.""" - - @after_agent - def my_after_agent_function(state: AgentState, runtime: Runtime) -> None: - pass - - assert my_after_agent_function.name == "my_after_agent_function"