Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions mteb/models/search_encoder_index/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .default_backend_search import DefaultEncoderSearchBackend
from .faiss_search_backend import FaissEncoderSearchBackend
from .search_backend_protocol import IndexEncoderSearchProtocol

__all__ = [
"DefaultEncoderSearchBackend",
"FaissEncoderSearchBackend",
"IndexEncoderSearchProtocol",
]
99 changes: 99 additions & 0 deletions mteb/models/search_encoder_index/default_backend_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import logging
from collections.abc import Callable

import torch

from mteb.types import Array, TopRankedDocumentsType

logger = logging.getLogger(__name__)


class DefaultEncoderSearchBackend:
"""Streaming backend for encoder-based search.
- Does not store the entire corpus in memory.
- Encodes and searches corpus in chunks.
"""

sub_corpus_embeddings: Array | None = None
idxs: list[str]

def add_document(
self,
embeddings: Array,
idxs: list[str],
) -> None:
"""Add all document embeddings and their IDs to the backend."""
self.sub_corpus_embeddings = embeddings
self.idxs = idxs

def search(
self,
embeddings: Array,
top_k: int,
similarity_fn: Callable[[Array, Array], Array],
top_ranked: TopRankedDocumentsType | None = None,
query_idx_to_id: dict[int, str] | None = None,
) -> tuple[list[list[float]], list[list[int]]]:
"""Search through added corpus embeddings or rerank top-ranked documents."""
if self.sub_corpus_embeddings is None:
raise ValueError("No corpus embeddings found. Did you call add_document()?")

if top_ranked is not None:
if query_idx_to_id is None:
raise ValueError("query_idx_to_id is required when using top_ranked.")

scores_all: list[list[float]] = []
idxs_all: list[list[int]] = []

doc_id_to_idx = {doc_id: i for i, doc_id in enumerate(self.idxs)}

for query_idx, query_emb in enumerate(embeddings):
query_id = query_idx_to_id[query_idx]
ranked_ids = top_ranked.get(query_id)
if not ranked_ids:
logger.warning(f"No top-ranked docs for query {query_id}")
scores_all.append([])
idxs_all.append([])
continue

candidate_idx = [doc_id_to_idx[doc_id] for doc_id in ranked_ids]
candidate_embs = self.sub_corpus_embeddings[candidate_idx]

scores = similarity_fn(
torch.as_tensor(query_emb).unsqueeze(0),
torch.as_tensor(candidate_embs),
)

values, indices = torch.topk(
torch.as_tensor(scores),
k=min(top_k, len(candidate_idx)),
dim=1,
largest=True,
)
scores_all.append(values.squeeze(0).cpu().tolist())
idxs_all.append(indices.squeeze(0).cpu().tolist())

return scores_all, idxs_all

scores = similarity_fn(embeddings, self.sub_corpus_embeddings)
self.sub_corpus_embeddings = None

cos_scores_top_k_values, cos_scores_top_k_idx = torch.topk(
torch.tensor(scores),
min(
top_k + 1,
len(scores[1]) if len(scores) > 1 else len(scores[-1]),
),
dim=1,
largest=True,
)
return (
cos_scores_top_k_values.cpu().tolist(),
cos_scores_top_k_idx.cpu().tolist(),
)

def clear(self) -> None:
"""Clear all stored documents and embeddings from the backend."""
self.sub_corpus_embeddings = None
self.idxs = []
136 changes: 136 additions & 0 deletions mteb/models/search_encoder_index/faiss_search_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import logging
from collections.abc import Callable

import numpy as np
import torch

from mteb._requires_package import requires_package
from mteb.models.model_meta import ScoringFunction
from mteb.models.models_protocols import EncoderProtocol
from mteb.types import Array, TopRankedDocumentsType

logger = logging.getLogger(__name__)


class FaissEncoderSearchBackend:
"""FAISS-based backend for encoder-based search.

Supports both full-corpus retrieval and reranking (via `top_ranked`).

Notes:
- Stores *all* embeddings in memory (IndexFlatIP or IndexFlatL2).
- Expects embeddings to be normalized if cosine similarity is desired.
"""

_normalize: bool = False

def __init__(self, model: EncoderProtocol) -> None:
requires_package(
self,
"faiss",
"FAISS-based search",
install_instruction="pip install mteb[faiss-cpu]",
)

import faiss
from faiss import IndexFlatIP, IndexFlatL2

# https://github.com/facebookresearch/faiss/wiki/Faiss-indexes
if (
model.mteb_model_meta.similarity_fn_name == "dot"
or model.mteb_model_meta.similarity_fn_name is ScoringFunction.DOT_PRODUCT
):
self.index_type = IndexFlatL2
else:
self.index_type = IndexFlatIP
self._normalize = True

self.idxs: list[str] = []
self.index: faiss.Index | None = None

def add_document(self, embeddings: Array, idxs: list[str]) -> None:
"""Add all document embeddings and their IDs to FAISS index."""
import faiss

if isinstance(embeddings, torch.Tensor):
embeddings = embeddings.detach().cpu().numpy()

embeddings = embeddings.astype(np.float32)
self.idxs.extend(idxs)

if self._normalize:
faiss.normalize_L2(embeddings)

dim = embeddings.shape[1]
if self.index is None:
self.index = self.index_type(dim)

self.index.add(embeddings)
logger.info(f"FAISS index built with {len(idxs)} vectors of dim {dim}.")

def search(
self,
embeddings: Array,
top_k: int,
similarity_fn: Callable[[Array, Array], Array],
top_ranked: TopRankedDocumentsType | None = None,
query_idx_to_id: dict[int, str] | None = None,
) -> tuple[list[list[float]], list[list[int]]]:
"""Search using FAISS."""
import faiss

if self.index is None:
raise ValueError("No index built. Call add_document() first.")

if isinstance(embeddings, torch.Tensor):
embeddings = embeddings.detach().cpu().numpy()

if self._normalize:
faiss.normalize_L2(embeddings)

if top_ranked is not None:
if query_idx_to_id is None:
raise ValueError("query_idx_to_id must be provided when reranking.")

doc_id_to_idx = {doc_id: i for i, doc_id in enumerate(self.idxs)}
scores_all: list[list[float]] = []
idxs_all: list[list[int]] = []

for query_idx, query_emb in enumerate(embeddings):
query_id = query_idx_to_id[query_idx]
ranked_ids = top_ranked.get(query_id)
if not ranked_ids:
logger.warning(f"No top-ranked documents for query {query_id}")
scores_all.append([])
idxs_all.append([])
continue

candidate_indices = [doc_id_to_idx[doc_id] for doc_id in ranked_ids]
d = self.index.d
candidate_embs = np.zeros((len(candidate_indices), d), dtype=np.float32)
for j, idx in enumerate(candidate_indices):
candidate_embs[j] = self.index.reconstruct(idx)

scores = similarity_fn(
torch.as_tensor(query_emb).unsqueeze(0),
torch.as_tensor(candidate_embs),
)

values, indices = torch.topk(
torch.as_tensor(scores),
k=min(top_k, len(candidate_indices)),
dim=1,
largest=True,
)
scores_all.append(values.squeeze(0).cpu().tolist())
idxs_all.append(indices.squeeze(0).cpu().tolist())

return scores_all, idxs_all

documents, ids = self.index.search(embeddings.astype(np.float32), top_k)
return documents.tolist(), ids.tolist()

def clear(self) -> None:
"""Clear all stored documents and embeddings from the backend."""
self.index = None
self.idxs = []
50 changes: 50 additions & 0 deletions mteb/models/search_encoder_index/search_backend_protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from collections.abc import Callable
from typing import Protocol

from mteb.types import Array, TopRankedDocumentsType


class IndexEncoderSearchProtocol(Protocol):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class IndexEncoderSearchProtocol(Protocol):
class EncoderSearchProtocol(Protocol):

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can specify that this is only for index and only for encoder, because this can be confused that SentenceTransformerEncoderWrapper will implement it (probably)

"""Protocol for search backends used in encoder-based retrieval."""

def add_document(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def add_document(
def add_documents(

self,
embeddings: Array,
idxs: list[str],
) -> None:
"""Add documents to the search backend.
Args:
embeddings: Embeddings of the documents to add.
idxs: IDs of the documents to add.
"""

def search(
self,
embeddings: Array,
top_k: int,
similarity_fn: Callable[[Array, Array], Array],
top_ranked: TopRankedDocumentsType | None = None,
query_idx_to_id: dict[int, str] | None = None,
) -> tuple[list[list[float]], list[list[int]]]:
"""Search through added corpus embeddings or rerank top-ranked documents.
Supports both full-corpus and reranking search modes:
- Full-corpus mode: `top_ranked=None`, uses added corpus embeddings.
- Reranking mode: `top_ranked` contains mapping {query_id: [doc_ids]}.
Args:
embeddings: Query embeddings, shape (num_queries, dim).
top_k: Number of top results to return.
similarity_fn: Function to compute similarity between query and corpus.
top_ranked: Mapping of query_id -> list of candidate doc_ids. Used for reranking.
query_idx_to_id: Mapping of query index -> query_id. Used for reranking.
Returns:
A tuple (top_k_values, top_k_indices), for each query:
- top_k_values: List of top-k similarity scores.
- top_k_indices: List of indices of the top-k documents in the added corpus.
"""

def clear(self) -> None:
"""Clear all stored documents and embeddings from the backend."""
Loading