diff --git a/libs/langchain_v1/langchain/agents/factory.py b/libs/langchain_v1/langchain/agents/factory.py index 2f9962759fc7a..f708555aa124a 100644 --- a/libs/langchain_v1/langchain/agents/factory.py +++ b/libs/langchain_v1/langchain/agents/factory.py @@ -18,6 +18,7 @@ from langchain.agents.middleware.types import ( AgentMiddleware, + AgentRuntime, AgentState, JumpTo, ModelRequest, @@ -1018,6 +1019,9 @@ def _execute_model_sync(request: ModelRequest) -> ModelResponse: def model_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]: """Sync model request handler with sequential middleware processing.""" + # Create flat AgentRuntime with all runtime properties + agent_runtime = AgentRuntime.from_runtime(name or "agent", runtime) + request = ModelRequest( model=model, tools=default_tools, @@ -1026,7 +1030,7 @@ def model_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]: messages=state["messages"], tool_choice=None, state=state, - runtime=runtime, + runtime=agent_runtime, ) if wrap_model_call_handler is None: @@ -1071,6 +1075,9 @@ async def _execute_model_async(request: ModelRequest) -> ModelResponse: async def amodel_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]: """Async model request handler with sequential middleware processing.""" + # Create flat AgentRuntime with all runtime properties + agent_runtime = AgentRuntime.from_runtime(name or "agent", runtime) + request = ModelRequest( model=model, tools=default_tools, @@ -1079,7 +1086,7 @@ async def amodel_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str messages=state["messages"], tool_choice=None, state=state, - runtime=runtime, + runtime=agent_runtime, ) if awrap_model_call_handler is None: diff --git a/libs/langchain_v1/langchain/agents/middleware/__init__.py b/libs/langchain_v1/langchain/agents/middleware/__init__.py index 8ed35aafcd5c4..3b66c08ac1bf6 100644 --- a/libs/langchain_v1/langchain/agents/middleware/__init__.py +++ b/libs/langchain_v1/langchain/agents/middleware/__init__.py @@ -32,6 +32,7 @@ from .tool_selection import LLMToolSelectorMiddleware from .types import ( AgentMiddleware, + AgentRuntime, AgentState, ModelRequest, ModelResponse, @@ -47,6 +48,7 @@ __all__ = [ "AgentMiddleware", + "AgentRuntime", "AgentState", "ClearToolUsesEdit", "CodexSandboxExecutionPolicy", diff --git a/libs/langchain_v1/langchain/agents/middleware/types.py b/libs/langchain_v1/langchain/agents/middleware/types.py index a100c65ce179f..db4434d41d4ab 100644 --- a/libs/langchain_v1/langchain/agents/middleware/types.py +++ b/libs/langchain_v1/langchain/agents/middleware/types.py @@ -27,7 +27,8 @@ from langchain_core.messages import AIMessage, AnyMessage, BaseMessage, ToolMessage # noqa: TC002 from langgraph.channels.ephemeral_value import EphemeralValue from langgraph.graph.message import add_messages -from langgraph.types import Command # noqa: TC002 +from langgraph.store.base import BaseStore # noqa: TC002 +from langgraph.types import Command, StreamWriter # noqa: TC002 from langgraph.typing import ContextT from typing_extensions import NotRequired, Required, TypedDict, TypeVar, Unpack @@ -60,6 +61,75 @@ ResponseT = TypeVar("ResponseT") +@dataclass +class AgentRuntime(Generic[ContextT]): + """Runtime context for agent execution, extending LangGraph's Runtime. + + This class provides agent-specific execution context to middleware, including + the name of the currently executing graph and all Runtime properties flattened + for convenient access. + + The AgentRuntime follows the same pattern as ToolRuntime, providing a flat + structure with all runtime properties directly accessible. + + Attributes: + agent_name: The name of the currently executing graph/agent. This is the + name passed to `create_agent(name=...)` or defaults to "LangGraph". + context: Static context for the graph run (e.g., `user_id`, `db_conn`). + store: Store for persistence and memory, if configured. + stream_writer: Function for writing to the custom stream. + previous: The previous return value for the given thread (functional API only). + + Example: + ```python + from langchain.agents.middleware import wrap_model_call, AgentRuntime + from langchain.agents.middleware.types import ModelRequest, ModelResponse + + + @wrap_model_call + def log_agent_name( + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelResponse: + '''Log which agent is making the model call.''' + agent_name = request.runtime.agent_name + print(f"Agent '{agent_name}' is calling the model") + + # Access runtime context directly (flattened) + user_id = request.runtime.context.get("user_id") + print(f"User: {user_id}") + + return handler(request) + ``` + """ + + agent_name: str + """The name of the currently executing graph/agent.""" + + context: ContextT = field(default=None) # type: ignore[assignment] + """Static context for the graph run, like `user_id`, `db_conn`, etc.""" + + store: BaseStore | None = field(default=None) + """Store for the graph run, enabling persistence and memory.""" + + stream_writer: StreamWriter = field(default=None) # type: ignore[assignment] + """Function that writes to the custom stream.""" + + previous: Any = field(default=None) + """The previous return value for the given thread.""" + + @classmethod + def from_runtime(cls, name: str, runtime: Runtime[ContextT]) -> AgentRuntime[ContextT]: + """Create an AgentRuntime from a Runtime.""" + return AgentRuntime[ContextT]( + agent_name=name, + context=runtime.context, + store=runtime.store, + stream_writer=runtime.stream_writer, + previous=runtime.previous, + ) + + class _ModelRequestOverrides(TypedDict, total=False): """Possible overrides for ModelRequest.override() method.""" @@ -74,7 +144,23 @@ class _ModelRequestOverrides(TypedDict, total=False): @dataclass class ModelRequest: - """Model request information for the agent.""" + """Model request information for the agent. + + This dataclass contains all the information needed for a model invocation, + including the model, messages, tools, and runtime context. + + Attributes: + model: The chat model to invoke. + system_prompt: Optional system prompt to prepend to messages. + messages: List of conversation messages (excluding system prompt). + tool_choice: Tool selection configuration for the model. + tools: Available tools for the model to use. + response_format: Structured output format specification. + state: Complete agent state at the time of model invocation. + runtime: Agent runtime context including agent name and underlying + LangGraph Runtime with context, store, and stream_writer. + model_settings: Additional model-specific settings. + """ model: BaseChatModel system_prompt: str | None @@ -83,7 +169,7 @@ class ModelRequest: tools: list[BaseTool | dict] response_format: ResponseFormat | None state: AgentState - runtime: Runtime[ContextT] # type: ignore[valid-type] + runtime: AgentRuntime[ContextT] # type: ignore[valid-type] model_settings: dict[str, Any] = field(default_factory=dict) def override(self, **overrides: Unpack[_ModelRequestOverrides]) -> ModelRequest: @@ -932,7 +1018,7 @@ def before_agent( ```python @before_agent def log_before_agent(state: AgentState, runtime: Runtime) -> None: - print(f"Starting agent with {len(state['messages'])} messages") + print(f"Starting agent '{runtime.agent_name}' with {len(state['messages'])} messages") ``` With conditional jumping: diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_agent_runtime.py b/libs/langchain_v1/tests/unit_tests/agents/test_agent_runtime.py new file mode 100644 index 0000000000000..fd77cf63b39fc --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/test_agent_runtime.py @@ -0,0 +1,108 @@ +"""Tests for AgentRuntime access via wrap_model_call middleware.""" + +import pytest +from langchain_core.language_models.fake_chat_models import GenericFakeChatModel +from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.tools import tool + +from langchain.agents import create_agent +from langchain.agents.middleware import wrap_model_call +from langchain.agents.middleware.types import ModelRequest +from langchain.tools import ToolRuntime + +from .model import FakeToolCallingModel + + +@pytest.fixture +def fake_chat_model(): + """Fixture providing a fake chat model for testing.""" + return GenericFakeChatModel(messages=iter([AIMessage(content="test response")])) + + +def test_agent_name_accessible_in_middleware(fake_chat_model): + """Test that agent name can be accessed via middleware.""" + captured_agent_name = None + + @wrap_model_call + def capture_agent_name(request: ModelRequest, handler): + nonlocal captured_agent_name + captured_agent_name = request.runtime.agent_name + return handler(request) + + agent = create_agent( + fake_chat_model, + tools=[], + middleware=[capture_agent_name], + name="TestAgent", + ) + + agent.invoke({"messages": [HumanMessage("Hello")]}) + assert captured_agent_name == "TestAgent" + + +def test_nested_agent_name_accessible_in_tool(): + """Test that nested agent's name is accessible when agent is used in a tool.""" + # Track which agent names were captured + captured_agent_names = [] + + @wrap_model_call + def capture_agent_name(request: ModelRequest, handler): + captured_agent_names.append(request.runtime.agent_name) + return handler(request) + + # Create a nested agent that will be called from within a tool + nested_agent = create_agent( + FakeToolCallingModel(), + tools=[], + middleware=[capture_agent_name], + name="NestedAgent", + ) + + # Create a tool that invokes the nested agent + @tool + def call_nested_agent(query: str, runtime: ToolRuntime) -> str: + """Tool that calls a nested agent.""" + result = nested_agent.invoke({"messages": [HumanMessage(query)]}) + return result["messages"][-1].content + + # Create outer agent that uses the tool + outer_agent = create_agent( + FakeToolCallingModel( + tool_calls=[ + [{"name": "call_nested_agent", "args": {"query": "test"}, "id": "1"}], + [], + ] + ), + tools=[call_nested_agent], + middleware=[capture_agent_name], + name="OuterAgent", + ) + + # Invoke the outer agent, which should call the tool, which calls the nested agent + outer_agent.invoke({"messages": [HumanMessage("Hello")]}) + + # Both agents should have captured their names + assert "OuterAgent" in captured_agent_names + assert "NestedAgent" in captured_agent_names + + +async def test_agent_name_accessible_in_async_middleware(): + """Test that agent name can be accessed in async middleware.""" + captured_agent_name = None + + @wrap_model_call + async def capture_agent_name_async(request: ModelRequest, handler): + nonlocal captured_agent_name + captured_agent_name = request.runtime.agent_name + return await handler(request) + + fake_model = GenericFakeChatModel(messages=iter([AIMessage(content="async response")])) + agent = create_agent( + fake_model, + tools=[], + middleware=[capture_agent_name_async], + name="AsyncAgent", + ) + + await agent.ainvoke({"messages": [HumanMessage("Hello")]}) + assert captured_agent_name == "AsyncAgent" diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_middleware_agent.py b/libs/langchain_v1/tests/unit_tests/agents/test_middleware_agent.py index 02fa96e6b65af..9fffc56347168 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/test_middleware_agent.py +++ b/libs/langchain_v1/tests/unit_tests/agents/test_middleware_agent.py @@ -1351,7 +1351,7 @@ class CustomState(AgentState): class CustomMiddleware(AgentMiddleware[CustomState]): state_schema: type[CustomState] = CustomState - def before_model(self, state: CustomState) -> dict[str, Any]: + def before_model(self, state: CustomState, runtime) -> dict[str, Any]: assert "omit_input" not in state assert "omit_output" in state assert "private_state" not in state @@ -1456,11 +1456,11 @@ def test_injected_state_in_middleware_agent() -> None: def test_jump_to_is_ephemeral() -> None: class MyMiddleware(AgentMiddleware): - def before_model(self, state: AgentState) -> dict[str, Any]: + def before_model(self, state: AgentState, runtime) -> dict[str, Any]: assert "jump_to" not in state return {"jump_to": "model"} - def after_model(self, state: AgentState) -> dict[str, Any]: + def after_model(self, state: AgentState, runtime) -> dict[str, Any]: assert "jump_to" not in state return {"jump_to": "model"}