Skip to content

Commit 0ccc0cb

Browse files
feat(langchain_v1): before_agent and after_agent hooks (#33279)
We're adding enough new nodes that I think a refactor in terms of graph building is warranted here, but not necessarily required for merging.
1 parent 7404338 commit 0ccc0cb

File tree

5 files changed

+757
-48
lines changed

5 files changed

+757
-48
lines changed

libs/langchain_v1/langchain/agents/middleware/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
AgentMiddleware,
1818
AgentState,
1919
ModelRequest,
20+
after_agent,
2021
after_model,
22+
before_agent,
2123
before_model,
2224
dynamic_prompt,
2325
hook_config,
@@ -41,7 +43,9 @@
4143
"PlanningMiddleware",
4244
"SummarizationMiddleware",
4345
"ToolCallLimitMiddleware",
46+
"after_agent",
4447
"after_model",
48+
"before_agent",
4549
"before_model",
4650
"dynamic_prompt",
4751
"hook_config",

libs/langchain_v1/langchain/agents/middleware/types.py

Lines changed: 295 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,13 @@
4444
"ModelRequest",
4545
"OmitFromSchema",
4646
"PublicAgentState",
47+
"after_agent",
48+
"after_model",
49+
"before_agent",
50+
"before_model",
4751
"dynamic_prompt",
4852
"hook_config",
53+
"modify_model_request",
4954
]
5055

5156
JumpTo = Literal["tools", "model", "end"]
@@ -93,7 +98,7 @@ class AgentState(TypedDict, Generic[ResponseT]):
9398

9499
messages: Required[Annotated[list[AnyMessage], add_messages]]
95100
jump_to: NotRequired[Annotated[JumpTo | None, EphemeralValue, PrivateStateAttr]]
96-
structured_response: NotRequired[ResponseT]
101+
structured_response: NotRequired[Annotated[ResponseT, OmitFromInput]]
97102
thread_model_call_count: NotRequired[Annotated[int, PrivateStateAttr]]
98103
run_model_call_count: NotRequired[Annotated[int, UntrackedValue, PrivateStateAttr]]
99104

@@ -133,6 +138,14 @@ def name(self) -> str:
133138
"""
134139
return self.__class__.__name__
135140

141+
def before_agent(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
142+
"""Logic to run before the agent execution starts."""
143+
144+
async def abefore_agent(
145+
self, state: StateT, runtime: Runtime[ContextT]
146+
) -> dict[str, Any] | None:
147+
"""Async logic to run before the agent execution starts."""
148+
136149
def before_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
137150
"""Logic to run before the model is called."""
138151

@@ -215,6 +228,14 @@ async def aretry_model_request(
215228
None, self.retry_model_request, error, request, state, runtime, attempt
216229
)
217230

231+
def after_agent(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
232+
"""Logic to run after the agent execution completes."""
233+
234+
async def aafter_agent(
235+
self, state: StateT, runtime: Runtime[ContextT]
236+
) -> dict[str, Any] | None:
237+
"""Async logic to run after the agent execution completes."""
238+
218239

219240
class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
220241
"""Callable with AgentState and Runtime as arguments."""
@@ -707,6 +728,279 @@ def wrapped(
707728
return decorator
708729

709730

731+
@overload
732+
def before_agent(
733+
func: _CallableWithStateAndRuntime[StateT, ContextT],
734+
) -> AgentMiddleware[StateT, ContextT]: ...
735+
736+
737+
@overload
738+
def before_agent(
739+
func: None = None,
740+
*,
741+
state_schema: type[StateT] | None = None,
742+
tools: list[BaseTool] | None = None,
743+
can_jump_to: list[JumpTo] | None = None,
744+
name: str | None = None,
745+
) -> Callable[
746+
[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]
747+
]: ...
748+
749+
750+
def before_agent(
751+
func: _CallableWithStateAndRuntime[StateT, ContextT] | None = None,
752+
*,
753+
state_schema: type[StateT] | None = None,
754+
tools: list[BaseTool] | None = None,
755+
can_jump_to: list[JumpTo] | None = None,
756+
name: str | None = None,
757+
) -> (
758+
Callable[[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]
759+
| AgentMiddleware[StateT, ContextT]
760+
):
761+
"""Decorator used to dynamically create a middleware with the before_agent hook.
762+
763+
Args:
764+
func: The function to be decorated. Must accept:
765+
`state: StateT, runtime: Runtime[ContextT]` - State and runtime context
766+
state_schema: Optional custom state schema type. If not provided, uses the default
767+
AgentState schema.
768+
tools: Optional list of additional tools to register with this middleware.
769+
can_jump_to: Optional list of valid jump destinations for conditional edges.
770+
Valid values are: "tools", "model", "end"
771+
name: Optional name for the generated middleware class. If not provided,
772+
uses the decorated function's name.
773+
774+
Returns:
775+
Either an AgentMiddleware instance (if func is provided directly) or a decorator function
776+
that can be applied to a function its wrapping.
777+
778+
The decorated function should return:
779+
- `dict[str, Any]` - State updates to merge into the agent state
780+
- `Command` - A command to control flow (e.g., jump to different node)
781+
- `None` - No state updates or flow control
782+
783+
Examples:
784+
Basic usage:
785+
```python
786+
@before_agent
787+
def log_before_agent(state: AgentState, runtime: Runtime) -> None:
788+
print(f"Starting agent with {len(state['messages'])} messages")
789+
```
790+
791+
With conditional jumping:
792+
```python
793+
@before_agent(can_jump_to=["end"])
794+
def conditional_before_agent(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
795+
if some_condition(state):
796+
return {"jump_to": "end"}
797+
return None
798+
```
799+
800+
With custom state schema:
801+
```python
802+
@before_agent(state_schema=MyCustomState)
803+
def custom_before_agent(state: MyCustomState, runtime: Runtime) -> dict[str, Any]:
804+
return {"custom_field": "initialized_value"}
805+
```
806+
"""
807+
808+
def decorator(
809+
func: _CallableWithStateAndRuntime[StateT, ContextT],
810+
) -> AgentMiddleware[StateT, ContextT]:
811+
is_async = iscoroutinefunction(func)
812+
813+
func_can_jump_to = (
814+
can_jump_to if can_jump_to is not None else getattr(func, "__can_jump_to__", [])
815+
)
816+
817+
if is_async:
818+
819+
async def async_wrapped(
820+
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
821+
state: StateT,
822+
runtime: Runtime[ContextT],
823+
) -> dict[str, Any] | Command | None:
824+
return await func(state, runtime) # type: ignore[misc]
825+
826+
# Preserve can_jump_to metadata on the wrapped function
827+
if func_can_jump_to:
828+
async_wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined]
829+
830+
middleware_name = name or cast(
831+
"str", getattr(func, "__name__", "BeforeAgentMiddleware")
832+
)
833+
834+
return type(
835+
middleware_name,
836+
(AgentMiddleware,),
837+
{
838+
"state_schema": state_schema or AgentState,
839+
"tools": tools or [],
840+
"abefore_agent": async_wrapped,
841+
},
842+
)()
843+
844+
def wrapped(
845+
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
846+
state: StateT,
847+
runtime: Runtime[ContextT],
848+
) -> dict[str, Any] | Command | None:
849+
return func(state, runtime) # type: ignore[return-value]
850+
851+
# Preserve can_jump_to metadata on the wrapped function
852+
if func_can_jump_to:
853+
wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined]
854+
855+
# Use function name as default if no name provided
856+
middleware_name = name or cast("str", getattr(func, "__name__", "BeforeAgentMiddleware"))
857+
858+
return type(
859+
middleware_name,
860+
(AgentMiddleware,),
861+
{
862+
"state_schema": state_schema or AgentState,
863+
"tools": tools or [],
864+
"before_agent": wrapped,
865+
},
866+
)()
867+
868+
if func is not None:
869+
return decorator(func)
870+
return decorator
871+
872+
873+
@overload
874+
def after_agent(
875+
func: _CallableWithStateAndRuntime[StateT, ContextT],
876+
) -> AgentMiddleware[StateT, ContextT]: ...
877+
878+
879+
@overload
880+
def after_agent(
881+
func: None = None,
882+
*,
883+
state_schema: type[StateT] | None = None,
884+
tools: list[BaseTool] | None = None,
885+
can_jump_to: list[JumpTo] | None = None,
886+
name: str | None = None,
887+
) -> Callable[
888+
[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]
889+
]: ...
890+
891+
892+
def after_agent(
893+
func: _CallableWithStateAndRuntime[StateT, ContextT] | None = None,
894+
*,
895+
state_schema: type[StateT] | None = None,
896+
tools: list[BaseTool] | None = None,
897+
can_jump_to: list[JumpTo] | None = None,
898+
name: str | None = None,
899+
) -> (
900+
Callable[[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]
901+
| AgentMiddleware[StateT, ContextT]
902+
):
903+
"""Decorator used to dynamically create a middleware with the after_agent hook.
904+
905+
Args:
906+
func: The function to be decorated. Must accept:
907+
`state: StateT, runtime: Runtime[ContextT]` - State and runtime context
908+
state_schema: Optional custom state schema type. If not provided, uses the default
909+
AgentState schema.
910+
tools: Optional list of additional tools to register with this middleware.
911+
can_jump_to: Optional list of valid jump destinations for conditional edges.
912+
Valid values are: "tools", "model", "end"
913+
name: Optional name for the generated middleware class. If not provided,
914+
uses the decorated function's name.
915+
916+
Returns:
917+
Either an AgentMiddleware instance (if func is provided) or a decorator function
918+
that can be applied to a function.
919+
920+
The decorated function should return:
921+
- `dict[str, Any]` - State updates to merge into the agent state
922+
- `Command` - A command to control flow (e.g., jump to different node)
923+
- `None` - No state updates or flow control
924+
925+
Examples:
926+
Basic usage for logging agent completion:
927+
```python
928+
@after_agent
929+
def log_completion(state: AgentState, runtime: Runtime) -> None:
930+
print(f"Agent completed with {len(state['messages'])} messages")
931+
```
932+
933+
With custom state schema:
934+
```python
935+
@after_agent(state_schema=MyCustomState, name="MyAfterAgentMiddleware")
936+
def custom_after_agent(state: MyCustomState, runtime: Runtime) -> dict[str, Any]:
937+
return {"custom_field": "finalized_value"}
938+
```
939+
"""
940+
941+
def decorator(
942+
func: _CallableWithStateAndRuntime[StateT, ContextT],
943+
) -> AgentMiddleware[StateT, ContextT]:
944+
is_async = iscoroutinefunction(func)
945+
# Extract can_jump_to from decorator parameter or from function metadata
946+
func_can_jump_to = (
947+
can_jump_to if can_jump_to is not None else getattr(func, "__can_jump_to__", [])
948+
)
949+
950+
if is_async:
951+
952+
async def async_wrapped(
953+
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
954+
state: StateT,
955+
runtime: Runtime[ContextT],
956+
) -> dict[str, Any] | Command | None:
957+
return await func(state, runtime) # type: ignore[misc]
958+
959+
# Preserve can_jump_to metadata on the wrapped function
960+
if func_can_jump_to:
961+
async_wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined]
962+
963+
middleware_name = name or cast("str", getattr(func, "__name__", "AfterAgentMiddleware"))
964+
965+
return type(
966+
middleware_name,
967+
(AgentMiddleware,),
968+
{
969+
"state_schema": state_schema or AgentState,
970+
"tools": tools or [],
971+
"aafter_agent": async_wrapped,
972+
},
973+
)()
974+
975+
def wrapped(
976+
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
977+
state: StateT,
978+
runtime: Runtime[ContextT],
979+
) -> dict[str, Any] | Command | None:
980+
return func(state, runtime) # type: ignore[return-value]
981+
982+
# Preserve can_jump_to metadata on the wrapped function
983+
if func_can_jump_to:
984+
wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined]
985+
986+
# Use function name as default if no name provided
987+
middleware_name = name or cast("str", getattr(func, "__name__", "AfterAgentMiddleware"))
988+
989+
return type(
990+
middleware_name,
991+
(AgentMiddleware,),
992+
{
993+
"state_schema": state_schema or AgentState,
994+
"tools": tools or [],
995+
"after_agent": wrapped,
996+
},
997+
)()
998+
999+
if func is not None:
1000+
return decorator(func)
1001+
return decorator
1002+
1003+
7101004
@overload
7111005
def dynamic_prompt(
7121006
func: _CallableReturningPromptString[StateT, ContextT],

0 commit comments

Comments
 (0)