Skip to content
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions haystack/components/retrievers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
"auto_merging_retriever": ["AutoMergingRetriever"],
"filter_retriever": ["FilterRetriever"],
"in_memory": ["InMemoryBM25Retriever", "InMemoryEmbeddingRetriever"],
"multi_retriever": ["MultiRetriever"],
"multi_query_embedding_retriever": ["MultiQueryEmbeddingRetriever"],
"multi_query_text_retriever": ["MultiQueryTextRetriever"],
"sentence_window_retriever": ["SentenceWindowRetriever"],
"query_embedding_retriever": ["QueryEmbeddingRetriever"],
}

if TYPE_CHECKING:
Expand All @@ -23,6 +25,8 @@
from .in_memory import InMemoryEmbeddingRetriever as InMemoryEmbeddingRetriever
from .multi_query_embedding_retriever import MultiQueryEmbeddingRetriever as MultiQueryEmbeddingRetriever
from .multi_query_text_retriever import MultiQueryTextRetriever as MultiQueryTextRetriever
from .multi_retriever import MultiRetriever as MultiRetriever
from .query_embedding_retriever import QueryEmbeddingRetriever as QueryEmbeddingRetriever
from .sentence_window_retriever import SentenceWindowRetriever as SentenceWindowRetriever

else:
Expand Down
204 changes: 204 additions & 0 deletions haystack/components/retrievers/multi_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0

from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any

from haystack import component, default_from_dict, default_to_dict
from haystack.components.retrievers.types.protocol import TextRetriever
from haystack.core.serialization import component_from_dict, component_to_dict, import_class_by_name
from haystack.dataclasses import Document
from haystack.utils.misc import _deduplicate_documents


@component
class MultiRetriever:
"""
A component that runs multiple retrievers in parallel and combines their results.

Each retriever is queried concurrently using a thread pool.
The results are deduplicated and returned as a single list of documents.

### Usage example

```python
from haystack import Document
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.document_stores.types import DuplicatePolicy
from haystack.components.retrievers import InMemoryBM25Retriever, InMemoryEmbeddingRetriever
from haystack.components.retrievers import QueryEmbeddingRetriever, MultiRetriever
from haystack.components.embedders import SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder
from haystack.components.writers import DocumentWriter

documents = [
Document(content="Renewable energy is energy that is collected from renewable resources."),
Document(content="Solar energy is a type of green energy that is harnessed from the sun."),
Document(content="Wind energy is another type of green energy that is generated by wind turbines."),
]

# Populate the document store
doc_store = InMemoryDocumentStore()
doc_embedder = SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")
doc_writer = DocumentWriter(document_store=doc_store, policy=DuplicatePolicy.SKIP)
doc_writer.run(documents=doc_embedder.run(documents)["documents"])

# Run the multi-retriever with all retrievers
retriever = MultiRetriever(
retrievers={
"bm25": InMemoryBM25Retriever(document_store=doc_store),
"embedding": QueryEmbeddingRetriever(
retriever=InMemoryEmbeddingRetriever(document_store=doc_store),
query_embedder=SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"),
),
},
top_k=3,
)

# Run all retrievers
result = retriever.run(query="green energy sources")

# Run only the BM25 retriever
result = retriever.run(query="green energy sources", active_retrievers=["bm25"])

for doc in result["documents"]:
print(doc.content)
```
"""

def __init__(
self,
*,
retrievers: dict[str, TextRetriever],
filters: dict[str, Any] | None = None,
top_k: int = 10,
max_workers: int = 4,
) -> None:
"""
Create the MultiRetriever component.

:param retrievers:
A dictionary mapping names to retriever components to run in parallel.
:param filters:
A dictionary of filters to apply when retrieving documents.
:param top_k:
The maximum number of documents to return per retriever.
:param max_workers:
The maximum number of threads to use for parallel retrieval.
"""
self.retrievers = retrievers
self.filters = filters
self.top_k = top_k
self.max_workers = max_workers
self._is_warmed_up = False

def warm_up(self) -> None:
"""
Warm up the retrievers if any has a warm_up method.
"""
if self._is_warmed_up:
return
for retriever in self.retrievers.values():
if hasattr(retriever, "warm_up") and callable(retriever.warm_up):
retriever.warm_up()
self._is_warmed_up = True

@component.output_types(documents=list[Document])
def run(
self,
query: str,
filters: dict[str, Any] | None = None,
top_k: int | None = None,
*,
active_retrievers: list[str] | None = None,
) -> dict[str, list[Document]]:
"""
Runs retrievers in parallel on the given query and returns deduplicated results.

:param query:
The query to run the retrievers on.
:param filters:
The filters to apply to the retrievers. If not provided, the filters from the initialization of the
component will be used. If those are also not provided, no filters will be applied.
:param top_k:
The number of documents to return per retriever. If not provided, the top_k from the initialization of
the component will be used.
:param active_retrievers:
A list of retriever names to run. If not provided, all retrievers will be run.
Names must match the keys provided in the `retrievers` dictionary at initialization.

:returns:
A dictionary with the keys:
- "documents": A deduplicated list of retrieved documents.

:raises ValueError:
If any name in `active_retrievers` does not match a retriever name.
"""
if not self._is_warmed_up:
self.warm_up()

resolved_top_k = top_k if top_k is not None else self.top_k
resolved_filters = filters if filters is not None else self.filters

if active_retrievers is not None:
unknown = set(active_retrievers) - self.retrievers.keys()
if unknown:
raise ValueError(
f"Unknown retriever name(s): {sorted(unknown)}. "
f"Available retrievers: {sorted(self.retrievers.keys())}"
)
retrievers_to_run = {name: self.retrievers[name] for name in active_retrievers}
else:
retrievers_to_run = self.retrievers

all_documents: list[Document] = []
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
future_to_name = {
executor.submit(retriever.run, query=query, filters=resolved_filters, top_k=resolved_top_k): name
for name, retriever in retrievers_to_run.items()
}
for future in as_completed(future_to_name):
name = future_to_name[future]
try:
all_documents.extend(future.result().get("documents", []))
except Exception as e:
raise RuntimeError(f"Retriever '{name}' failed: {e}") from e

return {"documents": _deduplicate_documents(all_documents)}

def to_dict(self) -> dict[str, Any]:
"""
Serializes the component to a dictionary.

:returns:
Dictionary with serialized data.
"""
return default_to_dict(
self,
retrievers={name: component_to_dict(obj=r, name=name) for name, r in self.retrievers.items()},
filters=self.filters,
top_k=self.top_k,
max_workers=self.max_workers,
)

@classmethod
def from_dict(cls, data: dict[str, Any]) -> "MultiRetriever":
"""
Creates an instance of the component from a dictionary.

:param data:
Dictionary with the data to create the component.
"""
retrievers_data = data.get("init_parameters", {}).get("retrievers", {})
if retrievers_data:
retrievers = {}
for name, retriever_data in retrievers_data.items():
try:
imported_class = import_class_by_name(retriever_data["type"])
except ImportError as e:
raise ImportError(
f"Could not import class {retriever_data['type']} for retriever '{name}'. Error: {str(e)}"
) from e
retrievers[name] = component_from_dict(cls=imported_class, data=retriever_data, name=name)
data["init_parameters"]["retrievers"] = retrievers
return default_from_dict(cls, data)
129 changes: 129 additions & 0 deletions haystack/components/retrievers/query_embedding_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any

from haystack import Document, component, default_from_dict, default_to_dict
from haystack.components.embedders.types.protocol import TextEmbedder
from haystack.components.retrievers.types import EmbeddingRetriever
from haystack.core.serialization import component_to_dict


@component
class QueryEmbeddingRetriever:
"""
A component that retrieves documents using a query with an embedding-based retriever.

This component takes a text query, converts it to an embedding using a query embedder, and then uses an
embedding-based retriever to find relevant documents.
The results are sorted by relevance score.

### Usage example

```python
from haystack import Document
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.document_stores.types import DuplicatePolicy
from haystack.components.embedders import SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder
from haystack.components.retrievers import InMemoryEmbeddingRetriever, QueryEmbeddingRetriever
from haystack.components.writers import DocumentWriter

documents = [
Document(content="Renewable energy is energy that is collected from renewable resources."),
Document(content="Solar energy is a type of green energy that is harnessed from the sun."),
Document(content="Wind energy is another type of green energy that is generated by wind turbines."),
Document(content="Geothermal energy is heat that comes from the sub-surface of the earth."),
Document(content="Biomass energy is produced from organic materials, such as plant and animal waste."),
Document(content="Fossil fuels, such as coal, oil, and natural gas, are non-renewable energy sources."),
]

# Populate the document store
doc_store = InMemoryDocumentStore()
doc_embedder = SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")
doc_writer = DocumentWriter(document_store=doc_store, policy=DuplicatePolicy.SKIP)
documents = doc_embedder.run(documents)["documents"]
doc_writer.run(documents=documents)

# Run the retriever
in_memory_retriever = InMemoryEmbeddingRetriever(document_store=doc_store, top_k=1)
query_embedder = SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")
retriever = QueryEmbeddingRetriever(retriever=in_memory_retriever, query_embedder=query_embedder)
result = retriever.run(query="Geothermal energy")

for doc in result["documents"]:
print(f"Content: {doc.content}, Score: {doc.score}")
# >> Content: Geothermal energy is heat that comes from the sub-surface of the earth., Score: 0.8509603046266574
```
"""

def __init__(self, *, retriever: EmbeddingRetriever, query_embedder: TextEmbedder) -> None:
"""
Initialize QueryEmbeddingRetriever.

:param retriever: The embedding-based retriever to use for document retrieval.
:param query_embedder: The query embedder to convert a text query to an embedding.
"""
self.retriever = retriever
self.query_embedder = query_embedder
self._is_warmed_up = False

def warm_up(self) -> None:
"""
Warm up the query embedder and the retriever if any has a warm_up method.
"""
if not self._is_warmed_up:
if hasattr(self.query_embedder, "warm_up") and callable(self.query_embedder.warm_up):
self.query_embedder.warm_up()
if hasattr(self.retriever, "warm_up") and callable(self.retriever.warm_up):
self.retriever.warm_up()
self._is_warmed_up = True

@component.output_types(documents=list[Document])
def run(
self, query: str, filters: dict[str, Any] | None = None, top_k: int | None = None
) -> dict[str, list[Document]]:
"""
Retrieve documents using a single query.

:param query: The query to retrieve documents for.
:param filters: A dictionary of filters to apply when retrieving documents.
:param top_k: The maximum number of documents to return.
:returns:
A dictionary containing:
- `documents`: List of retrieved documents sorted by relevance score.
"""
if not self._is_warmed_up:
self.warm_up()

embedding_result = self.query_embedder.run(text=query)
result = self.retriever.run(query_embedding=embedding_result["embedding"], filters=filters, top_k=top_k)
docs: list[Document] = result["documents"]

# sort
docs.sort(key=lambda x: x.score or 0.0, reverse=True)
return {"documents": docs}

def to_dict(self) -> dict[str, Any]:
"""
Serializes the component to a dictionary.

:returns:
A dictionary representing the serialized component.
"""
return default_to_dict(
self,
retriever=component_to_dict(obj=self.retriever, name="retriever"),
query_embedder=component_to_dict(obj=self.query_embedder, name="query_embedder"),
)

@classmethod
def from_dict(cls, data: dict[str, Any]) -> "QueryEmbeddingRetriever":
"""
Deserializes the component from a dictionary.

:param data: The dictionary to deserialize from.
:returns:
The deserialized component.
"""
return default_from_dict(cls, data)
2 changes: 1 addition & 1 deletion pydoc/retrievers_api.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ loaders:
- search_path: [../haystack/components/retrievers]
modules: ["auto_merging_retriever", "filter_retriever", "in_memory/bm25_retriever",
"in_memory/embedding_retriever", "multi_query_embedding_retriever", "multi_query_text_retriever",
"sentence_window_retriever"]
"multi_retriever", "query_embedding_retriever", "sentence_window_retriever"]
processors:
- type: filter
documented_only: true
Expand Down
20 changes: 20 additions & 0 deletions releasenotes/notes/add-multi-retriever-fc170115f9507fb3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
---
features:
- |
Added two new retriever components: ``MultiRetriever`` and ``QueryEmbeddingRetriever``.

``MultiRetriever`` is a generalisation of hybrid retrieval.
Where hybrid retrieval traditionally combines keyword search (BM25) with vector search, ``MultiRetriever`` lets
you compose any number of retrievers into a single component.

This allows users to:
- Combine multiple retrieval strategies without wiring each retriever individually in a Pipeline.
- Easily enable or disable specific retrievers at runtime using the ``active_retrievers`` parameter, avoiding complex ConditionalRouter setups.

All retrievers are queried in parallel and their results are deduplicated before being returned.

``QueryEmbeddingRetriever`` wraps an embedding-based retriever together with a query embedder into a single
self-contained component that follows the ``TextRetriever`` protocol.

This simplifies integration with ``MultiRetriever``, allowing both text-based and embedding-based retrievers
to be used together seamlessly in a single dict.
Loading
Loading