Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions libs/langchain_v1/langchain/agents/middleware/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
AgentMiddleware,
AgentState,
ModelRequest,
after_agent,
after_model,
before_agent,
before_model,
dynamic_prompt,
hook_config,
Expand All @@ -33,7 +35,9 @@
"PlanningMiddleware",
"SummarizationMiddleware",
"ToolCallLimitMiddleware",
"after_agent",
"after_model",
"before_agent",
"before_model",
"dynamic_prompt",
"hook_config",
Expand Down
296 changes: 295 additions & 1 deletion libs/langchain_v1/langchain/agents/middleware/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,13 @@
"ModelRequest",
"OmitFromSchema",
"PublicAgentState",
"after_agent",
"after_model",
"before_agent",
"before_model",
"dynamic_prompt",
"hook_config",
"modify_model_request",
]

JumpTo = Literal["tools", "model", "end"]
Expand Down Expand Up @@ -93,7 +98,7 @@ class AgentState(TypedDict, Generic[ResponseT]):

messages: Required[Annotated[list[AnyMessage], add_messages]]
jump_to: NotRequired[Annotated[JumpTo | None, EphemeralValue, PrivateStateAttr]]
structured_response: NotRequired[ResponseT]
structured_response: NotRequired[Annotated[ResponseT, OmitFromInput]]
thread_model_call_count: NotRequired[Annotated[int, PrivateStateAttr]]
run_model_call_count: NotRequired[Annotated[int, UntrackedValue, PrivateStateAttr]]

Expand Down Expand Up @@ -133,6 +138,14 @@ def name(self) -> str:
"""
return self.__class__.__name__

def before_agent(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
"""Logic to run before the agent execution starts."""

async def abefore_agent(
self, state: StateT, runtime: Runtime[ContextT]
) -> dict[str, Any] | None:
"""Async logic to run before the agent execution starts."""

def before_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
"""Logic to run before the model is called."""

Expand Down Expand Up @@ -215,6 +228,14 @@ async def aretry_model_request(
None, self.retry_model_request, error, request, state, runtime, attempt
)

def after_agent(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
"""Logic to run after the agent execution completes."""

async def aafter_agent(
self, state: StateT, runtime: Runtime[ContextT]
) -> dict[str, Any] | None:
"""Async logic to run after the agent execution completes."""


class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
"""Callable with AgentState and Runtime as arguments."""
Expand Down Expand Up @@ -707,6 +728,279 @@ def wrapped(
return decorator


@overload
def before_agent(
func: _CallableWithStateAndRuntime[StateT, ContextT],
) -> AgentMiddleware[StateT, ContextT]: ...


@overload
def before_agent(
func: None = None,
*,
state_schema: type[StateT] | None = None,
tools: list[BaseTool] | None = None,
can_jump_to: list[JumpTo] | None = None,
name: str | None = None,
) -> Callable[
[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]
]: ...


def before_agent(
func: _CallableWithStateAndRuntime[StateT, ContextT] | None = None,
*,
state_schema: type[StateT] | None = None,
tools: list[BaseTool] | None = None,
can_jump_to: list[JumpTo] | None = None,
name: str | None = None,
) -> (
Callable[[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]
| AgentMiddleware[StateT, ContextT]
):
"""Decorator used to dynamically create a middleware with the before_agent hook.

Args:
func: The function to be decorated. Must accept:
`state: StateT, runtime: Runtime[ContextT]` - State and runtime context
state_schema: Optional custom state schema type. If not provided, uses the default
AgentState schema.
tools: Optional list of additional tools to register with this middleware.
can_jump_to: Optional list of valid jump destinations for conditional edges.
Valid values are: "tools", "model", "end"
name: Optional name for the generated middleware class. If not provided,
uses the decorated function's name.

Returns:
Either an AgentMiddleware instance (if func is provided directly) or a decorator function
that can be applied to a function its wrapping.

The decorated function should return:
- `dict[str, Any]` - State updates to merge into the agent state
- `Command` - A command to control flow (e.g., jump to different node)
- `None` - No state updates or flow control

Examples:
Basic usage:
```python
@before_agent
def log_before_agent(state: AgentState, runtime: Runtime) -> None:
print(f"Starting agent with {len(state['messages'])} messages")
```

With conditional jumping:
```python
@before_agent(can_jump_to=["end"])
def conditional_before_agent(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
if some_condition(state):
return {"jump_to": "end"}
return None
```

With custom state schema:
```python
@before_agent(state_schema=MyCustomState)
def custom_before_agent(state: MyCustomState, runtime: Runtime) -> dict[str, Any]:
return {"custom_field": "initialized_value"}
```
"""

def decorator(
func: _CallableWithStateAndRuntime[StateT, ContextT],
) -> AgentMiddleware[StateT, ContextT]:
is_async = iscoroutinefunction(func)

func_can_jump_to = (
can_jump_to if can_jump_to is not None else getattr(func, "__can_jump_to__", [])
)

if is_async:

async def async_wrapped(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
state: StateT,
runtime: Runtime[ContextT],
) -> dict[str, Any] | Command | None:
return await func(state, runtime) # type: ignore[misc]

# Preserve can_jump_to metadata on the wrapped function
if func_can_jump_to:
async_wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined]

middleware_name = name or cast(
"str", getattr(func, "__name__", "BeforeAgentMiddleware")
)

return type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": state_schema or AgentState,
"tools": tools or [],
"abefore_agent": async_wrapped,
},
)()

def wrapped(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
state: StateT,
runtime: Runtime[ContextT],
) -> dict[str, Any] | Command | None:
return func(state, runtime) # type: ignore[return-value]

# Preserve can_jump_to metadata on the wrapped function
if func_can_jump_to:
wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined]

# Use function name as default if no name provided
middleware_name = name or cast("str", getattr(func, "__name__", "BeforeAgentMiddleware"))

return type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": state_schema or AgentState,
"tools": tools or [],
"before_agent": wrapped,
},
)()

if func is not None:
return decorator(func)
return decorator


@overload
def after_agent(
func: _CallableWithStateAndRuntime[StateT, ContextT],
) -> AgentMiddleware[StateT, ContextT]: ...


@overload
def after_agent(
func: None = None,
*,
state_schema: type[StateT] | None = None,
tools: list[BaseTool] | None = None,
can_jump_to: list[JumpTo] | None = None,
name: str | None = None,
) -> Callable[
[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]
]: ...


def after_agent(
func: _CallableWithStateAndRuntime[StateT, ContextT] | None = None,
*,
state_schema: type[StateT] | None = None,
tools: list[BaseTool] | None = None,
can_jump_to: list[JumpTo] | None = None,
name: str | None = None,
) -> (
Callable[[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]
| AgentMiddleware[StateT, ContextT]
):
"""Decorator used to dynamically create a middleware with the after_agent hook.

Args:
func: The function to be decorated. Must accept:
`state: StateT, runtime: Runtime[ContextT]` - State and runtime context
state_schema: Optional custom state schema type. If not provided, uses the default
AgentState schema.
tools: Optional list of additional tools to register with this middleware.
can_jump_to: Optional list of valid jump destinations for conditional edges.
Valid values are: "tools", "model", "end"
name: Optional name for the generated middleware class. If not provided,
uses the decorated function's name.

Returns:
Either an AgentMiddleware instance (if func is provided) or a decorator function
that can be applied to a function.

The decorated function should return:
- `dict[str, Any]` - State updates to merge into the agent state
- `Command` - A command to control flow (e.g., jump to different node)
- `None` - No state updates or flow control

Examples:
Basic usage for logging agent completion:
```python
@after_agent
def log_completion(state: AgentState, runtime: Runtime) -> None:
print(f"Agent completed with {len(state['messages'])} messages")
```

With custom state schema:
```python
@after_agent(state_schema=MyCustomState, name="MyAfterAgentMiddleware")
def custom_after_agent(state: MyCustomState, runtime: Runtime) -> dict[str, Any]:
return {"custom_field": "finalized_value"}
```
"""

def decorator(
func: _CallableWithStateAndRuntime[StateT, ContextT],
) -> AgentMiddleware[StateT, ContextT]:
is_async = iscoroutinefunction(func)
# Extract can_jump_to from decorator parameter or from function metadata
func_can_jump_to = (
can_jump_to if can_jump_to is not None else getattr(func, "__can_jump_to__", [])
)

if is_async:

async def async_wrapped(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
state: StateT,
runtime: Runtime[ContextT],
) -> dict[str, Any] | Command | None:
return await func(state, runtime) # type: ignore[misc]

# Preserve can_jump_to metadata on the wrapped function
if func_can_jump_to:
async_wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined]

middleware_name = name or cast("str", getattr(func, "__name__", "AfterAgentMiddleware"))

return type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": state_schema or AgentState,
"tools": tools or [],
"aafter_agent": async_wrapped,
},
)()

def wrapped(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
state: StateT,
runtime: Runtime[ContextT],
) -> dict[str, Any] | Command | None:
return func(state, runtime) # type: ignore[return-value]

# Preserve can_jump_to metadata on the wrapped function
if func_can_jump_to:
wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined]

# Use function name as default if no name provided
middleware_name = name or cast("str", getattr(func, "__name__", "AfterAgentMiddleware"))

return type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": state_schema or AgentState,
"tools": tools or [],
"after_agent": wrapped,
},
)()

if func is not None:
return decorator(func)
return decorator


@overload
def dynamic_prompt(
func: _CallableReturningPromptString[StateT, ContextT],
Expand Down
Loading