Skip to content

Commit a9aa3f2

Browse files
feat(langchain_v1): Add retry_model_request middleware hook, add ModelFallbackMiddleware (#33275)
- retry_model_request hook lets a middleware decide to retry a failed model request, with full ability to modify as much or as little of the request before doing so - ModelFallbackMiddleware tries each fallback model in order, until one is successful, or fallback list is exhausted Co-authored-by: Sydney Runkle <[email protected]>
1 parent 20514f5 commit a9aa3f2

File tree

6 files changed

+498
-25
lines changed

6 files changed

+498
-25
lines changed

.claude/settings.local.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"permissions": {
33
"allow": [
44
"Bash(uv run:*)",
5-
"Bash(make:*)"
5+
"Bash(make:*)",
66
"WebSearch",
77
"WebFetch(domain:ai.pydantic.dev)",
88
"WebFetch(domain:openai.github.io)",
@@ -12,4 +12,4 @@
1212
"deny": [],
1313
"ask": []
1414
}
15-
}
15+
}

libs/langchain_v1/langchain/agents/middleware/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Middleware plugins for agents."""
22

33
from .human_in_the_loop import HumanInTheLoopMiddleware
4+
from .model_fallback import ModelFallbackMiddleware
45
from .pii import PIIDetectionError, PIIMiddleware
56
from .planning import PlanningMiddleware
67
from .prompt_caching import AnthropicPromptCachingMiddleware
@@ -25,6 +26,7 @@
2526
"AnthropicPromptCachingMiddleware",
2627
"HumanInTheLoopMiddleware",
2728
"LLMToolSelectorMiddleware",
29+
"ModelFallbackMiddleware",
2830
"ModelRequest",
2931
"PIIDetectionError",
3032
"PIIMiddleware",
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
"""Model fallback middleware for agents."""
2+
3+
from __future__ import annotations
4+
5+
from typing import TYPE_CHECKING
6+
7+
from langchain.agents.middleware.types import AgentMiddleware, AgentState, ModelRequest
8+
from langchain.chat_models import init_chat_model
9+
10+
if TYPE_CHECKING:
11+
from langchain_core.language_models.chat_models import BaseChatModel
12+
from langgraph.runtime import Runtime
13+
14+
15+
class ModelFallbackMiddleware(AgentMiddleware):
16+
"""Middleware that provides automatic model fallback on errors.
17+
18+
This middleware attempts to retry failed model calls with alternative models
19+
in sequence. When a model call fails, it tries the next model in the fallback
20+
list until either a call succeeds or all models have been exhausted.
21+
22+
Example:
23+
```python
24+
from langchain.agents.middleware.model_fallback import ModelFallbackMiddleware
25+
from langchain.agents import create_agent
26+
27+
# Create middleware with fallback models (not including primary)
28+
fallback = ModelFallbackMiddleware(
29+
"openai:gpt-4o-mini", # First fallback
30+
"anthropic:claude-3-5-sonnet-20241022", # Second fallback
31+
)
32+
33+
agent = create_agent(
34+
model="openai:gpt-4o", # Primary model
35+
middleware=[fallback],
36+
)
37+
38+
# If gpt-4o fails, automatically tries gpt-4o-mini, then claude
39+
result = await agent.invoke({"messages": [HumanMessage("Hello")]})
40+
```
41+
"""
42+
43+
def __init__(
44+
self,
45+
first_model: str | BaseChatModel,
46+
*additional_models: str | BaseChatModel,
47+
) -> None:
48+
"""Initialize the model fallback middleware.
49+
50+
Args:
51+
first_model: The first fallback model to try when the primary model fails.
52+
Can be a model name string or BaseChatModel instance.
53+
*additional_models: Additional fallback models to try, in order.
54+
Can be model name strings or BaseChatModel instances.
55+
"""
56+
super().__init__()
57+
58+
# Initialize all fallback models
59+
all_models = (first_model, *additional_models)
60+
self.models: list[BaseChatModel] = []
61+
for model in all_models:
62+
if isinstance(model, str):
63+
self.models.append(init_chat_model(model))
64+
else:
65+
self.models.append(model)
66+
67+
def retry_model_request(
68+
self,
69+
error: Exception, # noqa: ARG002
70+
request: ModelRequest,
71+
state: AgentState, # noqa: ARG002
72+
runtime: Runtime, # noqa: ARG002
73+
attempt: int,
74+
) -> ModelRequest | None:
75+
"""Retry with the next fallback model.
76+
77+
Args:
78+
error: The exception that occurred during model invocation.
79+
request: The original model request that failed.
80+
state: The current agent state.
81+
runtime: The langgraph runtime.
82+
attempt: The current attempt number (1-indexed).
83+
84+
Returns:
85+
ModelRequest with the next fallback model, or None if all models exhausted.
86+
"""
87+
# attempt 1 = primary model failed, try models[0] (first fallback)
88+
fallback_index = attempt - 1
89+
# All fallback models exhausted
90+
if fallback_index >= len(self.models):
91+
return None
92+
# Try next fallback model
93+
request.model = self.models[fallback_index]
94+
return request

libs/langchain_v1/langchain/agents/middleware/types.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,54 @@ async def aafter_model(
167167
) -> dict[str, Any] | None:
168168
"""Async logic to run after the model is called."""
169169

170+
def retry_model_request(
171+
self,
172+
error: Exception, # noqa: ARG002
173+
request: ModelRequest, # noqa: ARG002
174+
state: StateT, # noqa: ARG002
175+
runtime: Runtime[ContextT], # noqa: ARG002
176+
attempt: int, # noqa: ARG002
177+
) -> ModelRequest | None:
178+
"""Logic to handle model invocation errors and optionally retry.
179+
180+
Args:
181+
error: The exception that occurred during model invocation.
182+
request: The original model request that failed.
183+
state: The current agent state.
184+
runtime: The langgraph runtime.
185+
attempt: The current attempt number (1-indexed).
186+
187+
Returns:
188+
ModelRequest: Modified request to retry with.
189+
None: Propagate the error (re-raise).
190+
"""
191+
return None
192+
193+
async def aretry_model_request(
194+
self,
195+
error: Exception,
196+
request: ModelRequest,
197+
state: StateT,
198+
runtime: Runtime[ContextT],
199+
attempt: int,
200+
) -> ModelRequest | None:
201+
"""Async logic to handle model invocation errors and optionally retry.
202+
203+
Args:
204+
error: The exception that occurred during model invocation.
205+
request: The original model request that failed.
206+
state: The current agent state.
207+
runtime: The langgraph runtime.
208+
attempt: The current attempt number (1-indexed).
209+
210+
Returns:
211+
ModelRequest: Modified request to retry with.
212+
None: Propagate the error (re-raise).
213+
"""
214+
return await run_in_executor(
215+
None, self.retry_model_request, error, request, state, runtime, attempt
216+
)
217+
170218

171219
class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
172220
"""Callable with AgentState and Runtime as arguments."""

libs/langchain_v1/langchain/agents/middleware_agent.py

Lines changed: 79 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,12 @@ def create_agent( # noqa: PLR0915
278278
if m.__class__.after_model is not AgentMiddleware.after_model
279279
or m.__class__.aafter_model is not AgentMiddleware.aafter_model
280280
]
281+
middleware_w_retry = [
282+
m
283+
for m in middleware
284+
if m.__class__.retry_model_request is not AgentMiddleware.retry_model_request
285+
or m.__class__.aretry_model_request is not AgentMiddleware.aretry_model_request
286+
]
281287

282288
state_schemas = {m.state_schema for m in middleware}
283289
state_schemas.add(AgentState)
@@ -526,18 +532,47 @@ def model_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, An
526532
)
527533
raise TypeError(msg)
528534

529-
# Get the bound model (with auto-detection if needed)
530-
model_, effective_response_format = _get_bound_model(request)
531-
messages = request.messages
532-
if request.system_prompt:
533-
messages = [SystemMessage(request.system_prompt), *messages]
535+
# Retry loop for model invocation with error handling
536+
# Hard limit of 100 attempts to prevent infinite loops from buggy middleware
537+
max_attempts = 100
538+
for attempt in range(1, max_attempts + 1):
539+
try:
540+
# Get the bound model (with auto-detection if needed)
541+
model_, effective_response_format = _get_bound_model(request)
542+
messages = request.messages
543+
if request.system_prompt:
544+
messages = [SystemMessage(request.system_prompt), *messages]
545+
546+
output = model_.invoke(messages)
547+
return {
548+
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
549+
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
550+
**_handle_model_output(output, effective_response_format),
551+
}
552+
except Exception as error:
553+
# Try retry_model_request on each middleware
554+
for m in middleware_w_retry:
555+
if m.__class__.retry_model_request is not AgentMiddleware.retry_model_request:
556+
if retry_request := m.retry_model_request(
557+
error, request, state, runtime, attempt
558+
):
559+
# Break on first middleware that wants to retry
560+
request = retry_request
561+
break
562+
else:
563+
msg = (
564+
f"No synchronous function provided for "
565+
f'{m.__class__.__name__}.aretry_model_request".'
566+
"\nEither initialize with a synchronous function or invoke"
567+
" via the async API (ainvoke, astream, etc.)"
568+
)
569+
raise TypeError(msg)
570+
else:
571+
raise
534572

535-
output = model_.invoke(messages)
536-
return {
537-
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
538-
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
539-
**_handle_model_output(output, effective_response_format),
540-
}
573+
# If we exit the loop, max attempts exceeded
574+
msg = f"Maximum retry attempts ({max_attempts}) exceeded"
575+
raise RuntimeError(msg)
541576

542577
async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
543578
"""Async model request handler with sequential middleware processing."""
@@ -554,18 +589,39 @@ async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[
554589
for m in middleware_w_modify_model_request:
555590
await m.amodify_model_request(request, state, runtime)
556591

557-
# Get the bound model (with auto-detection if needed)
558-
model_, effective_response_format = _get_bound_model(request)
559-
messages = request.messages
560-
if request.system_prompt:
561-
messages = [SystemMessage(request.system_prompt), *messages]
562-
563-
output = await model_.ainvoke(messages)
564-
return {
565-
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
566-
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
567-
**_handle_model_output(output, effective_response_format),
568-
}
592+
# Retry loop for model invocation with error handling
593+
# Hard limit of 100 attempts to prevent infinite loops from buggy middleware
594+
max_attempts = 100
595+
for attempt in range(1, max_attempts + 1):
596+
try:
597+
# Get the bound model (with auto-detection if needed)
598+
model_, effective_response_format = _get_bound_model(request)
599+
messages = request.messages
600+
if request.system_prompt:
601+
messages = [SystemMessage(request.system_prompt), *messages]
602+
603+
output = await model_.ainvoke(messages)
604+
return {
605+
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
606+
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
607+
**_handle_model_output(output, effective_response_format),
608+
}
609+
except Exception as error:
610+
# Try retry_model_request on each middleware
611+
for m in middleware_w_retry:
612+
if retry_request := await m.aretry_model_request(
613+
error, request, state, runtime, attempt
614+
):
615+
# Break on first middleware that wants to retry
616+
request = retry_request
617+
break
618+
else:
619+
# If no middleware wants to retry, re-raise the error
620+
raise
621+
622+
# If we exit the loop, max attempts exceeded
623+
msg = f"Maximum retry attempts ({max_attempts}) exceeded"
624+
raise RuntimeError(msg)
569625

570626
# Use sync or async based on model capabilities
571627
from langgraph._internal._runnable import RunnableCallable

0 commit comments

Comments
 (0)