Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions haystack/components/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,6 @@ def run( # noqa: PLR0915
- "messages": List of all messages exchanged during the agent's run.
- "last_message": The last message exchanged during the agent's run.
- Any additional keys defined in the `state_schema`.
:raises RuntimeError: If the Agent component wasn't warmed up before calling `run()`.
:raises BreakpointException: If an agent breakpoint is triggered.
"""
# We pop parent_snapshot from kwargs to avoid passing it into State.
Expand Down Expand Up @@ -722,7 +721,6 @@ async def run_async(
- "messages": List of all messages exchanged during the agent's run.
- "last_message": The last message exchanged during the agent's run.
- Any additional keys defined in the `state_schema`.
:raises RuntimeError: If the Agent component wasn't warmed up before calling `run_async()`.
:raises BreakpointException: If an agent breakpoint is triggered.
"""
# We pop parent_snapshot from kwargs to avoid passing it into State.
Expand Down
8 changes: 4 additions & 4 deletions haystack/components/audio/whisper_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,7 @@ def run(self, sources: list[Union[str, Path, ByteStream]], whisper_params: Optio
for the transcription.
"""
if self._model is None:
raise RuntimeError(
"The component LocalWhisperTranscriber was not warmed up. Run 'warm_up()' before calling 'run()'."
)
self.warm_up()

if whisper_params is None:
whisper_params = self.whisper_params
Expand Down Expand Up @@ -172,7 +170,9 @@ def _raw_transcribe(self, sources: list[Union[str, Path, ByteStream]], **kwargs)
A dictionary mapping 'file_path' to 'transcription'.
"""
if self._model is None:
raise RuntimeError("Model is not loaded, please run 'warm_up()' before calling 'run()'")
self.warm_up()
# To make mypy happy even though this is set in warm_up()
assert self._model is not None

return_segments = kwargs.pop("return_segments", False)
transcriptions = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,9 @@ def run(self, documents: list[Document], batch_size: int = 1):
"""

if self.pipeline is None:
raise RuntimeError(
"The component TransformerZeroShotDocumentClassifier wasn't warmed up. "
"Run 'warm_up()' before calling 'run()'."
)
self.warm_up()
# To make mypy happy even though this is set in warm_up()
assert self.pipeline is not None

if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
raise TypeError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,9 @@ def run(self, documents: list[Document]) -> dict[str, list[Document]]:
"In case you want to embed a string, please use the SentenceTransformersTextEmbedder."
)
if self._embedding_backend is None:
raise RuntimeError("The embedding model has not been loaded. Please call warm_up() before running.")
self.warm_up()
# To make mypy happy even though this is set in warm_up()
assert self._embedding_backend is not None

images_source_info = _extract_image_sources_info(
documents=documents, file_path_meta_field=self.file_path_meta_field, root_path=self.root_path
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class SentenceTransformersDocumentEmbedder:
It stores the embeddings in the `embedding` metadata field of each document.
You can also embed documents' metadata.
Use this component in indexing pipelines to embed input documents
and send them to DocumentWriter to write a into a Document Store.
and send them to DocumentWriter to write into a Document Store.

### Usage example:

Expand Down Expand Up @@ -244,7 +244,9 @@ def run(self, documents: list[Document]):
"In case you want to embed a string, please use the SentenceTransformersTextEmbedder."
)
if self.embedding_backend is None:
raise RuntimeError("The embedding model has not been loaded. Please call warm_up() before running.")
self.warm_up()
# To make mypy happy even though this is set in warm_up()
assert self.embedding_backend is not None

texts_to_embed = []
for doc in documents:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,9 @@ def run(self, documents: list[Document]):
"In case you want to embed a list of strings, please use the SentenceTransformersSparseTextEmbedder."
)
if self.embedding_backend is None:
raise RuntimeError("The embedding model has not been loaded. Please call warm_up() before running.")
self.warm_up()
# To make mypy happy even though this is set in warm_up()
assert self.embedding_backend is not None

texts_to_embed = []
for doc in documents:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,9 @@ def run(self, text: str):
"SentenceTransformersSparseDocumentEmbedder."
)
if self.embedding_backend is None:
raise RuntimeError("The embedding model has not been loaded. Please call warm_up() before running.")
self.warm_up()
# To make mypy happy even though this is set in warm_up()
assert self.embedding_backend is not None

text_to_embed = self.prefix + text + self.suffix

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,9 @@ def run(self, text: str):
"In case you want to embed a list of Documents, please use the SentenceTransformersDocumentEmbedder."
)
if self.embedding_backend is None:
raise RuntimeError("The embedding model has not been loaded. Please call warm_up() before running.")
self.warm_up()
# To make mypy happy even though this is set in warm_up()
assert self.embedding_backend is not None

text_to_embed = self.prefix + text + self.suffix
embedding = self.embedding_backend.embed(
Expand Down
14 changes: 14 additions & 0 deletions haystack/components/evaluators/llm_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,17 @@ def __init__( # pylint: disable=too-many-positional-arguments
generation_kwargs = {"response_format": {"type": "json_object"}, "seed": 42}
self._chat_generator = OpenAIChatGenerator(generation_kwargs=generation_kwargs)

self._is_warmed_up = False

def warm_up(self):
"""
Warm up the component by warming up the underlying chat generator.
"""
if not self._is_warmed_up:
if hasattr(self._chat_generator, "warm_up"):
self._chat_generator.warm_up()
self._is_warmed_up = True

@staticmethod
def validate_init_parameters(
inputs: list[tuple[str, type[list]]], outputs: list[str], examples: list[dict[str, Any]]
Expand Down Expand Up @@ -181,6 +192,9 @@ def run(self, **inputs) -> dict[str, Any]:
Only in the case that `raise_on_failure` is set to True and the received inputs are not lists or have
different lengths, or if the output is not a valid JSON or doesn't contain the expected keys.
"""
if not self._is_warmed_up:
self.warm_up()

self.validate_input_parameters(dict(self.inputs), inputs)

# inputs is a dictionary with keys being input names and values being a list of input values
Expand Down
3 changes: 1 addition & 2 deletions haystack/components/evaluators/sas_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,7 @@ def run(self, ground_truth_answers: list[str], predicted_answers: list[str]) ->
return {"score": 0.0, "individual_scores": [0.0]}

if not self._similarity_model:
msg = "The model has not been initialized. Call warm_up() before running the evaluator."
raise RuntimeError(msg)
self.warm_up()

if isinstance(self._similarity_model, CrossEncoder):
# For Cross Encoders we create a list of pairs of predictions and labels
Expand Down
3 changes: 1 addition & 2 deletions haystack/components/extractors/named_entity_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,7 @@ def run(self, documents: list[Document], batch_size: int = 1) -> dict[str, Any]:
If the backend fails to process a document.
"""
if not self._warmed_up:
msg = "The component NamedEntityExtractor was not warmed up. Call warm_up() before running the component."
raise RuntimeError(msg)
self.warm_up()

texts = [doc.content if doc.content is not None else "" for doc in documents]
annotations = self._backend.annotate(texts, batch_size=batch_size)
Expand Down
18 changes: 15 additions & 3 deletions haystack/components/generators/chat/fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class FallbackChatGenerator:
- Any other exception
"""

def __init__(self, chat_generators: list[ChatGenerator]):
def __init__(self, chat_generators: list[ChatGenerator]) -> None:
"""
Creates an instance of FallbackChatGenerator.

Expand All @@ -58,6 +58,7 @@ def __init__(self, chat_generators: list[ChatGenerator]):
raise ValueError(msg)

self.chat_generators = list(chat_generators)
self._is_warmed_up = False

def to_dict(self) -> dict[str, Any]:
"""Serialize the component, including nested chat generators when they support serialization."""
Expand Down Expand Up @@ -87,10 +88,15 @@ def warm_up(self) -> None:

This method calls warm_up() on each underlying generator that supports it.
"""
if self._is_warmed_up:
return

for gen in self.chat_generators:
if hasattr(gen, "warm_up") and callable(gen.warm_up):
gen.warm_up()

self._is_warmed_up = True

def _run_single_sync( # pylint: disable=too-many-positional-arguments
self,
gen: Any,
Expand Down Expand Up @@ -133,7 +139,7 @@ def run(
generation_kwargs: Union[dict[str, Any], None] = None,
tools: Optional[ToolsType] = None,
streaming_callback: Union[StreamingCallbackT, None] = None,
) -> dict[str, Any]:
) -> dict[str, Union[list[ChatMessage], dict[str, Any]]]:
"""
Execute chat generators sequentially until one succeeds.

Expand All @@ -147,6 +153,9 @@ def run(
total_attempts, failed_chat_generators, plus any metadata from the successful generator.
:raises RuntimeError: If all chat generators fail.
"""
if not self._is_warmed_up:
self.warm_up()

failed: list[str] = []
last_error: Union[BaseException, None] = None

Expand Down Expand Up @@ -186,7 +195,7 @@ async def run_async(
generation_kwargs: Union[dict[str, Any], None] = None,
tools: Optional[ToolsType] = None,
streaming_callback: Union[StreamingCallbackT, None] = None,
) -> dict[str, Any]:
) -> dict[str, Union[list[ChatMessage], dict[str, Any]]]:
"""
Asynchronously execute chat generators sequentially until one succeeds.

Expand All @@ -200,6 +209,9 @@ async def run_async(
total_attempts, failed_chat_generators, plus any metadata from the successful generator.
:raises RuntimeError: If all chat generators fail.
"""
if not self._is_warmed_up:
self.warm_up()

failed: list[str] = []
last_error: Union[BaseException, None] = None

Expand Down
8 changes: 6 additions & 2 deletions haystack/components/generators/chat/hugging_face_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def run(
generation_kwargs: Optional[dict[str, Any]] = None,
tools: Optional[ToolsType] = None,
streaming_callback: Optional[StreamingCallbackT] = None,
):
) -> dict[str, list[ChatMessage]]:
"""
Invoke the text generation inference based on the provided messages and generation parameters.

Expand All @@ -454,6 +454,8 @@ def run(
:returns: A dictionary with the following keys:
- `replies`: A list containing the generated responses as ChatMessage objects.
"""
if not self._is_warmed_up:
self.warm_up()

# update generation kwargs by merging with the default ones
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
Expand Down Expand Up @@ -485,7 +487,7 @@ async def run_async(
generation_kwargs: Optional[dict[str, Any]] = None,
tools: Optional[ToolsType] = None,
streaming_callback: Optional[StreamingCallbackT] = None,
):
) -> dict[str, list[ChatMessage]]:
"""
Asynchronously invokes the text generation inference based on the provided messages and generation parameters.

Expand All @@ -506,6 +508,8 @@ async def run_async(
:returns: A dictionary with the following keys:
- `replies`: A list containing the generated responses as ChatMessage objects.
"""
if not self._is_warmed_up:
self.warm_up()

# update generation kwargs by merging with the default ones
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
Expand Down
13 changes: 9 additions & 4 deletions haystack/components/generators/chat/hugging_face_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,9 @@ def run(
:returns: A dictionary with the following keys:
- `replies`: A list containing the generated responses as ChatMessage instances.
"""
if self.pipeline is None:
self.warm_up()

prepared_inputs = self._prepare_inputs(
messages=messages, generation_kwargs=generation_kwargs, streaming_callback=streaming_callback, tools=tools
)
Expand Down Expand Up @@ -475,6 +478,9 @@ async def run_async(
:returns: A dictionary with the following keys:
- `replies`: A list containing the generated responses as ChatMessage instances.
"""
if self.pipeline is None:
self.warm_up()

prepared_inputs = self._prepare_inputs(
messages=messages, generation_kwargs=generation_kwargs, streaming_callback=streaming_callback, tools=tools
)
Expand Down Expand Up @@ -542,18 +548,17 @@ def _prepare_inputs(
:param streaming_callback: An optional callable for handling streaming responses.
:param tools: A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
:returns: A dictionary containing the prepared prompt, tokenizer, generation kwargs, and tools.
:raises RuntimeError: If the generation model has not been loaded.
:raises ValueError: If both tools and streaming_callback are provided.
"""
if self.pipeline is None:
raise RuntimeError("The generation model has not been loaded. Please call warm_up() before running.")

tools = tools or self.tools
if tools and streaming_callback is not None:
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
flat_tools = flatten_tools_or_toolsets(tools)
_check_duplicate_tool_names(flat_tools)

# To make mypy happy even though this is set in warm_up()
assert self.pipeline is not None

tokenizer = self.pipeline.tokenizer
# initialized text-generation/text2text-generation pipelines always have a non-None tokenizer
assert tokenizer is not None
Expand Down
10 changes: 8 additions & 2 deletions haystack/components/generators/chat/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def run(
*,
tools: Optional[ToolsType] = None,
tools_strict: Optional[bool] = None,
):
) -> dict[str, list[ChatMessage]]:
"""
Invokes chat completion based on the provided messages and generation parameters.

Expand All @@ -310,6 +310,9 @@ def run(
A dictionary with the following key:
- `replies`: A list containing the generated responses as ChatMessage instances.
"""
if not self._is_warmed_up:
self.warm_up()

if len(messages) == 0:
return {"replies": []}

Expand Down Expand Up @@ -358,7 +361,7 @@ async def run_async(
*,
tools: Optional[ToolsType] = None,
tools_strict: Optional[bool] = None,
):
) -> dict[str, list[ChatMessage]]:
"""
Asynchronously invokes chat completion based on the provided messages and generation parameters.

Expand All @@ -385,6 +388,9 @@ async def run_async(
A dictionary with the following key:
- `replies`: A list containing the generated responses as ChatMessage instances.
"""
if not self._is_warmed_up:
self.warm_up()

# validate and select the streaming callback
streaming_callback = select_streaming_callback(
init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=True
Expand Down
10 changes: 8 additions & 2 deletions haystack/components/generators/chat/openai_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def run(
generation_kwargs: Optional[dict[str, Any]] = None,
tools: Optional[Union[ToolsType, list[dict]]] = None,
tools_strict: Optional[bool] = None,
):
) -> dict[str, list[ChatMessage]]:
"""
Invokes response generation based on the provided messages and generation parameters.

Expand Down Expand Up @@ -326,6 +326,9 @@ def run(
A dictionary with the following key:
- `replies`: A list containing the generated responses as ChatMessage instances.
"""
if not self._is_warmed_up:
self.warm_up()

if len(messages) == 0:
return {"replies": []}

Expand Down Expand Up @@ -364,7 +367,7 @@ async def run_async(
generation_kwargs: Optional[dict[str, Any]] = None,
tools: Optional[Union[ToolsType, list[dict]]] = None,
tools_strict: Optional[bool] = None,
):
) -> dict[str, list[ChatMessage]]:
"""
Asynchronously invokes response generation based on the provided messages and generation parameters.

Expand Down Expand Up @@ -395,6 +398,9 @@ async def run_async(
A dictionary with the following key:
- `replies`: A list containing the generated responses as ChatMessage instances.
"""
if not self._is_warmed_up:
self.warm_up()

# validate and select the streaming callback
streaming_callback = select_streaming_callback(
init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=True
Expand Down
4 changes: 1 addition & 3 deletions haystack/components/generators/hugging_face_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,7 @@ def run(
- replies: A list of strings representing the generated replies.
"""
if not self._warmed_up:
raise RuntimeError(
"The component HuggingFaceLocalGenerator was not warmed up. Please call warm_up() before running."
)
self.warm_up()

# at this point, we know that the pipeline has been initialized
assert self.pipeline is not None
Expand Down
Loading
Loading