@@ -278,6 +278,12 @@ def create_agent( # noqa: PLR0915
278
278
if m .__class__ .after_model is not AgentMiddleware .after_model
279
279
or m .__class__ .aafter_model is not AgentMiddleware .aafter_model
280
280
]
281
+ middleware_w_retry = [
282
+ m
283
+ for m in middleware
284
+ if m .__class__ .retry_model_request is not AgentMiddleware .retry_model_request
285
+ or m .__class__ .aretry_model_request is not AgentMiddleware .aretry_model_request
286
+ ]
281
287
282
288
state_schemas = {m .state_schema for m in middleware }
283
289
state_schemas .add (AgentState )
@@ -526,18 +532,47 @@ def model_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, An
526
532
)
527
533
raise TypeError (msg )
528
534
529
- # Get the bound model (with auto-detection if needed)
530
- model_ , effective_response_format = _get_bound_model (request )
531
- messages = request .messages
532
- if request .system_prompt :
533
- messages = [SystemMessage (request .system_prompt ), * messages ]
535
+ # Retry loop for model invocation with error handling
536
+ # Hard limit of 100 attempts to prevent infinite loops from buggy middleware
537
+ max_attempts = 100
538
+ for attempt in range (1 , max_attempts + 1 ):
539
+ try :
540
+ # Get the bound model (with auto-detection if needed)
541
+ model_ , effective_response_format = _get_bound_model (request )
542
+ messages = request .messages
543
+ if request .system_prompt :
544
+ messages = [SystemMessage (request .system_prompt ), * messages ]
545
+
546
+ output = model_ .invoke (messages )
547
+ return {
548
+ "thread_model_call_count" : state .get ("thread_model_call_count" , 0 ) + 1 ,
549
+ "run_model_call_count" : state .get ("run_model_call_count" , 0 ) + 1 ,
550
+ ** _handle_model_output (output , effective_response_format ),
551
+ }
552
+ except Exception as error :
553
+ # Try retry_model_request on each middleware
554
+ for m in middleware_w_retry :
555
+ if m .__class__ .retry_model_request is not AgentMiddleware .retry_model_request :
556
+ if retry_request := m .retry_model_request (
557
+ error , request , state , runtime , attempt
558
+ ):
559
+ # Break on first middleware that wants to retry
560
+ request = retry_request
561
+ break
562
+ else :
563
+ msg = (
564
+ f"No synchronous function provided for "
565
+ f'{ m .__class__ .__name__ } .aretry_model_request".'
566
+ "\n Either initialize with a synchronous function or invoke"
567
+ " via the async API (ainvoke, astream, etc.)"
568
+ )
569
+ raise TypeError (msg )
570
+ else :
571
+ raise
534
572
535
- output = model_ .invoke (messages )
536
- return {
537
- "thread_model_call_count" : state .get ("thread_model_call_count" , 0 ) + 1 ,
538
- "run_model_call_count" : state .get ("run_model_call_count" , 0 ) + 1 ,
539
- ** _handle_model_output (output , effective_response_format ),
540
- }
573
+ # If we exit the loop, max attempts exceeded
574
+ msg = f"Maximum retry attempts ({ max_attempts } ) exceeded"
575
+ raise RuntimeError (msg )
541
576
542
577
async def amodel_request (state : AgentState , runtime : Runtime [ContextT ]) -> dict [str , Any ]:
543
578
"""Async model request handler with sequential middleware processing."""
@@ -554,18 +589,39 @@ async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[
554
589
for m in middleware_w_modify_model_request :
555
590
await m .amodify_model_request (request , state , runtime )
556
591
557
- # Get the bound model (with auto-detection if needed)
558
- model_ , effective_response_format = _get_bound_model (request )
559
- messages = request .messages
560
- if request .system_prompt :
561
- messages = [SystemMessage (request .system_prompt ), * messages ]
562
-
563
- output = await model_ .ainvoke (messages )
564
- return {
565
- "thread_model_call_count" : state .get ("thread_model_call_count" , 0 ) + 1 ,
566
- "run_model_call_count" : state .get ("run_model_call_count" , 0 ) + 1 ,
567
- ** _handle_model_output (output , effective_response_format ),
568
- }
592
+ # Retry loop for model invocation with error handling
593
+ # Hard limit of 100 attempts to prevent infinite loops from buggy middleware
594
+ max_attempts = 100
595
+ for attempt in range (1 , max_attempts + 1 ):
596
+ try :
597
+ # Get the bound model (with auto-detection if needed)
598
+ model_ , effective_response_format = _get_bound_model (request )
599
+ messages = request .messages
600
+ if request .system_prompt :
601
+ messages = [SystemMessage (request .system_prompt ), * messages ]
602
+
603
+ output = await model_ .ainvoke (messages )
604
+ return {
605
+ "thread_model_call_count" : state .get ("thread_model_call_count" , 0 ) + 1 ,
606
+ "run_model_call_count" : state .get ("run_model_call_count" , 0 ) + 1 ,
607
+ ** _handle_model_output (output , effective_response_format ),
608
+ }
609
+ except Exception as error :
610
+ # Try retry_model_request on each middleware
611
+ for m in middleware_w_retry :
612
+ if retry_request := await m .aretry_model_request (
613
+ error , request , state , runtime , attempt
614
+ ):
615
+ # Break on first middleware that wants to retry
616
+ request = retry_request
617
+ break
618
+ else :
619
+ # If no middleware wants to retry, re-raise the error
620
+ raise
621
+
622
+ # If we exit the loop, max attempts exceeded
623
+ msg = f"Maximum retry attempts ({ max_attempts } ) exceeded"
624
+ raise RuntimeError (msg )
569
625
570
626
# Use sync or async based on model capabilities
571
627
from langgraph ._internal ._runnable import RunnableCallable
0 commit comments