Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions libs/langchain_v1/langchain/agents/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from langchain.agents.middleware.types import (
AgentMiddleware,
AgentRuntime,
AgentState,
JumpTo,
ModelRequest,
Expand Down Expand Up @@ -1018,6 +1019,9 @@ def _execute_model_sync(request: ModelRequest) -> ModelResponse:

def model_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
"""Sync model request handler with sequential middleware processing."""
# Create flat AgentRuntime with all runtime properties
agent_runtime = AgentRuntime.from_runtime(name or "agent", runtime)

request = ModelRequest(
model=model,
tools=default_tools,
Expand All @@ -1026,7 +1030,7 @@ def model_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
messages=state["messages"],
tool_choice=None,
state=state,
runtime=runtime,
runtime=agent_runtime,
)

if wrap_model_call_handler is None:
Expand Down Expand Up @@ -1071,6 +1075,9 @@ async def _execute_model_async(request: ModelRequest) -> ModelResponse:

async def amodel_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
"""Async model request handler with sequential middleware processing."""
# Create flat AgentRuntime with all runtime properties
agent_runtime = AgentRuntime.from_runtime(name or "agent", runtime)

request = ModelRequest(
model=model,
tools=default_tools,
Expand All @@ -1079,7 +1086,7 @@ async def amodel_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str
messages=state["messages"],
tool_choice=None,
state=state,
runtime=runtime,
runtime=agent_runtime,
)

if awrap_model_call_handler is None:
Expand Down
2 changes: 2 additions & 0 deletions libs/langchain_v1/langchain/agents/middleware/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .tool_selection import LLMToolSelectorMiddleware
from .types import (
AgentMiddleware,
AgentRuntime,
AgentState,
ModelRequest,
ModelResponse,
Expand All @@ -47,6 +48,7 @@

__all__ = [
"AgentMiddleware",
"AgentRuntime",
"AgentState",
"ClearToolUsesEdit",
"CodexSandboxExecutionPolicy",
Expand Down
94 changes: 90 additions & 4 deletions libs/langchain_v1/langchain/agents/middleware/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
from langchain_core.messages import AIMessage, AnyMessage, BaseMessage, ToolMessage # noqa: TC002
from langgraph.channels.ephemeral_value import EphemeralValue
from langgraph.graph.message import add_messages
from langgraph.types import Command # noqa: TC002
from langgraph.store.base import BaseStore # noqa: TC002
from langgraph.types import Command, StreamWriter # noqa: TC002
from langgraph.typing import ContextT
from typing_extensions import NotRequired, Required, TypedDict, TypeVar, Unpack

Expand Down Expand Up @@ -60,6 +61,75 @@
ResponseT = TypeVar("ResponseT")


@dataclass
class AgentRuntime(Generic[ContextT]):
"""Runtime context for agent execution, extending LangGraph's Runtime.

This class provides agent-specific execution context to middleware, including
the name of the currently executing graph and all Runtime properties flattened
for convenient access.

The AgentRuntime follows the same pattern as ToolRuntime, providing a flat
structure with all runtime properties directly accessible.

Attributes:
agent_name: The name of the currently executing graph/agent. This is the
name passed to `create_agent(name=...)` or defaults to "LangGraph".
context: Static context for the graph run (e.g., `user_id`, `db_conn`).
store: Store for persistence and memory, if configured.
stream_writer: Function for writing to the custom stream.
previous: The previous return value for the given thread (functional API only).

Example:
```python
from langchain.agents.middleware import wrap_model_call, AgentRuntime
from langchain.agents.middleware.types import ModelRequest, ModelResponse


@wrap_model_call
def log_agent_name(
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
'''Log which agent is making the model call.'''
agent_name = request.runtime.agent_name
print(f"Agent '{agent_name}' is calling the model")

# Access runtime context directly (flattened)
user_id = request.runtime.context.get("user_id")
print(f"User: {user_id}")

return handler(request)
```
"""

agent_name: str
"""The name of the currently executing graph/agent."""

context: ContextT = field(default=None) # type: ignore[assignment]
"""Static context for the graph run, like `user_id`, `db_conn`, etc."""

store: BaseStore | None = field(default=None)
"""Store for the graph run, enabling persistence and memory."""

stream_writer: StreamWriter = field(default=None) # type: ignore[assignment]
"""Function that writes to the custom stream."""

previous: Any = field(default=None)
"""The previous return value for the given thread."""

@classmethod
def from_runtime(cls, name: str, runtime: Runtime[ContextT]) -> AgentRuntime[ContextT]:
"""Create an AgentRuntime from a Runtime."""
return AgentRuntime[ContextT](
agent_name=name,
context=runtime.context,
store=runtime.store,
stream_writer=runtime.stream_writer,
previous=runtime.previous,
)


class _ModelRequestOverrides(TypedDict, total=False):
"""Possible overrides for ModelRequest.override() method."""

Expand All @@ -74,7 +144,23 @@ class _ModelRequestOverrides(TypedDict, total=False):

@dataclass
class ModelRequest:
"""Model request information for the agent."""
"""Model request information for the agent.

This dataclass contains all the information needed for a model invocation,
including the model, messages, tools, and runtime context.

Attributes:
model: The chat model to invoke.
system_prompt: Optional system prompt to prepend to messages.
messages: List of conversation messages (excluding system prompt).
tool_choice: Tool selection configuration for the model.
tools: Available tools for the model to use.
response_format: Structured output format specification.
state: Complete agent state at the time of model invocation.
runtime: Agent runtime context including agent name and underlying
LangGraph Runtime with context, store, and stream_writer.
model_settings: Additional model-specific settings.
"""

model: BaseChatModel
system_prompt: str | None
Expand All @@ -83,7 +169,7 @@ class ModelRequest:
tools: list[BaseTool | dict]
response_format: ResponseFormat | None
state: AgentState
runtime: Runtime[ContextT] # type: ignore[valid-type]
runtime: AgentRuntime[ContextT] # type: ignore[valid-type]
model_settings: dict[str, Any] = field(default_factory=dict)

def override(self, **overrides: Unpack[_ModelRequestOverrides]) -> ModelRequest:
Expand Down Expand Up @@ -932,7 +1018,7 @@ def before_agent(
```python
@before_agent
def log_before_agent(state: AgentState, runtime: Runtime) -> None:
print(f"Starting agent with {len(state['messages'])} messages")
print(f"Starting agent '{runtime.agent_name}' with {len(state['messages'])} messages")
```

With conditional jumping:
Expand Down
108 changes: 108 additions & 0 deletions libs/langchain_v1/tests/unit_tests/agents/test_agent_runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""Tests for AgentRuntime access via wrap_model_call middleware."""

import pytest
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.tools import tool

from langchain.agents import create_agent
from langchain.agents.middleware import wrap_model_call
from langchain.agents.middleware.types import ModelRequest
from langchain.tools import ToolRuntime

from .model import FakeToolCallingModel


@pytest.fixture
def fake_chat_model():
"""Fixture providing a fake chat model for testing."""
return GenericFakeChatModel(messages=iter([AIMessage(content="test response")]))


def test_agent_name_accessible_in_middleware(fake_chat_model):
"""Test that agent name can be accessed via middleware."""
captured_agent_name = None

@wrap_model_call
def capture_agent_name(request: ModelRequest, handler):
nonlocal captured_agent_name
captured_agent_name = request.runtime.agent_name
return handler(request)

agent = create_agent(
fake_chat_model,
tools=[],
middleware=[capture_agent_name],
name="TestAgent",
)

agent.invoke({"messages": [HumanMessage("Hello")]})
assert captured_agent_name == "TestAgent"


def test_nested_agent_name_accessible_in_tool():
"""Test that nested agent's name is accessible when agent is used in a tool."""
# Track which agent names were captured
captured_agent_names = []

@wrap_model_call
def capture_agent_name(request: ModelRequest, handler):
captured_agent_names.append(request.runtime.agent_name)
return handler(request)

# Create a nested agent that will be called from within a tool
nested_agent = create_agent(
FakeToolCallingModel(),
tools=[],
middleware=[capture_agent_name],
name="NestedAgent",
)

# Create a tool that invokes the nested agent
@tool
def call_nested_agent(query: str, runtime: ToolRuntime) -> str:
"""Tool that calls a nested agent."""
result = nested_agent.invoke({"messages": [HumanMessage(query)]})
return result["messages"][-1].content

# Create outer agent that uses the tool
outer_agent = create_agent(
FakeToolCallingModel(
tool_calls=[
[{"name": "call_nested_agent", "args": {"query": "test"}, "id": "1"}],
[],
]
),
tools=[call_nested_agent],
middleware=[capture_agent_name],
name="OuterAgent",
)

# Invoke the outer agent, which should call the tool, which calls the nested agent
outer_agent.invoke({"messages": [HumanMessage("Hello")]})

# Both agents should have captured their names
assert "OuterAgent" in captured_agent_names
assert "NestedAgent" in captured_agent_names


async def test_agent_name_accessible_in_async_middleware():
"""Test that agent name can be accessed in async middleware."""
captured_agent_name = None

@wrap_model_call
async def capture_agent_name_async(request: ModelRequest, handler):
nonlocal captured_agent_name
captured_agent_name = request.runtime.agent_name
return await handler(request)

fake_model = GenericFakeChatModel(messages=iter([AIMessage(content="async response")]))
agent = create_agent(
fake_model,
tools=[],
middleware=[capture_agent_name_async],
name="AsyncAgent",
)

await agent.ainvoke({"messages": [HumanMessage("Hello")]})
assert captured_agent_name == "AsyncAgent"
Original file line number Diff line number Diff line change
Expand Up @@ -1351,7 +1351,7 @@ class CustomState(AgentState):
class CustomMiddleware(AgentMiddleware[CustomState]):
state_schema: type[CustomState] = CustomState

def before_model(self, state: CustomState) -> dict[str, Any]:
def before_model(self, state: CustomState, runtime) -> dict[str, Any]:
assert "omit_input" not in state
assert "omit_output" in state
assert "private_state" not in state
Expand Down Expand Up @@ -1456,11 +1456,11 @@ def test_injected_state_in_middleware_agent() -> None:

def test_jump_to_is_ephemeral() -> None:
class MyMiddleware(AgentMiddleware):
def before_model(self, state: AgentState) -> dict[str, Any]:
def before_model(self, state: AgentState, runtime) -> dict[str, Any]:
assert "jump_to" not in state
return {"jump_to": "model"}

def after_model(self, state: AgentState) -> dict[str, Any]:
def after_model(self, state: AgentState, runtime) -> dict[str, Any]:
assert "jump_to" not in state
return {"jump_to": "model"}

Expand Down