Skip to content

Commit 5502eff

Browse files
davidsbatistasjrl
andauthored
feat: adding new components MultiQueryKeywordRetriever and MultiQueryEmbeddingRetriever (#358)
* initial import * adding retrievers Protocol description * WIP: writing keyword retrieval component * adding both components and fixing linting and typing issues * wip: writing/updating tests * wip: writing/updating tests * adding integration test * adding multiquery embedding retrieval tests * fixing integrations tests * fixing integrations tests * simplying integrations tests * simplying integrations tests * simplying tests * updating tests * updating tests * Update haystack_experimental/components/retrievers/multi_query_embedding_retriever.py Co-authored-by: Sebastian Husch Lee <[email protected]> * Update haystack_experimental/components/retrievers/multi_query_text_retriever.py Co-authored-by: Sebastian Husch Lee <[email protected]> * Update haystack_experimental/components/retrievers/multi_query_text_retriever.py Co-authored-by: Sebastian Husch Lee <[email protected]> * attending PR comments * explicitly warming up in integration test without pipeline --------- Co-authored-by: Sebastian Husch Lee <[email protected]>
1 parent 9ac6a92 commit 5502eff

File tree

9 files changed

+803
-1
lines changed

9 files changed

+803
-1
lines changed

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ that includes it. Once it reaches the end of its lifespan, the experiment will b
4848
| [`ChatMessageWriter`][3] | Memory Component | December 2024 | None | <a href="https://colab.research.google.com/github/deepset-ai/haystack-cookbook/blob/main/notebooks/conversational_rag_using_memory.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/> | [Discuss][4] |
4949
| [`QueryExpander`][5] | Query Expansion Component | October 2025 | None | None | [Discuss][6] |
5050
| [`EmbeddingBasedDocumentSplitter`][8] | EmbeddingBasedDocumentSplitter | August 2025 | None | None | [Discuss][7] |
51+
| [`MultiQueryEmbeddingRetriever`][9] | MultiQueryEmbeddingRetriever | November 2025 | None | None | [Discuss][11] |
52+
| [`MultiQueryTextRetriever`][10] | MultiQueryTextRetriever | November 2025 | None | None | [Discuss][12] |
5153
| [`OpenAIChatGenerator`][9] | Chat Generator Component | November 2025 | None | None | [Discuss][10] |
5254

5355
[1]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/chat_message_stores/in_memory.py
@@ -60,6 +62,10 @@ that includes it. Once it reaches the end of its lifespan, the experiment will b
6062
[8]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/components/preprocessors/embedding_based_document_splitter.py
6163
[9]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/components/generators/chat/openai.py
6264
[10]: https://github.com/deepset-ai/haystack-experimental/discussions/XXX
65+
[11]: https://github.com/deepset-ai/haystack-experimental/discussions/<>
66+
[12]: https://github.com/deepset-ai/haystack-experimental/discussions/<>
67+
[13]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/components/retrievers/multi_query_embedding_retriever.py
68+
[14]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/components/retrievers/multi_query_text_retriever.py
6369

6470
### Adopted experiments
6571
| Name | Type | Final release |

docs/pydoc/config/retrievers_api.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ loaders:
44
modules:
55
[
66
"haystack_experimental.components.retrievers.chat_message_retriever",
7+
"haystack_experimental.components.retrievers.multi_query_embedding_retriever",
8+
"haystack_experimental.components.retrievers.multi_query_text_retriever",
79
]
810
ignore_when_discovered: ["__init__"]
911
processors:

haystack_experimental/components/retrievers/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,7 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
from haystack_experimental.components.retrievers.chat_message_retriever import ChatMessageRetriever
6+
from haystack_experimental.components.retrievers.multi_query_embedding_retriever import MultiQueryEmbeddingRetriever
7+
from haystack_experimental.components.retrievers.multi_query_text_retriever import MultiQueryTextRetriever
68

7-
_all_ = ["ChatMessageRetriever"]
9+
_all_ = ["ChatMessageRetriever", "MultiQueryTextRetriever", "MultiQueryEmbeddingRetriever"]
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from concurrent.futures import ThreadPoolExecutor
6+
from typing import Any, List, Optional
7+
8+
from haystack import Document, component, default_from_dict, default_to_dict
9+
from haystack.components.embedders.types.protocol import TextEmbedder
10+
from haystack.core.serialization import component_to_dict
11+
from haystack.utils.deserialization import deserialize_component_inplace
12+
13+
from haystack_experimental.components.retrievers.types import EmbeddingRetriever
14+
15+
16+
@component
17+
class MultiQueryEmbeddingRetriever:
18+
"""
19+
A component that retrieves documents using multiple queries in parallel with an embedding-based retriever.
20+
21+
This component takes a list of text queries, converts them to embeddings using a query embedder,
22+
and then uses an embedding-based retriever to find relevant documents for each query in parallel.
23+
The results are combined and sorted by relevance score.
24+
25+
### Usage example
26+
27+
```python
28+
from haystack import Document
29+
from haystack.document_stores.in_memory import InMemoryDocumentStore
30+
from haystack.document_stores.types import DuplicatePolicy
31+
from haystack.components.embedders import SentenceTransformersTextEmbedder
32+
from haystack.components.embedders import SentenceTransformersDocumentEmbedder
33+
from haystack.components.retrievers import InMemoryEmbeddingRetriever
34+
from haystack.components.writers import DocumentWriter
35+
from haystack_experimental.components.retrievers import MultiQueryEmbeddingRetriever
36+
37+
documents = [
38+
Document(content="Renewable energy is energy that is collected from renewable resources."),
39+
Document(content="Solar energy is a type of green energy that is harnessed from the sun."),
40+
Document(content="Wind energy is another type of green energy that is generated by wind turbines."),
41+
Document(content="Geothermal energy is heat that comes from the sub-surface of the earth."),
42+
Document(content="Biomass energy is produced from organic materials, such as plant and animal waste."),
43+
Document(content="Fossil fuels, such as coal, oil, and natural gas, are non-renewable energy sources."),
44+
]
45+
46+
# Populate the document store
47+
doc_store = InMemoryDocumentStore()
48+
doc_embedder = SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")
49+
doc_embedder.warm_up()
50+
doc_writer = DocumentWriter(document_store=doc_store, policy=DuplicatePolicy.SKIP)
51+
documents = doc_embedder.run(documents)["documents"]
52+
doc_writer.run(documents=documents)
53+
54+
# Run the multi-query retriever
55+
in_memory_retriever = InMemoryEmbeddingRetriever(document_store=doc_store, top_k=1)
56+
query_embedder = SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")
57+
58+
multi_query_retriever = MultiQueryEmbeddingRetriever(
59+
retriever=in_memory_retriever,
60+
query_embedder=query_embedder,
61+
max_workers=3
62+
)
63+
64+
queries = ["Geothermal energy", "natural gas", "turbines"]
65+
result = multi_query_retriever.run(queries=queries)
66+
for doc in result["documents"]:
67+
print(f"Content: {doc.content}, Score: {doc.score}")
68+
>> Content: Geothermal energy is heat that comes from the sub-surface of the earth., Score: 0.8509603046266574
69+
>> Content: Renewable energy is energy that is collected from renewable resources., Score: 0.42763211298893034
70+
>> Content: Solar energy is a type of green energy that is harnessed from the sun., Score: 0.40077417016494354
71+
>> Content: Fossil fuels, such as coal, oil, and natural gas, are non-renewable energy sources., Score: 0.3774863680995796
72+
>> Content: Wind energy is another type of green energy that is generated by wind turbines., Score: 0.3091423972562246
73+
>> Content: Biomass energy is produced from organic materials, such as plant and animal waste., Score: 0.25173074243668087
74+
```
75+
""" # noqa E501
76+
77+
def __init__(
78+
self,
79+
*,
80+
retriever: EmbeddingRetriever,
81+
query_embedder: TextEmbedder,
82+
max_workers: int = 3,
83+
):
84+
"""
85+
Initialize MultiQueryEmbeddingRetriever.
86+
87+
:param retriever: The embedding-based retriever to use for document retrieval.
88+
:param query_embedder: The query embedder to convert text queries to embeddings.
89+
:param max_workers: Maximum number of worker threads for parallel processing.
90+
"""
91+
self.retriever = retriever
92+
self.query_embedder = query_embedder
93+
self.max_workers = max_workers
94+
self._is_warmed_up = False
95+
96+
def warm_up(self) -> None:
97+
"""
98+
Warm up the query embedder and the retriever if any has a warm_up method.
99+
"""
100+
if not self._is_warmed_up:
101+
if hasattr(self.query_embedder, "warm_up") and callable(getattr(self.query_embedder, "warm_up")):
102+
self.query_embedder.warm_up()
103+
if hasattr(self.retriever, "warm_up") and callable(getattr(self.retriever, "warm_up")):
104+
self.retriever.warm_up()
105+
self._is_warmed_up = True
106+
107+
@component.output_types(documents=List[Document])
108+
def run(
109+
self,
110+
queries: List[str],
111+
retriever_kwargs: Optional[dict[str, Any]] = None,
112+
) -> dict[str, Any]:
113+
"""
114+
Retrieve documents using multiple queries in parallel.
115+
116+
:param queries: List of text queries to process.
117+
:param retriever_kwargs: Optional dictionary of arguments to pass to the retriever's run method.
118+
:returns:
119+
A dictionary containing:
120+
- `documents`: List of retrieved documents sorted by relevance score.
121+
"""
122+
docs: list[Document] = []
123+
seen_contents = set()
124+
retriever_kwargs = retriever_kwargs or {}
125+
126+
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
127+
queries_results = executor.map(lambda query: self._run_on_thread(query, retriever_kwargs), queries)
128+
for result in queries_results:
129+
if not result:
130+
continue
131+
for doc in result:
132+
# deduplicate based on content
133+
if doc.content not in seen_contents:
134+
docs.append(doc)
135+
seen_contents.add(doc.content)
136+
137+
docs.sort(key=lambda x: x.score or 0.0, reverse=True)
138+
return {"documents": docs}
139+
140+
def _run_on_thread(self, query: str, retriever_kwargs: Optional[dict[str, Any]] = None) -> Optional[List[Document]]:
141+
"""
142+
Process a single query on a separate thread.
143+
144+
:param query: The text query to process.
145+
:returns:
146+
List of retrieved documents or None if no results.
147+
"""
148+
embedding_result = self.query_embedder.run(text=query)
149+
query_embedding = embedding_result["embedding"]
150+
result = self.retriever.run(query_embedding=query_embedding, **(retriever_kwargs or {}))
151+
if result and "documents" in result:
152+
return result["documents"]
153+
return None
154+
155+
def to_dict(self) -> dict[str, Any]:
156+
"""
157+
Serializes the component to a dictionary.
158+
159+
:returns:
160+
A dictionary representing the serialized component.
161+
"""
162+
return default_to_dict(
163+
self,
164+
retriever=component_to_dict(obj=self.retriever, name="retriever"),
165+
query_embedder=component_to_dict(obj=self.query_embedder, name="query_embedder"),
166+
max_workers=self.max_workers,
167+
)
168+
169+
@classmethod
170+
def from_dict(cls, data: dict[str, Any]) -> "MultiQueryEmbeddingRetriever":
171+
"""
172+
Deserializes the component from a dictionary.
173+
174+
:param data: The dictionary to deserialize from.
175+
:returns:
176+
The deserialized component.
177+
"""
178+
deserialize_component_inplace(data["init_parameters"], key="retriever")
179+
deserialize_component_inplace(data["init_parameters"], key="query_embedder")
180+
return default_from_dict(cls, data)
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from concurrent.futures import ThreadPoolExecutor
6+
from typing import Any, List, Optional
7+
8+
from haystack import Document, component, default_from_dict, default_to_dict
9+
from haystack.core.serialization import component_to_dict
10+
from haystack.utils.deserialization import deserialize_component_inplace
11+
12+
from haystack_experimental.components.retrievers.types import TextRetriever
13+
14+
15+
@component
16+
class MultiQueryTextRetriever:
17+
"""
18+
A component that retrieves documents using multiple queries in parallel with a text-based retriever.
19+
20+
This component takes a list of text queries and uses a text-based retriever to find relevant documents for each
21+
query in parallel, using a thread pool to manage concurrent execution. The results are combined and sorted by
22+
relevance score.
23+
24+
You can use this component in combination with QueryExpander component to enhance the retrieval process.
25+
26+
### Usage example
27+
```python
28+
from haystack import Document
29+
from haystack.components.writers import DocumentWriter
30+
from haystack.document_stores.in_memory import InMemoryDocumentStore
31+
from haystack.document_stores.types import DuplicatePolicy
32+
from haystack.components.retrievers import InMemoryBM25Retriever
33+
from haystack_experimental.components.query import QueryExpander
34+
from haystack_experimental.components.retrievers.multi_query_text_retriever import MultiQueryTextRetriever
35+
36+
documents = [
37+
Document(content="Renewable energy is energy that is collected from renewable resources."),
38+
Document(content="Solar energy is a type of green energy that is harnessed from the sun."),
39+
Document(content="Wind energy is another type of green energy that is generated by wind turbines."),
40+
Document(content="Hydropower is a form of renewable energy using the flow of water to generate electricity."),
41+
Document(content="Geothermal energy is heat that comes from the sub-surface of the earth.")
42+
]
43+
44+
document_store = InMemoryDocumentStore()
45+
doc_writer = DocumentWriter(document_store=document_store, policy=DuplicatePolicy.SKIP)
46+
doc_writer.run(documents=documents)
47+
48+
in_memory_retriever = InMemoryBM25Retriever(document_store=document_store, top_k=1)
49+
multiquery_retriever = MultiQueryTextRetriever(retriever=in_memory_retriever)
50+
results = multiquery_retriever.run(queries=["renewable energy?", "Geothermal", "Hydropower"])
51+
for doc in results["documents"]:
52+
print(f"Content: {doc.content}, Score: {doc.score}")
53+
>>
54+
>> Content: Geothermal energy is heat that comes from the sub-surface of the earth., Score: 1.6474448833731097
55+
>> Content: Hydropower is a form of renewable energy using the flow of water to generate electricity., Score: 1.6157822790079805
56+
>> Content: Renewable energy is energy that is collected from renewable resources., Score: 1.5255309812344944
57+
```
58+
""" # noqa E501
59+
60+
def __init__(
61+
self,
62+
retriever: TextRetriever,
63+
max_workers: int = 3,
64+
):
65+
"""
66+
Initialize MultiQueryTextRetriever.
67+
68+
:param retriever: The text-based retriever to use for document retrieval.
69+
:param max_workers: Maximum number of worker threads for parallel processing. Default is 3.
70+
"""
71+
self.retriever = retriever
72+
self.max_workers = max_workers
73+
self._is_warmed_up = False
74+
75+
def warm_up(self) -> None:
76+
"""
77+
Warm up the retriever if it has a warm_up method.
78+
"""
79+
if not self._is_warmed_up:
80+
if hasattr(self.retriever, "warm_up") and callable(getattr(self.retriever, "warm_up")):
81+
self.retriever.warm_up()
82+
self._is_warmed_up = True
83+
84+
@component.output_types(documents=list[Document])
85+
def run(
86+
self,
87+
queries: List[str],
88+
retriever_kwargs: Optional[dict[str, Any]] = None,
89+
) -> dict[str, Any]:
90+
"""
91+
Retrieve documents using multiple queries in parallel.
92+
93+
:param queries: List of text queries to process.
94+
:param retriever_kwargs: Optional dictionary of arguments to pass to the retriever's run method.
95+
:returns:
96+
A dictionary containing:
97+
`documents`: List of retrieved documents sorted by relevance score.
98+
"""
99+
docs: list[Document] = []
100+
seen_contents = set()
101+
retriever_kwargs = retriever_kwargs or {}
102+
103+
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
104+
queries_results = executor.map(lambda query: self._run_on_thread(query, retriever_kwargs), queries)
105+
for result in queries_results:
106+
if not result:
107+
continue
108+
# deduplicate based on content
109+
for doc in result:
110+
if doc.content not in seen_contents:
111+
docs.append(doc)
112+
seen_contents.add(doc.content)
113+
114+
docs.sort(key=lambda x: x.score or 0.0, reverse=True)
115+
return {"documents": docs}
116+
117+
def _run_on_thread(
118+
self,
119+
query: str,
120+
retriever_kwargs: Optional[dict[str, Any]] = None,
121+
) -> Optional[list[Document]]:
122+
"""
123+
Process a single query on a separate thread.
124+
125+
:param query: The text query to process.
126+
:param retriever_kwargs: Optional dictionary of arguments to pass to the retriever's run method.
127+
:returns:
128+
List of retrieved documents or None if no results.
129+
"""
130+
result = self.retriever.run(query=query, **(retriever_kwargs or {}))
131+
if result and "documents" in result:
132+
return result["documents"]
133+
return None
134+
135+
def to_dict(self) -> dict[str, Any]:
136+
"""
137+
Serializes the component to a dictionary.
138+
139+
:returns:
140+
The serialized component as a dictionary.
141+
"""
142+
return default_to_dict(
143+
self,
144+
retriever=component_to_dict(obj=self.retriever, name="retriever"),
145+
max_workers=self.max_workers,
146+
)
147+
148+
@classmethod
149+
def from_dict(cls, data: dict[str, Any]) -> "MultiQueryTextRetriever":
150+
"""
151+
Deserializes the component from a dictionary.
152+
153+
:param data: The dictionary to deserialize from.
154+
:returns:
155+
The deserialized component.
156+
"""
157+
deserialize_component_inplace(data["init_parameters"], key="retriever")
158+
return default_from_dict(cls, data)
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from .protocol import EmbeddingRetriever, TextRetriever
6+
7+
__all__ = ["TextRetriever", "EmbeddingRetriever"]

0 commit comments

Comments
 (0)