Skip to content

Commit a336afa

Browse files
feat(langchain): use decorators for jumps instead (#33179)
The old `before_model_jump_to` classvar approach was quite clunky, this is nicer imo and easier to document. Also moving from `jump_to` to `can_jump_to` which is more idiomatic. Before: ```py class MyMiddleware(AgentMiddleware): before_model_jump_to: ClassVar[list[JumpTo]] = ["end"] def before_model(state, runtime) -> dict[str, Any]: return {"jump_to": "end"} ``` After ```py class MyMiddleware(AgentMiddleware): @hook_config(can_jump_to=["end"]) def before_model(state, runtime) -> dict[str, Any]: return {"jump_to": "end"} ```
1 parent af07949 commit a336afa

File tree

7 files changed

+499
-46
lines changed

7 files changed

+499
-46
lines changed

libs/langchain_v1/langchain/agents/middleware/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
ModelRequest,
1111
after_model,
1212
before_model,
13+
hook_config,
1314
modify_model_request,
1415
)
1516

@@ -24,5 +25,6 @@
2425
"SummarizationMiddleware",
2526
"after_model",
2627
"before_model",
28+
"hook_config",
2729
"modify_model_request",
2830
]

libs/langchain_v1/langchain/agents/middleware/types.py

Lines changed: 86 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22

33
from __future__ import annotations
44

5+
from collections.abc import Callable
56
from dataclasses import dataclass, field
67
from inspect import iscoroutinefunction
78
from typing import (
89
TYPE_CHECKING,
910
Annotated,
1011
Any,
11-
ClassVar,
1212
Generic,
1313
Literal,
1414
Protocol,
@@ -30,8 +30,6 @@
3030
from typing_extensions import NotRequired, Required, TypedDict, TypeVar
3131

3232
if TYPE_CHECKING:
33-
from collections.abc import Callable
34-
3533
from langchain_core.language_models.chat_models import BaseChatModel
3634
from langchain_core.tools import BaseTool
3735
from langgraph.runtime import Runtime
@@ -46,6 +44,7 @@
4644
"ModelRequest",
4745
"OmitFromSchema",
4846
"PublicAgentState",
47+
"hook_config",
4948
]
5049

5150
JumpTo = Literal["tools", "model", "end"]
@@ -123,12 +122,6 @@ class AgentMiddleware(Generic[StateT, ContextT]):
123122
tools: list[BaseTool]
124123
"""Additional tools registered by the middleware."""
125124

126-
before_model_jump_to: ClassVar[list[JumpTo]] = []
127-
"""Valid jump destinations for before_model hook. Used to establish conditional edges."""
128-
129-
after_model_jump_to: ClassVar[list[JumpTo]] = []
130-
"""Valid jump destinations for after_model hook. Used to establish conditional edges."""
131-
132125
def before_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
133126
"""Logic to run before the model is called."""
134127

@@ -184,6 +177,57 @@ def __call__(
184177
...
185178

186179

180+
CallableT = TypeVar("CallableT", bound=Callable[..., Any])
181+
182+
183+
def hook_config(
184+
*,
185+
can_jump_to: list[JumpTo] | None = None,
186+
) -> Callable[[CallableT], CallableT]:
187+
"""Decorator to configure hook behavior in middleware methods.
188+
189+
Use this decorator on `before_model` or `after_model` methods in middleware classes
190+
to configure their behavior. Currently supports specifying which destinations they
191+
can jump to, which establishes conditional edges in the agent graph.
192+
193+
Args:
194+
can_jump_to: Optional list of valid jump destinations. Can be:
195+
- "tools": Jump to the tools node
196+
- "model": Jump back to the model node
197+
- "end": Jump to the end of the graph
198+
199+
Returns:
200+
Decorator function that marks the method with configuration metadata.
201+
202+
Examples:
203+
Using decorator on a class method:
204+
```python
205+
class MyMiddleware(AgentMiddleware):
206+
@hook_config(can_jump_to=["end", "model"])
207+
def before_model(self, state: AgentState) -> dict[str, Any] | None:
208+
if some_condition(state):
209+
return {"jump_to": "end"}
210+
return None
211+
```
212+
213+
Alternative: Use the `can_jump_to` parameter in `before_model`/`after_model` decorators:
214+
```python
215+
@before_model(can_jump_to=["end"])
216+
def conditional_middleware(state: AgentState) -> dict[str, Any] | None:
217+
if should_exit(state):
218+
return {"jump_to": "end"}
219+
return None
220+
```
221+
"""
222+
223+
def decorator(func: CallableT) -> CallableT:
224+
if can_jump_to is not None:
225+
func.__can_jump_to__ = can_jump_to # type: ignore[attr-defined]
226+
return func
227+
228+
return decorator
229+
230+
187231
@overload
188232
def before_model(
189233
func: _CallableWithStateAndRuntime[StateT, ContextT],
@@ -196,7 +240,7 @@ def before_model(
196240
*,
197241
state_schema: type[StateT] | None = None,
198242
tools: list[BaseTool] | None = None,
199-
jump_to: list[JumpTo] | None = None,
243+
can_jump_to: list[JumpTo] | None = None,
200244
name: str | None = None,
201245
) -> Callable[
202246
[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]
@@ -208,7 +252,7 @@ def before_model(
208252
*,
209253
state_schema: type[StateT] | None = None,
210254
tools: list[BaseTool] | None = None,
211-
jump_to: list[JumpTo] | None = None,
255+
can_jump_to: list[JumpTo] | None = None,
212256
name: str | None = None,
213257
) -> (
214258
Callable[[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]
@@ -222,7 +266,7 @@ def before_model(
222266
state_schema: Optional custom state schema type. If not provided, uses the default
223267
AgentState schema.
224268
tools: Optional list of additional tools to register with this middleware.
225-
jump_to: Optional list of valid jump destinations for conditional edges.
269+
can_jump_to: Optional list of valid jump destinations for conditional edges.
226270
Valid values are: "tools", "model", "end"
227271
name: Optional name for the generated middleware class. If not provided,
228272
uses the decorated function's name.
@@ -246,7 +290,7 @@ def log_before_model(state: AgentState, runtime: Runtime) -> None:
246290
247291
With conditional jumping:
248292
```python
249-
@before_model(jump_to=["end"])
293+
@before_model(can_jump_to=["end"])
250294
def conditional_before_model(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
251295
if some_condition(state):
252296
return {"jump_to": "end"}
@@ -266,6 +310,10 @@ def decorator(
266310
) -> AgentMiddleware[StateT, ContextT]:
267311
is_async = iscoroutinefunction(func)
268312

313+
func_can_jump_to = (
314+
can_jump_to if can_jump_to is not None else getattr(func, "__can_jump_to__", [])
315+
)
316+
269317
if is_async:
270318

271319
async def async_wrapped(
@@ -275,6 +323,10 @@ async def async_wrapped(
275323
) -> dict[str, Any] | Command | None:
276324
return await func(state, runtime) # type: ignore[misc]
277325

326+
# Preserve can_jump_to metadata on the wrapped function
327+
if func_can_jump_to:
328+
async_wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined]
329+
278330
middleware_name = name or cast(
279331
"str", getattr(func, "__name__", "BeforeModelMiddleware")
280332
)
@@ -285,7 +337,6 @@ async def async_wrapped(
285337
{
286338
"state_schema": state_schema or AgentState,
287339
"tools": tools or [],
288-
"before_model_jump_to": jump_to or [],
289340
"abefore_model": async_wrapped,
290341
},
291342
)()
@@ -297,6 +348,11 @@ def wrapped(
297348
) -> dict[str, Any] | Command | None:
298349
return func(state, runtime) # type: ignore[return-value]
299350

351+
# Preserve can_jump_to metadata on the wrapped function
352+
if func_can_jump_to:
353+
wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined]
354+
355+
# Use function name as default if no name provided
300356
middleware_name = name or cast("str", getattr(func, "__name__", "BeforeModelMiddleware"))
301357

302358
return type(
@@ -305,7 +361,6 @@ def wrapped(
305361
{
306362
"state_schema": state_schema or AgentState,
307363
"tools": tools or [],
308-
"before_model_jump_to": jump_to or [],
309364
"before_model": wrapped,
310365
},
311366
)()
@@ -464,7 +519,7 @@ def after_model(
464519
*,
465520
state_schema: type[StateT] | None = None,
466521
tools: list[BaseTool] | None = None,
467-
jump_to: list[JumpTo] | None = None,
522+
can_jump_to: list[JumpTo] | None = None,
468523
name: str | None = None,
469524
) -> Callable[
470525
[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]
@@ -476,7 +531,7 @@ def after_model(
476531
*,
477532
state_schema: type[StateT] | None = None,
478533
tools: list[BaseTool] | None = None,
479-
jump_to: list[JumpTo] | None = None,
534+
can_jump_to: list[JumpTo] | None = None,
480535
name: str | None = None,
481536
) -> (
482537
Callable[[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]
@@ -490,7 +545,7 @@ def after_model(
490545
state_schema: Optional custom state schema type. If not provided, uses the default
491546
AgentState schema.
492547
tools: Optional list of additional tools to register with this middleware.
493-
jump_to: Optional list of valid jump destinations for conditional edges.
548+
can_jump_to: Optional list of valid jump destinations for conditional edges.
494549
Valid values are: "tools", "model", "end"
495550
name: Optional name for the generated middleware class. If not provided,
496551
uses the decorated function's name.
@@ -524,6 +579,10 @@ def decorator(
524579
func: _CallableWithStateAndRuntime[StateT, ContextT],
525580
) -> AgentMiddleware[StateT, ContextT]:
526581
is_async = iscoroutinefunction(func)
582+
# Extract can_jump_to from decorator parameter or from function metadata
583+
func_can_jump_to = (
584+
can_jump_to if can_jump_to is not None else getattr(func, "__can_jump_to__", [])
585+
)
527586

528587
if is_async:
529588

@@ -534,6 +593,10 @@ async def async_wrapped(
534593
) -> dict[str, Any] | Command | None:
535594
return await func(state, runtime) # type: ignore[misc]
536595

596+
# Preserve can_jump_to metadata on the wrapped function
597+
if func_can_jump_to:
598+
async_wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined]
599+
537600
middleware_name = name or cast("str", getattr(func, "__name__", "AfterModelMiddleware"))
538601

539602
return type(
@@ -542,7 +605,6 @@ async def async_wrapped(
542605
{
543606
"state_schema": state_schema or AgentState,
544607
"tools": tools or [],
545-
"after_model_jump_to": jump_to or [],
546608
"aafter_model": async_wrapped,
547609
},
548610
)()
@@ -554,6 +616,11 @@ def wrapped(
554616
) -> dict[str, Any] | Command | None:
555617
return func(state, runtime) # type: ignore[return-value]
556618

619+
# Preserve can_jump_to metadata on the wrapped function
620+
if func_can_jump_to:
621+
wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined]
622+
623+
# Use function name as default if no name provided
557624
middleware_name = name or cast("str", getattr(func, "__name__", "AfterModelMiddleware"))
558625

559626
return type(
@@ -562,7 +629,6 @@ def wrapped(
562629
{
563630
"state_schema": state_schema or AgentState,
564631
"tools": tools or [],
565-
"after_model_jump_to": jump_to or [],
566632
"after_model": wrapped,
567633
},
568634
)()

0 commit comments

Comments
 (0)