|
| 1 | +# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]> |
| 2 | +# |
| 3 | +# SPDX-License-Identifier: Apache-2.0 |
| 4 | + |
| 5 | +from concurrent.futures import ThreadPoolExecutor |
| 6 | +from typing import Any, List, Optional |
| 7 | + |
| 8 | +from haystack import Document, component, default_from_dict, default_to_dict |
| 9 | +from haystack.components.embedders.types.protocol import TextEmbedder |
| 10 | +from haystack.core.serialization import component_to_dict |
| 11 | +from haystack.utils.deserialization import deserialize_component_inplace |
| 12 | + |
| 13 | +from haystack_experimental.components.retrievers.types import EmbeddingRetriever |
| 14 | + |
| 15 | + |
| 16 | +@component |
| 17 | +class MultiQueryEmbeddingRetriever: |
| 18 | + """ |
| 19 | + A component that retrieves documents using multiple queries in parallel with an embedding-based retriever. |
| 20 | +
|
| 21 | + This component takes a list of text queries, converts them to embeddings using a query embedder, |
| 22 | + and then uses an embedding-based retriever to find relevant documents for each query in parallel. |
| 23 | + The results are combined and sorted by relevance score. |
| 24 | +
|
| 25 | + ### Usage example |
| 26 | +
|
| 27 | + ```python |
| 28 | + from haystack import Document |
| 29 | + from haystack.document_stores.in_memory import InMemoryDocumentStore |
| 30 | + from haystack.document_stores.types import DuplicatePolicy |
| 31 | + from haystack.components.embedders import SentenceTransformersTextEmbedder |
| 32 | + from haystack.components.embedders import SentenceTransformersDocumentEmbedder |
| 33 | + from haystack.components.retrievers import InMemoryEmbeddingRetriever |
| 34 | + from haystack.components.writers import DocumentWriter |
| 35 | + from haystack_experimental.components.retrievers import MultiQueryEmbeddingRetriever |
| 36 | +
|
| 37 | + documents = [ |
| 38 | + Document(content="Renewable energy is energy that is collected from renewable resources."), |
| 39 | + Document(content="Solar energy is a type of green energy that is harnessed from the sun."), |
| 40 | + Document(content="Wind energy is another type of green energy that is generated by wind turbines."), |
| 41 | + Document(content="Geothermal energy is heat that comes from the sub-surface of the earth."), |
| 42 | + Document(content="Biomass energy is produced from organic materials, such as plant and animal waste."), |
| 43 | + Document(content="Fossil fuels, such as coal, oil, and natural gas, are non-renewable energy sources."), |
| 44 | + ] |
| 45 | +
|
| 46 | + # Populate the document store |
| 47 | + doc_store = InMemoryDocumentStore() |
| 48 | + doc_embedder = SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2") |
| 49 | + doc_embedder.warm_up() |
| 50 | + doc_writer = DocumentWriter(document_store=doc_store, policy=DuplicatePolicy.SKIP) |
| 51 | + documents = doc_embedder.run(documents)["documents"] |
| 52 | + doc_writer.run(documents=documents) |
| 53 | +
|
| 54 | + # Run the multi-query retriever |
| 55 | + in_memory_retriever = InMemoryEmbeddingRetriever(document_store=doc_store, top_k=1) |
| 56 | + query_embedder = SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2") |
| 57 | +
|
| 58 | + multi_query_retriever = MultiQueryEmbeddingRetriever( |
| 59 | + retriever=in_memory_retriever, |
| 60 | + query_embedder=query_embedder, |
| 61 | + max_workers=3 |
| 62 | + ) |
| 63 | +
|
| 64 | + queries = ["Geothermal energy", "natural gas", "turbines"] |
| 65 | + result = multi_query_retriever.run(queries=queries) |
| 66 | + for doc in result["documents"]: |
| 67 | + print(f"Content: {doc.content}, Score: {doc.score}") |
| 68 | + >> Content: Geothermal energy is heat that comes from the sub-surface of the earth., Score: 0.8509603046266574 |
| 69 | + >> Content: Renewable energy is energy that is collected from renewable resources., Score: 0.42763211298893034 |
| 70 | + >> Content: Solar energy is a type of green energy that is harnessed from the sun., Score: 0.40077417016494354 |
| 71 | + >> Content: Fossil fuels, such as coal, oil, and natural gas, are non-renewable energy sources., Score: 0.3774863680995796 |
| 72 | + >> Content: Wind energy is another type of green energy that is generated by wind turbines., Score: 0.3091423972562246 |
| 73 | + >> Content: Biomass energy is produced from organic materials, such as plant and animal waste., Score: 0.25173074243668087 |
| 74 | + ``` |
| 75 | + """ # noqa E501 |
| 76 | + |
| 77 | + def __init__( |
| 78 | + self, |
| 79 | + *, |
| 80 | + retriever: EmbeddingRetriever, |
| 81 | + query_embedder: TextEmbedder, |
| 82 | + max_workers: int = 3, |
| 83 | + ): |
| 84 | + """ |
| 85 | + Initialize MultiQueryEmbeddingRetriever. |
| 86 | +
|
| 87 | + :param retriever: The embedding-based retriever to use for document retrieval. |
| 88 | + :param query_embedder: The query embedder to convert text queries to embeddings. |
| 89 | + :param max_workers: Maximum number of worker threads for parallel processing. |
| 90 | + """ |
| 91 | + self.retriever = retriever |
| 92 | + self.query_embedder = query_embedder |
| 93 | + self.max_workers = max_workers |
| 94 | + self._is_warmed_up = False |
| 95 | + |
| 96 | + def warm_up(self) -> None: |
| 97 | + """ |
| 98 | + Warm up the query embedder and the retriever if any has a warm_up method. |
| 99 | + """ |
| 100 | + if not self._is_warmed_up: |
| 101 | + if hasattr(self.query_embedder, "warm_up") and callable(getattr(self.query_embedder, "warm_up")): |
| 102 | + self.query_embedder.warm_up() |
| 103 | + if hasattr(self.retriever, "warm_up") and callable(getattr(self.retriever, "warm_up")): |
| 104 | + self.retriever.warm_up() |
| 105 | + self._is_warmed_up = True |
| 106 | + |
| 107 | + @component.output_types(documents=List[Document]) |
| 108 | + def run( |
| 109 | + self, |
| 110 | + queries: List[str], |
| 111 | + retriever_kwargs: Optional[dict[str, Any]] = None, |
| 112 | + ) -> dict[str, Any]: |
| 113 | + """ |
| 114 | + Retrieve documents using multiple queries in parallel. |
| 115 | +
|
| 116 | + :param queries: List of text queries to process. |
| 117 | + :param retriever_kwargs: Optional dictionary of arguments to pass to the retriever's run method. |
| 118 | + :returns: |
| 119 | + A dictionary containing: |
| 120 | + - `documents`: List of retrieved documents sorted by relevance score. |
| 121 | + """ |
| 122 | + docs: list[Document] = [] |
| 123 | + seen_contents = set() |
| 124 | + retriever_kwargs = retriever_kwargs or {} |
| 125 | + |
| 126 | + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: |
| 127 | + queries_results = executor.map(lambda query: self._run_on_thread(query, retriever_kwargs), queries) |
| 128 | + for result in queries_results: |
| 129 | + if not result: |
| 130 | + continue |
| 131 | + for doc in result: |
| 132 | + # deduplicate based on content |
| 133 | + if doc.content not in seen_contents: |
| 134 | + docs.append(doc) |
| 135 | + seen_contents.add(doc.content) |
| 136 | + |
| 137 | + docs.sort(key=lambda x: x.score or 0.0, reverse=True) |
| 138 | + return {"documents": docs} |
| 139 | + |
| 140 | + def _run_on_thread(self, query: str, retriever_kwargs: Optional[dict[str, Any]] = None) -> Optional[List[Document]]: |
| 141 | + """ |
| 142 | + Process a single query on a separate thread. |
| 143 | +
|
| 144 | + :param query: The text query to process. |
| 145 | + :returns: |
| 146 | + List of retrieved documents or None if no results. |
| 147 | + """ |
| 148 | + embedding_result = self.query_embedder.run(text=query) |
| 149 | + query_embedding = embedding_result["embedding"] |
| 150 | + result = self.retriever.run(query_embedding=query_embedding, **(retriever_kwargs or {})) |
| 151 | + if result and "documents" in result: |
| 152 | + return result["documents"] |
| 153 | + return None |
| 154 | + |
| 155 | + def to_dict(self) -> dict[str, Any]: |
| 156 | + """ |
| 157 | + Serializes the component to a dictionary. |
| 158 | +
|
| 159 | + :returns: |
| 160 | + A dictionary representing the serialized component. |
| 161 | + """ |
| 162 | + return default_to_dict( |
| 163 | + self, |
| 164 | + retriever=component_to_dict(obj=self.retriever, name="retriever"), |
| 165 | + query_embedder=component_to_dict(obj=self.query_embedder, name="query_embedder"), |
| 166 | + max_workers=self.max_workers, |
| 167 | + ) |
| 168 | + |
| 169 | + @classmethod |
| 170 | + def from_dict(cls, data: dict[str, Any]) -> "MultiQueryEmbeddingRetriever": |
| 171 | + """ |
| 172 | + Deserializes the component from a dictionary. |
| 173 | +
|
| 174 | + :param data: The dictionary to deserialize from. |
| 175 | + :returns: |
| 176 | + The deserialized component. |
| 177 | + """ |
| 178 | + deserialize_component_inplace(data["init_parameters"], key="retriever") |
| 179 | + deserialize_component_inplace(data["init_parameters"], key="query_embedder") |
| 180 | + return default_from_dict(cls, data) |
0 commit comments