2
2
3
3
from __future__ import annotations
4
4
5
+ from collections .abc import Callable
5
6
from dataclasses import dataclass , field
6
7
from inspect import iscoroutinefunction
7
8
from typing import (
8
9
TYPE_CHECKING ,
9
10
Annotated ,
10
11
Any ,
11
- ClassVar ,
12
12
Generic ,
13
13
Literal ,
14
14
Protocol ,
30
30
from typing_extensions import NotRequired , Required , TypedDict , TypeVar
31
31
32
32
if TYPE_CHECKING :
33
- from collections .abc import Callable
34
-
35
33
from langchain_core .language_models .chat_models import BaseChatModel
36
34
from langchain_core .tools import BaseTool
37
35
from langgraph .runtime import Runtime
46
44
"ModelRequest" ,
47
45
"OmitFromSchema" ,
48
46
"PublicAgentState" ,
47
+ "hook_config" ,
49
48
]
50
49
51
50
JumpTo = Literal ["tools" , "model" , "end" ]
@@ -123,12 +122,6 @@ class AgentMiddleware(Generic[StateT, ContextT]):
123
122
tools : list [BaseTool ]
124
123
"""Additional tools registered by the middleware."""
125
124
126
- before_model_jump_to : ClassVar [list [JumpTo ]] = []
127
- """Valid jump destinations for before_model hook. Used to establish conditional edges."""
128
-
129
- after_model_jump_to : ClassVar [list [JumpTo ]] = []
130
- """Valid jump destinations for after_model hook. Used to establish conditional edges."""
131
-
132
125
def before_model (self , state : StateT , runtime : Runtime [ContextT ]) -> dict [str , Any ] | None :
133
126
"""Logic to run before the model is called."""
134
127
@@ -184,6 +177,57 @@ def __call__(
184
177
...
185
178
186
179
180
+ CallableT = TypeVar ("CallableT" , bound = Callable [..., Any ])
181
+
182
+
183
+ def hook_config (
184
+ * ,
185
+ can_jump_to : list [JumpTo ] | None = None ,
186
+ ) -> Callable [[CallableT ], CallableT ]:
187
+ """Decorator to configure hook behavior in middleware methods.
188
+
189
+ Use this decorator on `before_model` or `after_model` methods in middleware classes
190
+ to configure their behavior. Currently supports specifying which destinations they
191
+ can jump to, which establishes conditional edges in the agent graph.
192
+
193
+ Args:
194
+ can_jump_to: Optional list of valid jump destinations. Can be:
195
+ - "tools": Jump to the tools node
196
+ - "model": Jump back to the model node
197
+ - "end": Jump to the end of the graph
198
+
199
+ Returns:
200
+ Decorator function that marks the method with configuration metadata.
201
+
202
+ Examples:
203
+ Using decorator on a class method:
204
+ ```python
205
+ class MyMiddleware(AgentMiddleware):
206
+ @hook_config(can_jump_to=["end", "model"])
207
+ def before_model(self, state: AgentState) -> dict[str, Any] | None:
208
+ if some_condition(state):
209
+ return {"jump_to": "end"}
210
+ return None
211
+ ```
212
+
213
+ Alternative: Use the `can_jump_to` parameter in `before_model`/`after_model` decorators:
214
+ ```python
215
+ @before_model(can_jump_to=["end"])
216
+ def conditional_middleware(state: AgentState) -> dict[str, Any] | None:
217
+ if should_exit(state):
218
+ return {"jump_to": "end"}
219
+ return None
220
+ ```
221
+ """
222
+
223
+ def decorator (func : CallableT ) -> CallableT :
224
+ if can_jump_to is not None :
225
+ func .__can_jump_to__ = can_jump_to # type: ignore[attr-defined]
226
+ return func
227
+
228
+ return decorator
229
+
230
+
187
231
@overload
188
232
def before_model (
189
233
func : _CallableWithStateAndRuntime [StateT , ContextT ],
@@ -196,7 +240,7 @@ def before_model(
196
240
* ,
197
241
state_schema : type [StateT ] | None = None ,
198
242
tools : list [BaseTool ] | None = None ,
199
- jump_to : list [JumpTo ] | None = None ,
243
+ can_jump_to : list [JumpTo ] | None = None ,
200
244
name : str | None = None ,
201
245
) -> Callable [
202
246
[_CallableWithStateAndRuntime [StateT , ContextT ]], AgentMiddleware [StateT , ContextT ]
@@ -208,7 +252,7 @@ def before_model(
208
252
* ,
209
253
state_schema : type [StateT ] | None = None ,
210
254
tools : list [BaseTool ] | None = None ,
211
- jump_to : list [JumpTo ] | None = None ,
255
+ can_jump_to : list [JumpTo ] | None = None ,
212
256
name : str | None = None ,
213
257
) -> (
214
258
Callable [[_CallableWithStateAndRuntime [StateT , ContextT ]], AgentMiddleware [StateT , ContextT ]]
@@ -222,7 +266,7 @@ def before_model(
222
266
state_schema: Optional custom state schema type. If not provided, uses the default
223
267
AgentState schema.
224
268
tools: Optional list of additional tools to register with this middleware.
225
- jump_to : Optional list of valid jump destinations for conditional edges.
269
+ can_jump_to : Optional list of valid jump destinations for conditional edges.
226
270
Valid values are: "tools", "model", "end"
227
271
name: Optional name for the generated middleware class. If not provided,
228
272
uses the decorated function's name.
@@ -246,7 +290,7 @@ def log_before_model(state: AgentState, runtime: Runtime) -> None:
246
290
247
291
With conditional jumping:
248
292
```python
249
- @before_model(jump_to =["end"])
293
+ @before_model(can_jump_to =["end"])
250
294
def conditional_before_model(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
251
295
if some_condition(state):
252
296
return {"jump_to": "end"}
@@ -266,6 +310,10 @@ def decorator(
266
310
) -> AgentMiddleware [StateT , ContextT ]:
267
311
is_async = iscoroutinefunction (func )
268
312
313
+ func_can_jump_to = (
314
+ can_jump_to if can_jump_to is not None else getattr (func , "__can_jump_to__" , [])
315
+ )
316
+
269
317
if is_async :
270
318
271
319
async def async_wrapped (
@@ -275,6 +323,10 @@ async def async_wrapped(
275
323
) -> dict [str , Any ] | Command | None :
276
324
return await func (state , runtime ) # type: ignore[misc]
277
325
326
+ # Preserve can_jump_to metadata on the wrapped function
327
+ if func_can_jump_to :
328
+ async_wrapped .__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined]
329
+
278
330
middleware_name = name or cast (
279
331
"str" , getattr (func , "__name__" , "BeforeModelMiddleware" )
280
332
)
@@ -285,7 +337,6 @@ async def async_wrapped(
285
337
{
286
338
"state_schema" : state_schema or AgentState ,
287
339
"tools" : tools or [],
288
- "before_model_jump_to" : jump_to or [],
289
340
"abefore_model" : async_wrapped ,
290
341
},
291
342
)()
@@ -297,6 +348,11 @@ def wrapped(
297
348
) -> dict [str , Any ] | Command | None :
298
349
return func (state , runtime ) # type: ignore[return-value]
299
350
351
+ # Preserve can_jump_to metadata on the wrapped function
352
+ if func_can_jump_to :
353
+ wrapped .__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined]
354
+
355
+ # Use function name as default if no name provided
300
356
middleware_name = name or cast ("str" , getattr (func , "__name__" , "BeforeModelMiddleware" ))
301
357
302
358
return type (
@@ -305,7 +361,6 @@ def wrapped(
305
361
{
306
362
"state_schema" : state_schema or AgentState ,
307
363
"tools" : tools or [],
308
- "before_model_jump_to" : jump_to or [],
309
364
"before_model" : wrapped ,
310
365
},
311
366
)()
@@ -464,7 +519,7 @@ def after_model(
464
519
* ,
465
520
state_schema : type [StateT ] | None = None ,
466
521
tools : list [BaseTool ] | None = None ,
467
- jump_to : list [JumpTo ] | None = None ,
522
+ can_jump_to : list [JumpTo ] | None = None ,
468
523
name : str | None = None ,
469
524
) -> Callable [
470
525
[_CallableWithStateAndRuntime [StateT , ContextT ]], AgentMiddleware [StateT , ContextT ]
@@ -476,7 +531,7 @@ def after_model(
476
531
* ,
477
532
state_schema : type [StateT ] | None = None ,
478
533
tools : list [BaseTool ] | None = None ,
479
- jump_to : list [JumpTo ] | None = None ,
534
+ can_jump_to : list [JumpTo ] | None = None ,
480
535
name : str | None = None ,
481
536
) -> (
482
537
Callable [[_CallableWithStateAndRuntime [StateT , ContextT ]], AgentMiddleware [StateT , ContextT ]]
@@ -490,7 +545,7 @@ def after_model(
490
545
state_schema: Optional custom state schema type. If not provided, uses the default
491
546
AgentState schema.
492
547
tools: Optional list of additional tools to register with this middleware.
493
- jump_to : Optional list of valid jump destinations for conditional edges.
548
+ can_jump_to : Optional list of valid jump destinations for conditional edges.
494
549
Valid values are: "tools", "model", "end"
495
550
name: Optional name for the generated middleware class. If not provided,
496
551
uses the decorated function's name.
@@ -524,6 +579,10 @@ def decorator(
524
579
func : _CallableWithStateAndRuntime [StateT , ContextT ],
525
580
) -> AgentMiddleware [StateT , ContextT ]:
526
581
is_async = iscoroutinefunction (func )
582
+ # Extract can_jump_to from decorator parameter or from function metadata
583
+ func_can_jump_to = (
584
+ can_jump_to if can_jump_to is not None else getattr (func , "__can_jump_to__" , [])
585
+ )
527
586
528
587
if is_async :
529
588
@@ -534,6 +593,10 @@ async def async_wrapped(
534
593
) -> dict [str , Any ] | Command | None :
535
594
return await func (state , runtime ) # type: ignore[misc]
536
595
596
+ # Preserve can_jump_to metadata on the wrapped function
597
+ if func_can_jump_to :
598
+ async_wrapped .__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined]
599
+
537
600
middleware_name = name or cast ("str" , getattr (func , "__name__" , "AfterModelMiddleware" ))
538
601
539
602
return type (
@@ -542,7 +605,6 @@ async def async_wrapped(
542
605
{
543
606
"state_schema" : state_schema or AgentState ,
544
607
"tools" : tools or [],
545
- "after_model_jump_to" : jump_to or [],
546
608
"aafter_model" : async_wrapped ,
547
609
},
548
610
)()
@@ -554,6 +616,11 @@ def wrapped(
554
616
) -> dict [str , Any ] | Command | None :
555
617
return func (state , runtime ) # type: ignore[return-value]
556
618
619
+ # Preserve can_jump_to metadata on the wrapped function
620
+ if func_can_jump_to :
621
+ wrapped .__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined]
622
+
623
+ # Use function name as default if no name provided
557
624
middleware_name = name or cast ("str" , getattr (func , "__name__" , "AfterModelMiddleware" ))
558
625
559
626
return type (
@@ -562,7 +629,6 @@ def wrapped(
562
629
{
563
630
"state_schema" : state_schema or AgentState ,
564
631
"tools" : tools or [],
565
- "after_model_jump_to" : jump_to or [],
566
632
"after_model" : wrapped ,
567
633
},
568
634
)()
0 commit comments