Skip to content

Commit 90edcda

Browse files
vblagojesjrlAmnah199
authored
feat: Add FallbackChatGenerator (#9859)
* Add FallbackChatGenerator * Update licence files * Use typing.Optional/Union for Python 3.9 compat * Use the right logger * Lint fix * PR review * Rewrite release note * Add FallbackChatGenerator to docs * Update haystack/components/generators/chat/fallback.py Co-authored-by: Sebastian Husch Lee <[email protected]> * Rename generator -> chat_generators * Lint * Rename generators -> chat_generators in meta, docs, tests * Update haystack/components/generators/chat/fallback.py Co-authored-by: Amna Mubashar <[email protected]> * Update pydocs * Minor pydocs fix --------- Co-authored-by: Sebastian Husch Lee <[email protected]> Co-authored-by: Amna Mubashar <[email protected]>
1 parent a43c47b commit 90edcda

File tree

6 files changed

+589
-0
lines changed

6 files changed

+589
-0
lines changed

docs/pydoc/config/generators_api.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ loaders:
1212
"chat/hugging_face_local",
1313
"chat/hugging_face_api",
1414
"chat/openai",
15+
"chat/fallback",
1516
]
1617
ignore_when_discovered: ["__init__"]
1718
processors:

docs/pydoc/config_docusaurus/generators_api.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ loaders:
1212
"chat/hugging_face_local",
1313
"chat/hugging_face_api",
1414
"chat/openai",
15+
"chat/fallback",
1516
]
1617
ignore_when_discovered: ["__init__"]
1718
processors:

haystack/components/generators/chat/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212
"azure": ["AzureOpenAIChatGenerator"],
1313
"hugging_face_local": ["HuggingFaceLocalChatGenerator"],
1414
"hugging_face_api": ["HuggingFaceAPIChatGenerator"],
15+
"fallback": ["FallbackChatGenerator"],
1516
}
1617

1718
if TYPE_CHECKING:
1819
from .azure import AzureOpenAIChatGenerator as AzureOpenAIChatGenerator
20+
from .fallback import FallbackChatGenerator as FallbackChatGenerator
1921
from .hugging_face_api import HuggingFaceAPIChatGenerator as HuggingFaceAPIChatGenerator
2022
from .hugging_face_local import HuggingFaceLocalChatGenerator as HuggingFaceLocalChatGenerator
2123
from .openai import OpenAIChatGenerator as OpenAIChatGenerator
Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from __future__ import annotations
6+
7+
import asyncio
8+
from typing import Any, Union
9+
10+
from haystack import component, default_from_dict, default_to_dict, logging
11+
from haystack.components.generators.chat.types import ChatGenerator
12+
from haystack.dataclasses import ChatMessage, StreamingCallbackT
13+
from haystack.tools import Tool, Toolset
14+
from haystack.utils.deserialization import deserialize_component_inplace
15+
16+
logger = logging.getLogger(__name__)
17+
18+
19+
@component
20+
class FallbackChatGenerator:
21+
"""
22+
A chat generator wrapper that tries multiple chat generators sequentially.
23+
24+
It forwards all parameters transparently to the underlying chat generators and returns the first successful result.
25+
Calls chat generators sequentially until one succeeds. Falls back on any exception raised by a generator.
26+
If all chat generators fail, it raises a RuntimeError with details.
27+
28+
Timeout enforcement is fully delegated to the underlying chat generators. The fallback mechanism will only
29+
work correctly if the underlying chat generators implement proper timeout handling and raise exceptions
30+
when timeouts occur. For predictable latency guarantees, ensure your chat generators:
31+
- Support a `timeout` parameter in their initialization
32+
- Implement timeout as total wall-clock time (shared deadline for both streaming and non-streaming)
33+
- Raise timeout exceptions (e.g., TimeoutError, asyncio.TimeoutError, httpx.TimeoutException) when exceeded
34+
35+
Note: Most well-implemented chat generators (OpenAI, Anthropic, Cohere, etc.) support timeout parameters
36+
with consistent semantics. For HTTP-based LLM providers, a single timeout value (e.g., `timeout=30`)
37+
typically applies to all connection phases: connection setup, read, write, and pool. For streaming
38+
responses, read timeout is the maximum gap between chunks. For non-streaming, it's the time limit for
39+
receiving the complete response.
40+
41+
Failover is automatically triggered when a generator raises any exception, including:
42+
- Timeout errors (if the generator implements and raises them)
43+
- Rate limit errors (429)
44+
- Authentication errors (401)
45+
- Context length errors (400)
46+
- Server errors (500+)
47+
- Any other exception
48+
"""
49+
50+
def __init__(self, chat_generators: list[ChatGenerator]):
51+
"""
52+
Creates an instance of FallbackChatGenerator.
53+
54+
:param chat_generators: A non-empty list of chat generator components to try in order.
55+
"""
56+
if not chat_generators:
57+
msg = "'chat_generators' must be a non-empty list"
58+
raise ValueError(msg)
59+
60+
self.chat_generators = list(chat_generators)
61+
62+
def to_dict(self) -> dict[str, Any]:
63+
"""Serialize the component, including nested chat generators when they support serialization."""
64+
return default_to_dict(
65+
self, chat_generators=[gen.to_dict() for gen in self.chat_generators if hasattr(gen, "to_dict")]
66+
)
67+
68+
@classmethod
69+
def from_dict(cls, data: dict[str, Any]) -> FallbackChatGenerator:
70+
"""Rebuild the component from a serialized representation, restoring nested chat generators."""
71+
# Reconstruct nested chat generators from their serialized dicts
72+
init_params = data.get("init_parameters", {})
73+
serialized = init_params.get("chat_generators") or []
74+
deserialized: list[Any] = []
75+
for g in serialized:
76+
# Use the generic component deserializer available in Haystack
77+
holder = {"component": g}
78+
deserialize_component_inplace(holder, key="component")
79+
deserialized.append(holder["component"])
80+
init_params["chat_generators"] = deserialized
81+
data["init_parameters"] = init_params
82+
return default_from_dict(cls, data)
83+
84+
def _run_single_sync( # pylint: disable=too-many-positional-arguments
85+
self,
86+
gen: Any,
87+
messages: list[ChatMessage],
88+
generation_kwargs: Union[dict[str, Any], None],
89+
tools: Union[list[Tool], Toolset, None],
90+
streaming_callback: Union[StreamingCallbackT, None],
91+
) -> dict[str, Any]:
92+
return gen.run(
93+
messages=messages, generation_kwargs=generation_kwargs, tools=tools, streaming_callback=streaming_callback
94+
)
95+
96+
async def _run_single_async( # pylint: disable=too-many-positional-arguments
97+
self,
98+
gen: Any,
99+
messages: list[ChatMessage],
100+
generation_kwargs: Union[dict[str, Any], None],
101+
tools: Union[list[Tool], Toolset, None],
102+
streaming_callback: Union[StreamingCallbackT, None],
103+
) -> dict[str, Any]:
104+
if hasattr(gen, "run_async") and callable(gen.run_async):
105+
return await gen.run_async(
106+
messages=messages,
107+
generation_kwargs=generation_kwargs,
108+
tools=tools,
109+
streaming_callback=streaming_callback,
110+
)
111+
return await asyncio.to_thread(
112+
gen.run,
113+
messages=messages,
114+
generation_kwargs=generation_kwargs,
115+
tools=tools,
116+
streaming_callback=streaming_callback,
117+
)
118+
119+
@component.output_types(replies=list[ChatMessage], meta=dict[str, Any])
120+
def run(
121+
self,
122+
messages: list[ChatMessage],
123+
generation_kwargs: Union[dict[str, Any], None] = None,
124+
tools: Union[list[Tool], Toolset, None] = None,
125+
streaming_callback: Union[StreamingCallbackT, None] = None,
126+
) -> dict[str, Any]:
127+
"""
128+
Execute chat generators sequentially until one succeeds.
129+
130+
:param messages: The conversation history as a list of ChatMessage instances.
131+
:param generation_kwargs: Optional parameters for the chat generator (e.g., temperature, max_tokens).
132+
:param tools: Optional Tool instances or Toolset for function calling capabilities.
133+
:param streaming_callback: Optional callable for handling streaming responses.
134+
:returns: A dictionary with:
135+
- "replies": Generated ChatMessage instances from the first successful generator.
136+
- "meta": Execution metadata including successful_chat_generator_index, successful_chat_generator_class,
137+
total_attempts, failed_chat_generators, plus any metadata from the successful generator.
138+
:raises RuntimeError: If all chat generators fail.
139+
"""
140+
failed: list[str] = []
141+
last_error: Union[BaseException, None] = None
142+
143+
for idx, gen in enumerate(self.chat_generators):
144+
gen_name = gen.__class__.__name__
145+
try:
146+
result = self._run_single_sync(gen, messages, generation_kwargs, tools, streaming_callback)
147+
replies = result.get("replies", [])
148+
meta = dict(result.get("meta", {}))
149+
meta.update(
150+
{
151+
"successful_chat_generator_index": idx,
152+
"successful_chat_generator_class": gen_name,
153+
"total_attempts": idx + 1,
154+
"failed_chat_generators": failed,
155+
}
156+
)
157+
return {"replies": replies, "meta": meta}
158+
except Exception as e: # noqa: BLE001 - fallback logic should handle any exception
159+
logger.warning(
160+
"ChatGenerator {chat_generator} failed with error: {error}", chat_generator=gen_name, error=e
161+
)
162+
failed.append(gen_name)
163+
last_error = e
164+
165+
failed_names = ", ".join(failed)
166+
msg = (
167+
f"All {len(self.chat_generators)} chat generators failed. "
168+
f"Last error: {last_error}. Failed chat generators: [{failed_names}]"
169+
)
170+
raise RuntimeError(msg)
171+
172+
@component.output_types(replies=list[ChatMessage], meta=dict[str, Any])
173+
async def run_async(
174+
self,
175+
messages: list[ChatMessage],
176+
generation_kwargs: Union[dict[str, Any], None] = None,
177+
tools: Union[list[Tool], Toolset, None] = None,
178+
streaming_callback: Union[StreamingCallbackT, None] = None,
179+
) -> dict[str, Any]:
180+
"""
181+
Asynchronously execute chat generators sequentially until one succeeds.
182+
183+
:param messages: The conversation history as a list of ChatMessage instances.
184+
:param generation_kwargs: Optional parameters for the chat generator (e.g., temperature, max_tokens).
185+
:param tools: Optional Tool instances or Toolset for function calling capabilities.
186+
:param streaming_callback: Optional callable for handling streaming responses.
187+
:returns: A dictionary with:
188+
- "replies": Generated ChatMessage instances from the first successful generator.
189+
- "meta": Execution metadata including successful_chat_generator_index, successful_chat_generator_class,
190+
total_attempts, failed_chat_generators, plus any metadata from the successful generator.
191+
:raises RuntimeError: If all chat generators fail.
192+
"""
193+
failed: list[str] = []
194+
last_error: Union[BaseException, None] = None
195+
196+
for idx, gen in enumerate(self.chat_generators):
197+
gen_name = gen.__class__.__name__
198+
try:
199+
result = await self._run_single_async(gen, messages, generation_kwargs, tools, streaming_callback)
200+
replies = result.get("replies", [])
201+
meta = dict(result.get("meta", {}))
202+
meta.update(
203+
{
204+
"successful_chat_generator_index": idx,
205+
"successful_chat_generator_class": gen_name,
206+
"total_attempts": idx + 1,
207+
"failed_chat_generators": failed,
208+
}
209+
)
210+
return {"replies": replies, "meta": meta}
211+
except Exception as e: # noqa: BLE001 - fallback logic should handle any exception
212+
logger.warning(
213+
"ChatGenerator {chat_generator} failed with error: {error}", chat_generator=gen_name, error=e
214+
)
215+
failed.append(gen_name)
216+
last_error = e
217+
218+
failed_names = ", ".join(failed)
219+
msg = (
220+
f"All {len(self.chat_generators)} chat generators failed. "
221+
f"Last error: {last_error}. Failed chat generators: [{failed_names}]"
222+
)
223+
raise RuntimeError(msg)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
---
2+
highlights: >
3+
Introduced `FallbackChatGenerator` that tries multiple chat providers one by one, improving reliability in production and making sure you get answers even when some provider fails.
4+
features:
5+
- |
6+
Added `FallbackChatGenerator` that automatically retries different chat generators and returns first successful response with detailed information about which providers were tried.

0 commit comments

Comments
 (0)