diff --git a/haystack/components/agents/agent.py b/haystack/components/agents/agent.py index b5e5b1a1e4..68c9fc0907 100644 --- a/haystack/components/agents/agent.py +++ b/haystack/components/agents/agent.py @@ -554,7 +554,6 @@ def run( # noqa: PLR0915 - "messages": List of all messages exchanged during the agent's run. - "last_message": The last message exchanged during the agent's run. - Any additional keys defined in the `state_schema`. - :raises RuntimeError: If the Agent component wasn't warmed up before calling `run()`. :raises BreakpointException: If an agent breakpoint is triggered. """ # We pop parent_snapshot from kwargs to avoid passing it into State. @@ -722,7 +721,6 @@ async def run_async( - "messages": List of all messages exchanged during the agent's run. - "last_message": The last message exchanged during the agent's run. - Any additional keys defined in the `state_schema`. - :raises RuntimeError: If the Agent component wasn't warmed up before calling `run_async()`. :raises BreakpointException: If an agent breakpoint is triggered. """ # We pop parent_snapshot from kwargs to avoid passing it into State. diff --git a/haystack/components/audio/whisper_local.py b/haystack/components/audio/whisper_local.py index 66bf2a3d0d..cb1d332b99 100644 --- a/haystack/components/audio/whisper_local.py +++ b/haystack/components/audio/whisper_local.py @@ -127,9 +127,7 @@ def run(self, sources: list[Union[str, Path, ByteStream]], whisper_params: Optio for the transcription. """ if self._model is None: - raise RuntimeError( - "The component LocalWhisperTranscriber was not warmed up. Run 'warm_up()' before calling 'run()'." - ) + self.warm_up() if whisper_params is None: whisper_params = self.whisper_params @@ -158,6 +156,19 @@ def transcribe(self, sources: list[Union[str, Path, ByteStream]], **kwargs) -> l documents.append(doc) return documents + def _get_path(self, source: Union[str, Path, ByteStream]) -> Path: + if isinstance(source, (Path, str)): + return Path(source) + + potential_path = source.meta.get("file_path") + if potential_path is not None: + return Path(potential_path) + else: + with tempfile.NamedTemporaryFile(delete=False) as fp: + path = Path(fp.name) + source.to_file(path) + return path + def _raw_transcribe(self, sources: list[Union[str, Path, ByteStream]], **kwargs) -> dict[Path, Any]: """ Transcribes the given audio files. Returns the output of the model, a dictionary, for each input file. @@ -172,20 +183,16 @@ def _raw_transcribe(self, sources: list[Union[str, Path, ByteStream]], **kwargs) A dictionary mapping 'file_path' to 'transcription'. """ if self._model is None: - raise RuntimeError("Model is not loaded, please run 'warm_up()' before calling 'run()'") + self.warm_up() return_segments = kwargs.pop("return_segments", False) transcriptions = {} for source in sources: - path = Path(source) if not isinstance(source, ByteStream) else source.meta.get("file_path") - - if isinstance(source, ByteStream) and path is None: - with tempfile.NamedTemporaryFile(delete=False) as fp: - path = Path(fp.name) - source.to_file(path) + path = self._get_path(source) - transcription = self._model.transcribe(str(path), **kwargs) + # mypy doesn't know this is set in warm_up + transcription = self._model.transcribe(str(path), **kwargs) # type: ignore[attr-defined] if not return_segments: transcription.pop("segments", None) diff --git a/haystack/components/classifiers/zero_shot_document_classifier.py b/haystack/components/classifiers/zero_shot_document_classifier.py index 94794e816b..0d1d788cd4 100644 --- a/haystack/components/classifiers/zero_shot_document_classifier.py +++ b/haystack/components/classifiers/zero_shot_document_classifier.py @@ -203,10 +203,7 @@ def run(self, documents: list[Document], batch_size: int = 1): """ if self.pipeline is None: - raise RuntimeError( - "The component TransformerZeroShotDocumentClassifier wasn't warmed up. " - "Run 'warm_up()' before calling 'run()'." - ) + self.warm_up() if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): raise TypeError( @@ -231,7 +228,10 @@ def run(self, documents: list[Document], batch_size: int = 1): for doc in documents ] - predictions = self.pipeline(texts, self.labels, multi_label=self.multi_label, batch_size=batch_size) + # mypy doesn't know this is set in warm_up + predictions = self.pipeline( # type: ignore[misc] + texts, self.labels, multi_label=self.multi_label, batch_size=batch_size + ) new_documents = [] for prediction, document in zip(predictions, documents): diff --git a/haystack/components/embedders/image/sentence_transformers_doc_image_embedder.py b/haystack/components/embedders/image/sentence_transformers_doc_image_embedder.py index a78eb9bab8..f174cd3f25 100644 --- a/haystack/components/embedders/image/sentence_transformers_doc_image_embedder.py +++ b/haystack/components/embedders/image/sentence_transformers_doc_image_embedder.py @@ -237,7 +237,7 @@ def run(self, documents: list[Document]) -> dict[str, list[Document]]: "In case you want to embed a string, please use the SentenceTransformersTextEmbedder." ) if self._embedding_backend is None: - raise RuntimeError("The embedding model has not been loaded. Please call warm_up() before running.") + self.warm_up() images_source_info = _extract_image_sources_info( documents=documents, file_path_meta_field=self.file_path_meta_field, root_path=self.root_path @@ -270,7 +270,8 @@ def run(self, documents: list[Document]) -> dict[str, list[Document]]: if none_images_doc_ids: raise RuntimeError(f"Conversion failed for some documents. Document IDs: {none_images_doc_ids}.") - embeddings = self._embedding_backend.embed( + # mypy doesn't know this is set in warm_up + embeddings = self._embedding_backend.embed( # type: ignore[union-attr] data=images_to_embed, batch_size=self.batch_size, show_progress_bar=self.progress_bar, diff --git a/haystack/components/embedders/sentence_transformers_document_embedder.py b/haystack/components/embedders/sentence_transformers_document_embedder.py index de92d8184b..3ed1e7b12a 100644 --- a/haystack/components/embedders/sentence_transformers_document_embedder.py +++ b/haystack/components/embedders/sentence_transformers_document_embedder.py @@ -22,7 +22,7 @@ class SentenceTransformersDocumentEmbedder: It stores the embeddings in the `embedding` metadata field of each document. You can also embed documents' metadata. Use this component in indexing pipelines to embed input documents - and send them to DocumentWriter to write a into a Document Store. + and send them to DocumentWriter to write into a Document Store. ### Usage example: @@ -244,7 +244,7 @@ def run(self, documents: list[Document]): "In case you want to embed a string, please use the SentenceTransformersTextEmbedder." ) if self.embedding_backend is None: - raise RuntimeError("The embedding model has not been loaded. Please call warm_up() before running.") + self.warm_up() texts_to_embed = [] for doc in documents: @@ -256,7 +256,8 @@ def run(self, documents: list[Document]): ) texts_to_embed.append(text_to_embed) - embeddings = self.embedding_backend.embed( + # # mypy doesn't know this is set in warm_up + embeddings = self.embedding_backend.embed( # type: ignore[union-attr] texts_to_embed, batch_size=self.batch_size, show_progress_bar=self.progress_bar, diff --git a/haystack/components/embedders/sentence_transformers_sparse_document_embedder.py b/haystack/components/embedders/sentence_transformers_sparse_document_embedder.py index cc48574730..22b087254d 100644 --- a/haystack/components/embedders/sentence_transformers_sparse_document_embedder.py +++ b/haystack/components/embedders/sentence_transformers_sparse_document_embedder.py @@ -216,7 +216,7 @@ def run(self, documents: list[Document]): "In case you want to embed a list of strings, please use the SentenceTransformersSparseTextEmbedder." ) if self.embedding_backend is None: - raise RuntimeError("The embedding model has not been loaded. Please call warm_up() before running.") + self.warm_up() texts_to_embed = [] for doc in documents: @@ -228,7 +228,8 @@ def run(self, documents: list[Document]): ) texts_to_embed.append(text_to_embed) - embeddings = self.embedding_backend.embed( + # mypy doesn't know this is set in warm_up + embeddings = self.embedding_backend.embed( # type: ignore[union-attr] data=texts_to_embed, batch_size=self.batch_size, show_progress_bar=self.progress_bar ) diff --git a/haystack/components/embedders/sentence_transformers_sparse_text_embedder.py b/haystack/components/embedders/sentence_transformers_sparse_text_embedder.py index ab8da5f0e0..314ba09eab 100644 --- a/haystack/components/embedders/sentence_transformers_sparse_text_embedder.py +++ b/haystack/components/embedders/sentence_transformers_sparse_text_embedder.py @@ -192,10 +192,11 @@ def run(self, text: str): "SentenceTransformersSparseDocumentEmbedder." ) if self.embedding_backend is None: - raise RuntimeError("The embedding model has not been loaded. Please call warm_up() before running.") + self.warm_up() text_to_embed = self.prefix + text + self.suffix - sparse_embedding = self.embedding_backend.embed(data=[text_to_embed])[0] + # mypy doesn't know this is set in warm_up + sparse_embedding = self.embedding_backend.embed(data=[text_to_embed])[0] # type: ignore[union-attr] return {"sparse_embedding": sparse_embedding} diff --git a/haystack/components/embedders/sentence_transformers_text_embedder.py b/haystack/components/embedders/sentence_transformers_text_embedder.py index da1e7a4b38..4924daeaf3 100644 --- a/haystack/components/embedders/sentence_transformers_text_embedder.py +++ b/haystack/components/embedders/sentence_transformers_text_embedder.py @@ -229,10 +229,12 @@ def run(self, text: str): "In case you want to embed a list of Documents, please use the SentenceTransformersDocumentEmbedder." ) if self.embedding_backend is None: - raise RuntimeError("The embedding model has not been loaded. Please call warm_up() before running.") + self.warm_up() text_to_embed = self.prefix + text + self.suffix - embedding = self.embedding_backend.embed( + + # mypy doesn't know this is set in warm_up + embedding = self.embedding_backend.embed( # type: ignore[union-attr] [text_to_embed], batch_size=self.batch_size, show_progress_bar=self.progress_bar, diff --git a/haystack/components/evaluators/llm_evaluator.py b/haystack/components/evaluators/llm_evaluator.py index 878e5636cd..367703fd75 100644 --- a/haystack/components/evaluators/llm_evaluator.py +++ b/haystack/components/evaluators/llm_evaluator.py @@ -108,6 +108,17 @@ def __init__( # pylint: disable=too-many-positional-arguments generation_kwargs = {"response_format": {"type": "json_object"}, "seed": 42} self._chat_generator = OpenAIChatGenerator(generation_kwargs=generation_kwargs) + self._is_warmed_up = False + + def warm_up(self): + """ + Warm up the component by warming up the underlying chat generator. + """ + if not self._is_warmed_up: + if hasattr(self._chat_generator, "warm_up"): + self._chat_generator.warm_up() + self._is_warmed_up = True + @staticmethod def validate_init_parameters( inputs: list[tuple[str, type[list]]], outputs: list[str], examples: list[dict[str, Any]] @@ -181,6 +192,9 @@ def run(self, **inputs) -> dict[str, Any]: Only in the case that `raise_on_failure` is set to True and the received inputs are not lists or have different lengths, or if the output is not a valid JSON or doesn't contain the expected keys. """ + if not self._is_warmed_up: + self.warm_up() + self.validate_input_parameters(dict(self.inputs), inputs) # inputs is a dictionary with keys being input names and values being a list of input values diff --git a/haystack/components/evaluators/sas_evaluator.py b/haystack/components/evaluators/sas_evaluator.py index b2d7d91dbf..13bfc239b5 100644 --- a/haystack/components/evaluators/sas_evaluator.py +++ b/haystack/components/evaluators/sas_evaluator.py @@ -164,8 +164,7 @@ def run(self, ground_truth_answers: list[str], predicted_answers: list[str]) -> return {"score": 0.0, "individual_scores": [0.0]} if not self._similarity_model: - msg = "The model has not been initialized. Call warm_up() before running the evaluator." - raise RuntimeError(msg) + self.warm_up() if isinstance(self._similarity_model, CrossEncoder): # For Cross Encoders we create a list of pairs of predictions and labels diff --git a/haystack/components/extractors/named_entity_extractor.py b/haystack/components/extractors/named_entity_extractor.py index 4d45e9b368..06039cc099 100644 --- a/haystack/components/extractors/named_entity_extractor.py +++ b/haystack/components/extractors/named_entity_extractor.py @@ -192,8 +192,7 @@ def run(self, documents: list[Document], batch_size: int = 1) -> dict[str, Any]: If the backend fails to process a document. """ if not self._warmed_up: - msg = "The component NamedEntityExtractor was not warmed up. Call warm_up() before running the component." - raise RuntimeError(msg) + self.warm_up() texts = [doc.content if doc.content is not None else "" for doc in documents] annotations = self._backend.annotate(texts, batch_size=batch_size) diff --git a/haystack/components/generators/chat/fallback.py b/haystack/components/generators/chat/fallback.py index 26f4f3a84f..0d216dc6f5 100644 --- a/haystack/components/generators/chat/fallback.py +++ b/haystack/components/generators/chat/fallback.py @@ -47,7 +47,7 @@ class FallbackChatGenerator: - Any other exception """ - def __init__(self, chat_generators: list[ChatGenerator]): + def __init__(self, chat_generators: list[ChatGenerator]) -> None: """ Creates an instance of FallbackChatGenerator. @@ -58,6 +58,7 @@ def __init__(self, chat_generators: list[ChatGenerator]): raise ValueError(msg) self.chat_generators = list(chat_generators) + self._is_warmed_up = False def to_dict(self) -> dict[str, Any]: """Serialize the component, including nested chat generators when they support serialization.""" @@ -87,10 +88,15 @@ def warm_up(self) -> None: This method calls warm_up() on each underlying generator that supports it. """ + if self._is_warmed_up: + return + for gen in self.chat_generators: if hasattr(gen, "warm_up") and callable(gen.warm_up): gen.warm_up() + self._is_warmed_up = True + def _run_single_sync( # pylint: disable=too-many-positional-arguments self, gen: Any, @@ -133,7 +139,7 @@ def run( generation_kwargs: Union[dict[str, Any], None] = None, tools: Optional[ToolsType] = None, streaming_callback: Union[StreamingCallbackT, None] = None, - ) -> dict[str, Any]: + ) -> dict[str, Union[list[ChatMessage], dict[str, Any]]]: """ Execute chat generators sequentially until one succeeds. @@ -147,6 +153,9 @@ def run( total_attempts, failed_chat_generators, plus any metadata from the successful generator. :raises RuntimeError: If all chat generators fail. """ + if not self._is_warmed_up: + self.warm_up() + failed: list[str] = [] last_error: Union[BaseException, None] = None @@ -186,7 +195,7 @@ async def run_async( generation_kwargs: Union[dict[str, Any], None] = None, tools: Optional[ToolsType] = None, streaming_callback: Union[StreamingCallbackT, None] = None, - ) -> dict[str, Any]: + ) -> dict[str, Union[list[ChatMessage], dict[str, Any]]]: """ Asynchronously execute chat generators sequentially until one succeeds. @@ -200,6 +209,9 @@ async def run_async( total_attempts, failed_chat_generators, plus any metadata from the successful generator. :raises RuntimeError: If all chat generators fail. """ + if not self._is_warmed_up: + self.warm_up() + failed: list[str] = [] last_error: Union[BaseException, None] = None diff --git a/haystack/components/generators/chat/hugging_face_api.py b/haystack/components/generators/chat/hugging_face_api.py index fb6d6dbb3b..0caf6d2eb7 100644 --- a/haystack/components/generators/chat/hugging_face_api.py +++ b/haystack/components/generators/chat/hugging_face_api.py @@ -436,7 +436,7 @@ def run( generation_kwargs: Optional[dict[str, Any]] = None, tools: Optional[ToolsType] = None, streaming_callback: Optional[StreamingCallbackT] = None, - ): + ) -> dict[str, list[ChatMessage]]: """ Invoke the text generation inference based on the provided messages and generation parameters. @@ -454,6 +454,8 @@ def run( :returns: A dictionary with the following keys: - `replies`: A list containing the generated responses as ChatMessage objects. """ + if not self._is_warmed_up: + self.warm_up() # update generation kwargs by merging with the default ones generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} @@ -485,7 +487,7 @@ async def run_async( generation_kwargs: Optional[dict[str, Any]] = None, tools: Optional[ToolsType] = None, streaming_callback: Optional[StreamingCallbackT] = None, - ): + ) -> dict[str, list[ChatMessage]]: """ Asynchronously invokes the text generation inference based on the provided messages and generation parameters. @@ -506,6 +508,8 @@ async def run_async( :returns: A dictionary with the following keys: - `replies`: A list containing the generated responses as ChatMessage objects. """ + if not self._is_warmed_up: + self.warm_up() # update generation kwargs by merging with the default ones generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} diff --git a/haystack/components/generators/chat/hugging_face_local.py b/haystack/components/generators/chat/hugging_face_local.py index 2296b4f7f1..d1661245b8 100644 --- a/haystack/components/generators/chat/hugging_face_local.py +++ b/haystack/components/generators/chat/hugging_face_local.py @@ -361,6 +361,9 @@ def run( :returns: A dictionary with the following keys: - `replies`: A list containing the generated responses as ChatMessage instances. """ + if self.pipeline is None: + self.warm_up() + prepared_inputs = self._prepare_inputs( messages=messages, generation_kwargs=generation_kwargs, streaming_callback=streaming_callback, tools=tools ) @@ -475,6 +478,9 @@ async def run_async( :returns: A dictionary with the following keys: - `replies`: A list containing the generated responses as ChatMessage instances. """ + if self.pipeline is None: + self.warm_up() + prepared_inputs = self._prepare_inputs( messages=messages, generation_kwargs=generation_kwargs, streaming_callback=streaming_callback, tools=tools ) @@ -542,21 +548,16 @@ def _prepare_inputs( :param streaming_callback: An optional callable for handling streaming responses. :param tools: A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls. :returns: A dictionary containing the prepared prompt, tokenizer, generation kwargs, and tools. - :raises RuntimeError: If the generation model has not been loaded. :raises ValueError: If both tools and streaming_callback are provided. """ - if self.pipeline is None: - raise RuntimeError("The generation model has not been loaded. Please call warm_up() before running.") - tools = tools or self.tools if tools and streaming_callback is not None: raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.") flat_tools = flatten_tools_or_toolsets(tools) _check_duplicate_tool_names(flat_tools) - tokenizer = self.pipeline.tokenizer - # initialized text-generation/text2text-generation pipelines always have a non-None tokenizer - assert tokenizer is not None + # mypy doesn't know this is set in warm_up + tokenizer = self.pipeline.tokenizer # type: ignore[union-attr] # Check and update generation parameters generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} @@ -577,14 +578,23 @@ def _prepare_inputs( stop_words = self._validate_stop_words(stop_words) # Set up stop words criteria if stop words exist - stop_words_criteria = StopWordsCriteria(tokenizer, stop_words, self.pipeline.device) if stop_words else None + stop_words_criteria = ( + StopWordsCriteria( + tokenizer, # type: ignore[arg-type] + stop_words, + self.pipeline.device, # type: ignore[union-attr] + ) + if stop_words + else None + ) if stop_words_criteria: generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stop_words_criteria]) # convert messages to HF format hf_messages = [convert_message_to_hf_format(message) for message in messages] - prepared_prompt = tokenizer.apply_chat_template( + # mypy doesn't know tokenizer is set in warm_up + prepared_prompt = tokenizer.apply_chat_template( # type: ignore[union-attr] hf_messages, tokenize=False, chat_template=self.chat_template, @@ -595,8 +605,9 @@ def _prepare_inputs( assert isinstance(prepared_prompt, str) # Avoid some unnecessary warnings in the generation pipeline call + # mypy doesn't know tokenizer is set in warm_up generation_kwargs["pad_token_id"] = ( - generation_kwargs.get("pad_token_id", tokenizer.pad_token_id) or tokenizer.eos_token_id + generation_kwargs.get("pad_token_id", tokenizer.pad_token_id) or tokenizer.eos_token_id # type: ignore[union-attr] ) return { diff --git a/haystack/components/generators/chat/openai.py b/haystack/components/generators/chat/openai.py index cd227e7aba..3c06df1112 100644 --- a/haystack/components/generators/chat/openai.py +++ b/haystack/components/generators/chat/openai.py @@ -286,7 +286,7 @@ def run( *, tools: Optional[ToolsType] = None, tools_strict: Optional[bool] = None, - ): + ) -> dict[str, list[ChatMessage]]: """ Invokes chat completion based on the provided messages and generation parameters. @@ -310,6 +310,9 @@ def run( A dictionary with the following key: - `replies`: A list containing the generated responses as ChatMessage instances. """ + if not self._is_warmed_up: + self.warm_up() + if len(messages) == 0: return {"replies": []} @@ -358,7 +361,7 @@ async def run_async( *, tools: Optional[ToolsType] = None, tools_strict: Optional[bool] = None, - ): + ) -> dict[str, list[ChatMessage]]: """ Asynchronously invokes chat completion based on the provided messages and generation parameters. @@ -385,6 +388,9 @@ async def run_async( A dictionary with the following key: - `replies`: A list containing the generated responses as ChatMessage instances. """ + if not self._is_warmed_up: + self.warm_up() + # validate and select the streaming callback streaming_callback = select_streaming_callback( init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=True diff --git a/haystack/components/generators/chat/openai_responses.py b/haystack/components/generators/chat/openai_responses.py index 33b6cb9c6c..91b11bb04a 100644 --- a/haystack/components/generators/chat/openai_responses.py +++ b/haystack/components/generators/chat/openai_responses.py @@ -297,7 +297,7 @@ def run( generation_kwargs: Optional[dict[str, Any]] = None, tools: Optional[Union[ToolsType, list[dict]]] = None, tools_strict: Optional[bool] = None, - ): + ) -> dict[str, list[ChatMessage]]: """ Invokes response generation based on the provided messages and generation parameters. @@ -326,6 +326,9 @@ def run( A dictionary with the following key: - `replies`: A list containing the generated responses as ChatMessage instances. """ + if not self._is_warmed_up: + self.warm_up() + if len(messages) == 0: return {"replies": []} @@ -364,7 +367,7 @@ async def run_async( generation_kwargs: Optional[dict[str, Any]] = None, tools: Optional[Union[ToolsType, list[dict]]] = None, tools_strict: Optional[bool] = None, - ): + ) -> dict[str, list[ChatMessage]]: """ Asynchronously invokes response generation based on the provided messages and generation parameters. @@ -395,6 +398,9 @@ async def run_async( A dictionary with the following key: - `replies`: A list containing the generated responses as ChatMessage instances. """ + if not self._is_warmed_up: + self.warm_up() + # validate and select the streaming callback streaming_callback = select_streaming_callback( init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=True diff --git a/haystack/components/generators/hugging_face_local.py b/haystack/components/generators/hugging_face_local.py index 2a4bb73364..ad747f0bac 100644 --- a/haystack/components/generators/hugging_face_local.py +++ b/haystack/components/generators/hugging_face_local.py @@ -229,9 +229,7 @@ def run( - replies: A list of strings representing the generated replies. """ if not self._warmed_up: - raise RuntimeError( - "The component HuggingFaceLocalGenerator was not warmed up. Please call warm_up() before running." - ) + self.warm_up() # at this point, we know that the pipeline has been initialized assert self.pipeline is not None diff --git a/haystack/components/generators/openai.py b/haystack/components/generators/openai.py index 4914e19490..48bb0b680d 100644 --- a/haystack/components/generators/openai.py +++ b/haystack/components/generators/openai.py @@ -192,7 +192,7 @@ def run( system_prompt: Optional[str] = None, streaming_callback: Optional[StreamingCallbackT] = None, generation_kwargs: Optional[dict[str, Any]] = None, - ): + ) -> dict[str, Union[list[str], list[dict[str, Any]]]]: """ Invoke the text generation inference based on the provided messages and generation parameters. @@ -265,4 +265,7 @@ def run( for response in completions: _check_finish_reason(response.meta) - return {"replies": [message.text for message in completions], "meta": [message.meta for message in completions]} + return { + "replies": [message.text or "" for message in completions], + "meta": [message.meta for message in completions], + } diff --git a/haystack/components/generators/openai_dalle.py b/haystack/components/generators/openai_dalle.py index 7a907e2ecc..e5a640312b 100644 --- a/haystack/components/generators/openai_dalle.py +++ b/haystack/components/generators/openai_dalle.py @@ -117,14 +117,12 @@ def run( to the prompt made by OpenAI. """ if self.client is None: - raise RuntimeError( - "The component DALLEImageGenerator wasn't warmed up. Run 'warm_up()' before calling 'run()'." - ) + self.warm_up() size = size or self.size quality = quality or self.quality response_format = response_format or self.response_format - response = self.client.images.generate( + response = self.client.images.generate( # type: ignore[union-attr] model=self.model, prompt=prompt, size=size, quality=quality, response_format=response_format, n=1 ) if response.data is not None: diff --git a/haystack/components/preprocessors/document_splitter.py b/haystack/components/preprocessors/document_splitter.py index 1316b9c996..6162e364c9 100644 --- a/haystack/components/preprocessors/document_splitter.py +++ b/haystack/components/preprocessors/document_splitter.py @@ -187,9 +187,7 @@ def run(self, documents: list[Document]): :raises ValueError: if the content of a document is None. """ if self._use_sentence_splitter and self.sentence_splitter is None: - raise RuntimeError( - "The component DocumentSplitter wasn't warmed up. Run 'warm_up()' before calling 'run()'." - ) + self.warm_up() if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)): raise TypeError("DocumentSplitter expects a List of Documents as input.") diff --git a/haystack/components/preprocessors/recursive_splitter.py b/haystack/components/preprocessors/recursive_splitter.py index c4f56e96c5..ab23d3e329 100644 --- a/haystack/components/preprocessors/recursive_splitter.py +++ b/haystack/components/preprocessors/recursive_splitter.py @@ -461,14 +461,10 @@ def run(self, documents: list[Document]) -> dict[str, list[Document]]: :returns: A dictionary containing a key "documents" with a List of Documents with smaller chunks of text corresponding to the input documents. - - :raises RuntimeError: If the component wasn't warmed up but requires it for sentence splitting or tokenization. """ if not self._is_warmed_up and ("sentence" in self.separators or self.split_units == "token"): - raise RuntimeError( - "The component RecursiveDocumentSplitter wasn't warmed up but requires it " - "for sentence splitting or tokenization. Call 'warm_up()' before calling 'run()'." - ) + self.warm_up() + docs = [] for doc in documents: if not doc.content or doc.content == "": diff --git a/haystack/components/rankers/sentence_transformers_diversity.py b/haystack/components/rankers/sentence_transformers_diversity.py index ee77469f00..64385d4154 100644 --- a/haystack/components/rankers/sentence_transformers_diversity.py +++ b/haystack/components/rankers/sentence_transformers_diversity.py @@ -406,14 +406,9 @@ def run( - `documents`: List of Document objects that have been selected based on the diversity ranking. :raises ValueError: If the top_k value is less than or equal to 0. - :raises RuntimeError: If the component has not been warmed up. """ if self.model is None: - error_msg = ( - "The component SentenceTransformersDiversityRanker wasn't warmed up. " - "Run 'warm_up()' before calling 'run()'." - ) - raise RuntimeError(error_msg) + self.warm_up() if not documents: return {"documents": []} diff --git a/haystack/components/rankers/sentence_transformers_similarity.py b/haystack/components/rankers/sentence_transformers_similarity.py index c0657f469c..e503f46f47 100644 --- a/haystack/components/rankers/sentence_transformers_similarity.py +++ b/haystack/components/rankers/sentence_transformers_similarity.py @@ -233,14 +233,9 @@ def run( :raises ValueError: If `top_k` is not > 0. - :raises RuntimeError: - If the model is not loaded because `warm_up()` was not called before. """ if self._cross_encoder is None: - raise RuntimeError( - "The component SentenceTransformersSimilarityRanker wasn't warmed up. " - "Run 'warm_up()' before calling 'run()'." - ) + self.warm_up() if not documents: return {"documents": []} @@ -264,7 +259,8 @@ def run( activation_fn = Sigmoid() if scale_score else Identity() - ranking_result = self._cross_encoder.rank( + # mypy doesn't know this is set in warm_up + ranking_result = self._cross_encoder.rank( # type: ignore[attr-defined] query=prepared_query, documents=prepared_documents, batch_size=self.batch_size, diff --git a/haystack/components/rankers/transformers_similarity.py b/haystack/components/rankers/transformers_similarity.py index 77958b4115..89a8a26d14 100644 --- a/haystack/components/rankers/transformers_similarity.py +++ b/haystack/components/rankers/transformers_similarity.py @@ -123,7 +123,7 @@ def __init__( # noqa: PLR0913, pylint: disable=too-many-positional-arguments self.query_prefix = query_prefix self.document_prefix = document_prefix self.tokenizer = None - self.device = None + self.device: Optional[ComponentDevice] = None self.top_k = top_k self.token = token self.meta_fields_to_embed = meta_fields_to_embed or [] @@ -165,8 +165,10 @@ def warm_up(self): token=self.token.resolve_value() if self.token else None, **self.tokenizer_kwargs, ) - assert self.model is not None - self.device = ComponentDevice.from_multiple(device_map=DeviceMap.from_hf(self.model.hf_device_map)) + # mypy doesn't know this is set right above + self.device = ComponentDevice.from_multiple( + device_map=DeviceMap.from_hf(self.model.hf_device_map) # type: ignore[attr-defined] + ) def to_dict(self) -> dict[str, Any]: """ @@ -249,14 +251,10 @@ def run( # pylint: disable=too-many-positional-arguments :raises ValueError: If `top_k` is not > 0. If `scale_score` is True and `calibration_factor` is not provided. - :raises RuntimeError: - If the model is not loaded because `warm_up()` was not called before. """ # If a model path is provided but the model isn't loaded if self.model is None: - raise RuntimeError( - "The component TransformersSimilarityRanker wasn't warmed up. Run 'warm_up()' before calling 'run()'." - ) + self.warm_up() if not documents: return {"documents": []} @@ -292,30 +290,33 @@ def __len__(self): def __getitem__(self, item): return {key: self.batch_encoding.data[key][item] for key in self.batch_encoding.data.keys()} - batch_enc = self.tokenizer(query_doc_pairs, padding=True, truncation=True, return_tensors="pt").to( - self.device.first_device.to_torch() + # mypy doesn't know this is set in warm_up + batch_enc = self.tokenizer( # type: ignore[misc] + query_doc_pairs, padding=True, truncation=True, return_tensors="pt" + ).to( + self.device.first_device.to_torch() # type: ignore[union-attr] ) dataset = _Dataset(batch_enc) inp_dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False) - similarity_scores = [] + list_similarity_scores = [] with torch.inference_mode(): for features in inp_dataloader: - model_preds = self.model(**features).logits.squeeze(dim=1) - similarity_scores.extend(model_preds) - similarity_scores = torch.stack(similarity_scores) + model_preds = self.model(**features).logits.squeeze(dim=1) # type: ignore[misc] + list_similarity_scores.extend(model_preds) + similarity_scores = torch.stack(list_similarity_scores) - if scale_score: + if scale_score and calibration_factor is not None: similarity_scores = torch.sigmoid(similarity_scores * calibration_factor) _, sorted_indices = torch.sort(similarity_scores, descending=True) - sorted_indices = sorted_indices.cpu().tolist() - similarity_scores = similarity_scores.cpu().tolist() + list_sorted_indices = sorted_indices.cpu().tolist() + list_similarity_scores = similarity_scores.cpu().tolist() ranked_docs = [] - for sorted_index in sorted_indices: + for sorted_index in list_sorted_indices: i = sorted_index - documents[i].score = similarity_scores[i] + documents[i].score = list_similarity_scores[i] ranked_docs.append(documents[i]) if score_threshold is not None: diff --git a/haystack/components/readers/extractive.py b/haystack/components/readers/extractive.py index 74f7d5594f..42e518bfec 100644 --- a/haystack/components/readers/extractive.py +++ b/haystack/components/readers/extractive.py @@ -225,7 +225,8 @@ def _preprocess( # pylint: disable=too-many-positional-arguments document_ids.append(i) document_contents.append(doc.content) - encodings_pt = self.tokenizer( # type: ignore + # mypy doesn't know this is set in warm_up + encodings_pt = self.tokenizer( # type: ignore[misc] queries, document_contents, padding=True, @@ -236,12 +237,9 @@ def _preprocess( # pylint: disable=too-many-positional-arguments stride=stride, ) - # To make mypy happy even though self.device is set in warm_up() - assert self.device is not None - assert self.device.first_device is not None - # Take the first device used by `accelerate`. Needed to pass inputs from the tokenizer to the correct device. - first_device = self.device.first_device.to_torch() + # mypy doesn't know this is set in warm_up + first_device = self.device.first_device.to_torch() # type: ignore[union-attr] input_ids = encodings_pt.input_ids.to(first_device) attention_mask = encodings_pt.attention_mask.to(first_device) @@ -575,14 +573,9 @@ def run( # pylint: disable=too-many-positional-arguments If None is provided then all answers are kept. :returns: List of answers sorted by (desc.) answer score. - - :raises RuntimeError: - If the component was not warmed up by calling 'warm_up()' before. """ if self.model is None: - raise RuntimeError( - "The component ExtractiveReader was not warmed up. Run 'warm_up()' before calling 'run()'." - ) + self.warm_up() if not documents: return {"answers": []} @@ -622,7 +615,8 @@ def run( # pylint: disable=too-many-positional-arguments cur_attention_mask = attention_mask[start_index:end_index] with torch.inference_mode(): - output = self.model(input_ids=cur_input_ids, attention_mask=cur_attention_mask) + # mypy doesn't know this is set in warm_up + output = self.model(input_ids=cur_input_ids, attention_mask=cur_attention_mask) # type: ignore[misc] cur_start_logits = output.start_logits cur_end_logits = output.end_logits if num_batches != 1: diff --git a/haystack/components/routers/transformers_text_router.py b/haystack/components/routers/transformers_text_router.py index 5bcd6bc067..af651c23f5 100644 --- a/haystack/components/routers/transformers_text_router.py +++ b/haystack/components/routers/transformers_text_router.py @@ -188,17 +188,16 @@ def run(self, text: str) -> dict[str, str]: :raises TypeError: If the input is not a str. - :raises RuntimeError: - If the pipeline has not been loaded because warm_up() was not called before. """ if self.pipeline is None: - raise RuntimeError( - "The component TextTransformersRouter wasn't warmed up. Run 'warm_up()' before calling 'run()'." - ) + self.warm_up() if not isinstance(text, str): raise TypeError("TransformersTextRouter expects a str as input.") - prediction = self.pipeline([text], return_all_scores=False, function_to_apply="none") + # mypy doesn't know this is set in warm_up + prediction = self.pipeline( # type: ignore[misc] + [text], return_all_scores=False, function_to_apply="none" + ) label = prediction[0]["label"] return {label: text} diff --git a/haystack/components/routers/zero_shot_text_router.py b/haystack/components/routers/zero_shot_text_router.py index 0baaa913d6..3f1cc6dca0 100644 --- a/haystack/components/routers/zero_shot_text_router.py +++ b/haystack/components/routers/zero_shot_text_router.py @@ -199,18 +199,17 @@ def run(self, text: str) -> dict[str, str]: :raises TypeError: If the input is not a str. - :raises RuntimeError: - If the pipeline has not been loaded because warm_up() was not called before. """ if self.pipeline is None: - raise RuntimeError( - "The component TransformersZeroShotTextRouter wasn't warmed up. Run 'warm_up()' before calling 'run()'." - ) + self.warm_up() if not isinstance(text, str): raise TypeError("TransformersZeroShotTextRouter expects a str as input.") - prediction = self.pipeline([text], candidate_labels=self.labels, multi_label=self.multi_label) + # mypy doesn't know this is set in warm_up + prediction = self.pipeline( # type: ignore[misc] + [text], candidate_labels=self.labels, multi_label=self.multi_label + ) predicted_scores = prediction[0]["scores"] max_score_index = max(range(len(predicted_scores)), key=predicted_scores.__getitem__) label = prediction[0]["labels"][max_score_index] diff --git a/haystack/components/tools/tool_invoker.py b/haystack/components/tools/tool_invoker.py index 5c49d1371a..98213f6bc6 100644 --- a/haystack/components/tools/tool_invoker.py +++ b/haystack/components/tools/tool_invoker.py @@ -540,6 +540,9 @@ def run( :raises ToolOutputMergeError: If merging tool outputs into state fails and `raise_on_failure` is True. """ + if not self._is_warmed_up: + self.warm_up() + tools_with_names = self._tools_with_names if tools is not None: tools_with_names = self._validate_and_prepare_tools(tools) @@ -672,6 +675,8 @@ async def run_async( :raises ToolOutputMergeError: If merging tool outputs into state fails and `raise_on_failure` is True. """ + if not self._is_warmed_up: + self.warm_up() tools_with_names = self._tools_with_names if tools is not None: diff --git a/releasenotes/notes/refactor-warm-up-components-c2777fef28a70b61.yaml b/releasenotes/notes/refactor-warm-up-components-c2777fef28a70b61.yaml new file mode 100644 index 0000000000..cae8b434f9 --- /dev/null +++ b/releasenotes/notes/refactor-warm-up-components-c2777fef28a70b61.yaml @@ -0,0 +1,4 @@ +--- +features: + - | + Components that define ``warm_up`` now auto run it on first use, removing the need for manual calls and preventing errors in standalone usage. diff --git a/test/components/embedders/image/test_sentence_transformers_doc_image_embedder.py b/test/components/embedders/image/test_sentence_transformers_doc_image_embedder.py index 2e6b1525b3..4d8c6bb274 100644 --- a/test/components/embedders/image/test_sentence_transformers_doc_image_embedder.py +++ b/test/components/embedders/image/test_sentence_transformers_doc_image_embedder.py @@ -220,12 +220,6 @@ def test_run(self, test_files_path): assert new_doc.meta["embedding_source"]["type"] == "image" assert "file_path_meta_field" in new_doc.meta["embedding_source"] - def test_run_no_warmup(self): - embedder = SentenceTransformersDocumentImageEmbedder(model="model") - - with pytest.raises(RuntimeError, match="The embedding model has not been loaded."): - embedder.run(documents=[Document(content="test")]) - def test_run_wrong_input_format(self): embedder = SentenceTransformersDocumentImageEmbedder(model="model") @@ -329,7 +323,6 @@ def test_live_run(self, test_files_path, monkeypatch): monkeypatch.delenv("HF_API_TOKEN", raising=False) # https://github.com/deepset-ai/haystack/issues/8811 embedder = SentenceTransformersDocumentImageEmbedder(model="sentence-transformers/clip-ViT-B-32") - embedder.warm_up() documents = [ Document( diff --git a/test/components/embedders/test_sentence_transformers_sparse_document_embedder.py b/test/components/embedders/test_sentence_transformers_sparse_document_embedder.py index 1574ffc134..4e05823f54 100644 --- a/test/components/embedders/test_sentence_transformers_sparse_document_embedder.py +++ b/test/components/embedders/test_sentence_transformers_sparse_document_embedder.py @@ -286,11 +286,6 @@ def test_run_wrong_input_format(self): ): embedder.run(documents=list_integers_input) - def test_run_no_warmup(self): - embedder = SentenceTransformersSparseDocumentEmbedder(model="model") - with pytest.raises(RuntimeError, match="The embedding model has not been loaded."): - embedder.run(documents=[Document(content="test")]) - def test_embed_metadata(self): embedder = SentenceTransformersSparseDocumentEmbedder( model="model", meta_fields_to_embed=["meta_field"], embedding_separator="\n" @@ -435,7 +430,6 @@ def test_live_run_sparse_document_embedder(self, monkeypatch): embedding_separator=" | ", device=ComponentDevice.from_str("cpu"), ) - embedder.warm_up() result = embedder.run(documents=docs) documents_with_embeddings = result["documents"] diff --git a/test/components/embedders/test_sentence_transformers_sparse_text_embedder.py b/test/components/embedders/test_sentence_transformers_sparse_text_embedder.py index 4e018d01e5..c8a36f75c1 100644 --- a/test/components/embedders/test_sentence_transformers_sparse_text_embedder.py +++ b/test/components/embedders/test_sentence_transformers_sparse_text_embedder.py @@ -210,11 +210,6 @@ def test_warmup_doesnt_reload(self, mocked_factory): embedder.warm_up() mocked_factory.get_embedding_backend.assert_called_once() - def test_run_no_warmup(self): - embedder = SentenceTransformersSparseTextEmbedder(model="model") - with pytest.raises(RuntimeError, match="The embedding model has not been loaded"): - embedder.run(text="a nice text to embed") - def test_run(self): embedder = SentenceTransformersSparseTextEmbedder(model="model") embedder.embedding_backend = MagicMock() @@ -337,7 +332,6 @@ def test_live_run_sparse_text_embedder(self, monkeypatch): embedder = SentenceTransformersSparseTextEmbedder( model="sparse-encoder-testing/splade-bert-tiny-nq", device=ComponentDevice.from_str("cpu") ) - embedder.warm_up() result = embedder.run(text=text) sparse_embedding = result["sparse_embedding"] diff --git a/test/components/evaluators/test_sas_evaluator.py b/test/components/evaluators/test_sas_evaluator.py index 6a345ec33e..bbb8210991 100644 --- a/test/components/evaluators/test_sas_evaluator.py +++ b/test/components/evaluators/test_sas_evaluator.py @@ -89,21 +89,6 @@ def test_run_with_none_in_predictions(self): with pytest.raises(ValueError): evaluator.run(ground_truth_answers=ground_truths, predicted_answers=predictions) - def test_run_not_warmed_up(self): - evaluator = SASEvaluator() - ground_truths = [ - "A construction budget of US $2.3 billion", - "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", - "The Meiji Restoration in 1868 transformed Japan into a modernized world power.", - ] - predictions = [ - "A construction budget of US $2.3 billion", - "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", - "The Meiji Restoration in 1868 transformed Japan into a modernized world power.", - ] - with pytest.raises(RuntimeError): - evaluator.run(ground_truth_answers=ground_truths, predicted_answers=predictions) - @pytest.mark.integration @pytest.mark.slow def test_run_with_matching_predictions(self, monkeypatch): @@ -119,7 +104,6 @@ def test_run_with_matching_predictions(self, monkeypatch): "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", "The Meiji Restoration in 1868 transformed Japan into a modernized world power.", ] - evaluator.warm_up() result = evaluator.run(ground_truth_answers=ground_truths, predicted_answers=predictions) assert len(result) == 2 @@ -135,7 +119,6 @@ def test_run_with_single_prediction(self, monkeypatch): ) ground_truths = ["US $2.3 billion"] - evaluator.warm_up() result = evaluator.run( ground_truth_answers=ground_truths, predicted_answers=["A construction budget of US $2.3 billion"] ) @@ -158,7 +141,6 @@ def test_run_with_mismatched_predictions(self, monkeypatch): "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", "The Meiji Restoration in 1868 transformed Japan into a modernized world power.", ] - evaluator.warm_up() result = evaluator.run(ground_truth_answers=ground_truths, predicted_answers=predictions) assert len(result) == 2 assert result["score"] == pytest.approx(0.912335) @@ -179,7 +161,6 @@ def test_run_with_cross_encoder_model(self, monkeypatch): "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", "The Meiji Restoration in 1868 transformed Japan into a modernized world power.", ] - evaluator.warm_up() result = evaluator.run(ground_truth_answers=ground_truths, predicted_answers=predictions) assert len(result) == 2 assert result["score"] == pytest.approx(0.938108, abs=1e-5) diff --git a/test/components/extractors/test_named_entity_extractor.py b/test/components/extractors/test_named_entity_extractor.py index 06322f1eaf..ad8b444137 100644 --- a/test/components/extractors/test_named_entity_extractor.py +++ b/test/components/extractors/test_named_entity_extractor.py @@ -166,13 +166,3 @@ def test_named_entity_extractor_run(): assert "named_entities" in result["documents"][0].meta assert result["documents"][0].meta["named_entities"] == expected_annotations[0] assert "named_entities" not in documents[0].meta - - -def test_named_entity_extractor_run_not_warmed_up(): - """Test that run method raises error when not warmed up.""" - extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER") - - documents = [Document(content="Test document")] - - with pytest.raises(RuntimeError, match="The component NamedEntityExtractor was not warmed up"): - extractor.run(documents=documents) diff --git a/test/components/generators/chat/test_hugging_face_local.py b/test/components/generators/chat/test_hugging_face_local.py index 78f24073a4..650af4097a 100644 --- a/test/components/generators/chat/test_hugging_face_local.py +++ b/test/components/generators/chat/test_hugging_face_local.py @@ -493,7 +493,6 @@ def test_live_run(self, monkeypatch): llm = HuggingFaceLocalChatGenerator( model="Qwen/Qwen2.5-0.5B-Instruct", generation_kwargs={"max_new_tokens": 50} ) - llm.warm_up() result = llm.run(messages) @@ -679,10 +678,6 @@ async def test_async_error_handling(self, model_info_mock, mock_pipeline_with_to """Test error handling in async context""" generator = HuggingFaceLocalChatGenerator(model="mocked-model") - # Test without warm_up - with pytest.raises(RuntimeError, match="The generation model has not been loaded"): - await generator.run_async(messages=[ChatMessage.from_user("test")]) - # Test with invalid streaming callback generator.pipeline = mock_pipeline_with_tokenizer with pytest.raises(ValueError, match="Using tools and streaming at the same time is not supported"): @@ -809,7 +804,6 @@ async def streaming_callback(chunk: StreamingChunk) -> None: generation_kwargs={"max_new_tokens": 50}, streaming_callback=streaming_callback, ) - llm.warm_up() response = await llm.run_async( messages=[ChatMessage.from_user("Please create a summary about the following topic: Capital of France")] diff --git a/test/components/generators/test_hugging_face_local_generator.py b/test/components/generators/test_hugging_face_local_generator.py index c6985ab22d..1c3821264e 100644 --- a/test/components/generators/test_hugging_face_local_generator.py +++ b/test/components/generators/test_hugging_face_local_generator.py @@ -383,14 +383,6 @@ def streaming_callback_handler(x): # the tokenizer should be set, here it is a mock assert streamer.tokenizer - def test_run_fails_without_warm_up(self): - generator = HuggingFaceLocalGenerator( - model="google/flan-t5-base", task="text2text-generation", generation_kwargs={"max_new_tokens": 100} - ) - - with pytest.raises(RuntimeError, match="The component HuggingFaceLocalGenerator was not warmed up"): - generator.run(prompt="irrelevant") - def test_stop_words_criteria_with_a_mocked_tokenizer(self): """ Test that StopWordsCriteria will caught stop word tokens in a continuous and sequential order in the input_ids @@ -471,7 +463,6 @@ def test_hf_pipeline_runs_with_our_criteria(self, monkeypatch): def test_live_run(self, monkeypatch): monkeypatch.delenv("HF_API_TOKEN", raising=False) # https://github.com/deepset-ai/haystack/issues/8811 llm = HuggingFaceLocalGenerator(model="Qwen/Qwen2.5-0.5B-Instruct", generation_kwargs={"max_new_tokens": 50}) - llm.warm_up() # You must use the `apply_chat_template` method to add the generation prompt to properly include the instruction # tokens in the prompt. Otherwise, the model will not generate the expected output. diff --git a/test/components/preprocessors/test_recursive_splitter.py b/test/components/preprocessors/test_recursive_splitter.py index 2858c69562..d91ac883bd 100644 --- a/test/components/preprocessors/test_recursive_splitter.py +++ b/test/components/preprocessors/test_recursive_splitter.py @@ -120,7 +120,6 @@ def test_run_using_custom_sentence_tokenizer(): separators=["\n\n", "\n", "sentence", " "], sentence_splitter_params={"language": "en", "use_split_rules": True, "keep_white_spaces": False}, ) - splitter.warm_up() text = """Artificial intelligence (AI) - Introduction AI, in its broadest sense, is intelligence exhibited by machines, particularly computer systems. @@ -317,7 +316,6 @@ def test_run_split_by_sentence_count_page_breaks_split_unit_char() -> None: document_splitter = RecursiveDocumentSplitter( separators=["sentence"], split_length=28, split_overlap=0, split_unit="char" ) - document_splitter.warm_up() text = ( "Sentence on page 1. Another on page 1.\fSentence on page 2. Another on page 2.\f" @@ -462,7 +460,6 @@ def test_run_custom_sentence_tokenizer_document_and_overlap_char_unit(): splitter = RecursiveDocumentSplitter(split_length=25, split_overlap=10, separators=["sentence"], split_unit="char") text = "This is sentence one. This is sentence two. This is sentence three." - splitter.warm_up() doc = Document(content=text) doc_chunks = splitter.run([doc])["documents"] @@ -668,7 +665,6 @@ def test_run_split_by_sentence_count_page_breaks_word_unit() -> None: document_splitter = RecursiveDocumentSplitter( separators=["sentence"], split_length=7, split_overlap=0, split_unit="word" ) - document_splitter.warm_up() text = ( "Sentence on page 1. Another on page 1.\fSentence on page 2. Another on page 2.\f" @@ -847,7 +843,6 @@ def test_run_serialization_in_pipeline(): @pytest.mark.integration def test_run_split_by_token_count(): splitter = RecursiveDocumentSplitter(split_length=5, separators=["."], split_unit="token") - splitter.warm_up() text = "This is a test. This is another test. This is the final test." doc = Document(content=text) @@ -863,7 +858,6 @@ def test_run_split_by_token_count(): @pytest.mark.integration def test_run_split_by_token_count_with_html_tags(): splitter = RecursiveDocumentSplitter(split_length=4, separators=["."], split_unit="token") - splitter.warm_up() text = "This is a test. This is the final test." doc = Document(content=text) @@ -875,7 +869,6 @@ def test_run_split_by_token_count_with_html_tags(): @pytest.mark.integration def test_run_split_by_token_with_sentence_tokenizer(): splitter = RecursiveDocumentSplitter(split_length=4, separators=["sentence"], split_unit="token") - splitter.warm_up() text = "This is sentence one. This is sentence two." doc = Document(content=text) @@ -891,7 +884,6 @@ def test_run_split_by_token_with_sentence_tokenizer(): @pytest.mark.integration def test_run_split_by_token_with_empty_document(caplog: LogCaptureFixture): splitter = RecursiveDocumentSplitter(split_length=4, separators=["."], split_unit="token") - splitter.warm_up() empty_doc = Document(content="") doc_chunks = splitter.run([empty_doc]) @@ -904,7 +896,6 @@ def test_run_split_by_token_with_empty_document(caplog: LogCaptureFixture): @pytest.mark.integration def test_run_split_by_token_with_fallback(): splitter = RecursiveDocumentSplitter(split_length=2, separators=["."], split_unit="token") - splitter.warm_up() text = "This is a very long sentence that will need to be split into smaller chunks." doc = Document(content=text) @@ -918,7 +909,6 @@ def test_run_split_by_token_with_fallback(): @pytest.mark.integration def test_run_split_by_token_with_overlap_and_fallback(): splitter = RecursiveDocumentSplitter(split_length=4, split_overlap=2, separators=["."], split_unit="token") - splitter.warm_up() text = "This is a test. This is another test. This is the final test." doc = Document(content=text) @@ -935,24 +925,6 @@ def test_run_split_by_token_with_overlap_and_fallback(): assert chunks[7].content == " final test." -def test_run_without_warm_up_raises_error(): - docs = [Document(content="text")] - - splitter_token = RecursiveDocumentSplitter(split_unit="token") - with pytest.raises(RuntimeError, match="warmed up"): - splitter_token.run(docs) - - splitter_sentence = RecursiveDocumentSplitter(separators=["sentence"]) - with pytest.raises(RuntimeError, match="warmed up"): - splitter_sentence.run(docs) - - # Test that no error is raised when warm_up is not required - splitter_no_warmup = RecursiveDocumentSplitter(separators=["."]) - result = splitter_no_warmup.run(docs) - assert len(result["documents"]) == 1 - assert result["documents"][0].content == "text" - - def test_run_complex_text_with_multiple_separators(): """ Test that RecursiveDocumentSplitter correctly handles complex text with multiple separators and chunks that exceed @@ -975,7 +947,6 @@ def test_run_complex_text_with_multiple_separators(): splitter = RecursiveDocumentSplitter( split_length=200, split_overlap=0, split_unit="char", separators=["\n\n", "\n", " "] ) - splitter.warm_up() result = splitter.run([doc]) chunks = result["documents"] @@ -1000,7 +971,6 @@ def test_recursive_splitter_generates_unique_ids_and_correct_meta(): source_doc = Document(content=text) splitter = RecursiveDocumentSplitter(split_length=3) - splitter.warm_up() chunks = splitter.run([source_doc])["documents"] diff --git a/test/components/rankers/test_sentence_transformers_diversity.py b/test/components/rankers/test_sentence_transformers_diversity.py index 69ca3d536e..6bef271984 100644 --- a/test/components/rankers/test_sentence_transformers_diversity.py +++ b/test/components/rankers/test_sentence_transformers_diversity.py @@ -254,20 +254,6 @@ def test_run_invalid_strategy(self): model="sentence-transformers/all-MiniLM-L6-v2", similarity="cosine", strategy=strategy ) - @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) - def test_run_without_warm_up(self, similarity): - """ - Tests that run method raises ComponentError if model is not warmed up - """ - ranker = SentenceTransformersDiversityRanker( - model="sentence-transformers/all-MiniLM-L6-v2", top_k=1, similarity=similarity - ) - documents = [Document(content="doc1"), Document(content="doc2")] - - error_msg = "The component SentenceTransformersDiversityRanker wasn't warmed up." - with pytest.raises(RuntimeError, match=error_msg): - ranker.run(query="test query", documents=documents) - @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) def test_warm_up(self, similarity, monkeypatch): """ @@ -585,7 +571,6 @@ def test_run_real_world_use_case(self, similarity, monkeypatch): ranker = SentenceTransformersDiversityRanker( model="sentence-transformers-testing/stsb-bert-tiny-safetensors", similarity=similarity ) - ranker.warm_up() query = "What are the reasons for long-standing animosities between Russia and Poland?" doc1 = Document( @@ -674,7 +659,6 @@ def test_run_with_maximum_margin_relevance_strategy(self, similarity, monkeypatc similarity=similarity, strategy="maximum_margin_relevance", ) - ranker.warm_up() # lambda_threshold=1, the most relevant document should be returned first results = ranker.run(query=query, documents=docs, lambda_threshold=1, top_k=len(docs)) diff --git a/test/components/rankers/test_sentence_transformers_similarity.py b/test/components/rankers/test_sentence_transformers_similarity.py index f0622c630a..3a4bed1625 100644 --- a/test/components/rankers/test_sentence_transformers_similarity.py +++ b/test/components/rankers/test_sentence_transformers_similarity.py @@ -270,11 +270,6 @@ def test_returns_empty_list_if_no_documents_are_provided(self): output = ranker.run(query="City in Germany", documents=[]) assert not output["documents"] - def test_raises_component_error_if_model_not_warmed_up(self): - ranker = SentenceTransformersSimilarityRanker() - with pytest.raises(RuntimeError): - ranker.run(query="query", documents=[Document(content="document")]) - def test_embed_meta(self): ranker = SentenceTransformersSimilarityRanker( model="model", meta_fields_to_embed=["meta_field"], embedding_separator="\n" @@ -373,7 +368,6 @@ def test_scores_cast_to_python_float_when_numpy_scalars_returned(self): @pytest.mark.slow def test_run(self): ranker = SentenceTransformersSimilarityRanker(model="cross-encoder-testing/reranker-bert-tiny-gooaq-bce") - ranker.warm_up() query = "City in Bosnia and Herzegovina" docs_before_texts = ["Berlin", "Belgrade", "Sarajevo"] @@ -401,7 +395,6 @@ def test_run_top_k(self): ranker = SentenceTransformersSimilarityRanker( model="cross-encoder-testing/reranker-bert-tiny-gooaq-bce", top_k=2 ) - ranker.warm_up() query = "City in Bosnia and Herzegovina" docs_before_texts = ["Berlin", "Belgrade", "Sarajevo"] @@ -426,7 +419,6 @@ def test_run_single_document(self): ranker = SentenceTransformersSimilarityRanker( model="cross-encoder-testing/reranker-bert-tiny-gooaq-bce", device=None ) - ranker.warm_up() docs_before = [Document(content="Berlin")] output = ranker.run(query="City in Germany", documents=docs_before) docs_after = output["documents"] diff --git a/test/components/rankers/test_transformers_similarity.py b/test/components/rankers/test_transformers_similarity.py index d85f7c75e7..a4e0c8a4fa 100644 --- a/test/components/rankers/test_transformers_similarity.py +++ b/test/components/rankers/test_transformers_similarity.py @@ -338,17 +338,14 @@ def __init__(self): def test_returns_empty_list_if_no_documents_are_provided(self): sampler = TransformersSimilarityRanker() + # Mock all attributes that are set during warm_up sampler.model = MagicMock() + sampler.tokenizer = MagicMock() + sampler.device = MagicMock() output = sampler.run(query="City in Germany", documents=[]) assert not output["documents"] - # Raises ComponentError if model is not warmed up - def test_raises_component_error_if_model_not_warmed_up(self): - sampler = TransformersSimilarityRanker() - with pytest.raises(RuntimeError): - sampler.run(query="query", documents=[Document(content="document")]) - @pytest.mark.integration @pytest.mark.slow def test_run(self):