diff --git a/mteb/models/model_implementations/random_baseline.py b/mteb/models/model_implementations/random_baseline.py index f8bac508e2..562b54914c 100644 --- a/mteb/models/model_implementations/random_baseline.py +++ b/mteb/models/model_implementations/random_baseline.py @@ -8,6 +8,10 @@ from mteb.abstasks.task_metadata import TaskMetadata from mteb.models.model_meta import ModelMeta +from mteb.similarity_functions import ( + select_pairwise_similarity, + select_similarity, +) from mteb.types._encoder_io import Array, BatchedInput, PromptType @@ -155,15 +159,9 @@ def similarity( Returns: Cosine similarity matrix between the two sets of embeddings """ - norm1 = np.linalg.norm( - embeddings1.reshape(-1, self.embedding_dim), axis=1, keepdims=True - ) - norm2 = np.linalg.norm( - embeddings2.reshape(-1, self.embedding_dim), axis=1, keepdims=True + return select_similarity( + embeddings1, embeddings2, self.mteb_model_meta.similarity_fn_name ) - normalized1 = embeddings1 / (norm1 + 1e-10) - normalized2 = embeddings2 / (norm2 + 1e-10) - return np.dot(normalized1, normalized2.T) def similarity_pairwise( self, @@ -179,17 +177,9 @@ def similarity_pairwise( Returns: Cosine similarity for each pair of embeddings """ - norm1 = np.linalg.norm( - embeddings1.reshape(-1, self.embedding_dim), axis=1, keepdims=True - ) - norm2 = np.linalg.norm( - embeddings2.reshape(-1, self.embedding_dim), axis=1, keepdims=True + return select_pairwise_similarity( + embeddings1, embeddings2, self.mteb_model_meta.similarity_fn_name ) - normalized1 = embeddings1 / (norm1 + 1e-10) - normalized2 = embeddings2 / (norm2 + 1e-10) - normalized1 = np.asarray(normalized1) - normalized2 = np.asarray(normalized2) - return np.sum(normalized1 * normalized2, axis=1) random_encoder_baseline = ModelMeta( diff --git a/mteb/models/search_encoder_index/__init__.py b/mteb/models/search_encoder_index/__init__.py new file mode 100644 index 0000000000..4b4b67613e --- /dev/null +++ b/mteb/models/search_encoder_index/__init__.py @@ -0,0 +1,8 @@ +from .search_backend_protocol import IndexEncoderSearchProtocol +from .search_indexes import FaissSearchIndex, StreamingSearchIndex + +__all__ = [ + "FaissSearchIndex", + "IndexEncoderSearchProtocol", + "StreamingSearchIndex", +] diff --git a/mteb/models/search_encoder_index/search_backend_protocol.py b/mteb/models/search_encoder_index/search_backend_protocol.py new file mode 100644 index 0000000000..a91ef5cd92 --- /dev/null +++ b/mteb/models/search_encoder_index/search_backend_protocol.py @@ -0,0 +1,50 @@ +from collections.abc import Callable +from typing import Protocol + +from mteb.types import Array, TopRankedDocumentsType + + +class IndexEncoderSearchProtocol(Protocol): + """Protocol for search backends used in encoder-based retrieval.""" + + def add_document( + self, + embeddings: Array, + idxs: list[str], + ) -> None: + """Add documents to the search backend. + + Args: + embeddings: Embeddings of the documents to add. + idxs: IDs of the documents to add. + """ + + def search( + self, + embeddings: Array, + top_k: int, + similarity_fn: Callable[[Array, Array], Array], + top_ranked: TopRankedDocumentsType | None = None, + query_idx_to_id: dict[int, str] | None = None, + ) -> tuple[list[list[float]], list[list[int]]]: + """Search through added corpus embeddings or rerank top-ranked documents. + + Supports both full-corpus and reranking search modes: + - Full-corpus mode: `top_ranked=None`, uses added corpus embeddings. + - Reranking mode: `top_ranked` contains mapping {query_id: [doc_ids]}. + + Args: + embeddings: Query embeddings, shape (num_queries, dim). + top_k: Number of top results to return. + similarity_fn: Function to compute similarity between query and corpus. + top_ranked: Mapping of query_id -> list of candidate doc_ids. Used for reranking. + query_idx_to_id: Mapping of query index -> query_id. Used for reranking. + + Returns: + A tuple (top_k_values, top_k_indices), for each query: + - top_k_values: List of top-k similarity scores. + - top_k_indices: List of indices of the top-k documents in the added corpus. + """ + + def clear(self) -> None: + """Clear all stored documents and embeddings from the backend.""" diff --git a/mteb/models/search_encoder_index/search_indexes/__init__.py b/mteb/models/search_encoder_index/search_indexes/__init__.py new file mode 100644 index 0000000000..c886bc9655 --- /dev/null +++ b/mteb/models/search_encoder_index/search_indexes/__init__.py @@ -0,0 +1,7 @@ +from .faiss_search_index import FaissSearchIndex +from .streaming_search_index import StreamingSearchIndex + +__all__ = [ + "FaissSearchIndex", + "StreamingSearchIndex", +] diff --git a/mteb/models/search_encoder_index/search_indexes/faiss_search_index.py b/mteb/models/search_encoder_index/search_indexes/faiss_search_index.py new file mode 100644 index 0000000000..c3552ccfa6 --- /dev/null +++ b/mteb/models/search_encoder_index/search_indexes/faiss_search_index.py @@ -0,0 +1,157 @@ +import logging +from collections.abc import Callable + +import numpy as np +import torch + +from mteb._requires_package import requires_package +from mteb.models.model_meta import ScoringFunction +from mteb.models.models_protocols import EncoderProtocol +from mteb.types import Array, TopRankedDocumentsType + +logger = logging.getLogger(__name__) + + +class FaissSearchIndex: + """FAISS-based backend for encoder-based search. + + Supports both full-corpus retrieval and reranking (via `top_ranked`). + + Notes: + - Stores *all* embeddings in memory (IndexFlatIP or IndexFlatL2). + - Expects embeddings to be normalized if cosine similarity is desired. + """ + + _normalize: bool = False + + def __init__(self, model: EncoderProtocol) -> None: + requires_package( + self, + "faiss", + "FAISS-based search", + install_instruction="pip install mteb[faiss-cpu]", + ) + + import faiss + from faiss import IndexFlatIP, IndexFlatL2 + + # https://github.com/facebookresearch/faiss/wiki/Faiss-indexes + if model.mteb_model_meta.similarity_fn_name is ScoringFunction.DOT_PRODUCT: + self.index_type = IndexFlatIP + elif model.mteb_model_meta.similarity_fn_name is ScoringFunction.COSINE: + self.index_type = IndexFlatIP + self._normalize = True + elif model.mteb_model_meta.similarity_fn_name is ScoringFunction.EUCLIDEAN: + self.index_type = IndexFlatL2 + else: + raise ValueError( + f"FAISS backend does not support similarity function {model.mteb_model_meta.similarity_fn_name}. " + f"Available: {ScoringFunction.DOT_PRODUCT}, {ScoringFunction.COSINE}." + ) + + self.idxs: list[str] = [] + self.index: faiss.Index | None = None + + def add_document(self, embeddings: Array, idxs: list[str]) -> None: + """Add all document embeddings and their IDs to FAISS index.""" + import faiss + + if isinstance(embeddings, torch.Tensor): + embeddings = embeddings.detach().cpu().numpy() + + embeddings = embeddings.astype(np.float32) + self.idxs.extend(idxs) + + if self._normalize: + faiss.normalize_L2(embeddings) + + dim = embeddings.shape[1] + if self.index is None: + self.index = self.index_type(dim) + + self.index.add(embeddings) + logger.info(f"FAISS index built with {len(idxs)} vectors of dim {dim}.") + + def search( + self, + embeddings: Array, + top_k: int, + similarity_fn: Callable[[Array, Array], Array], + top_ranked: TopRankedDocumentsType | None = None, + query_idx_to_id: dict[int, str] | None = None, + ) -> tuple[list[list[float]], list[list[int]]]: + """Search using FAISS.""" + import faiss + + if self.index is None: + raise ValueError("No index built. Call add_document() first.") + + if isinstance(embeddings, torch.Tensor): + embeddings = embeddings.detach().cpu().numpy() + + if self._normalize: + faiss.normalize_L2(embeddings) + + if top_ranked is not None: + if query_idx_to_id is None: + raise ValueError("query_idx_to_id must be provided when reranking.") + + similarities, ids = self._reranking( + embeddings, + top_k, + top_ranked=top_ranked, + query_idx_to_id=query_idx_to_id, + ) + else: + similarities, ids = self.index.search(embeddings.astype(np.float32), top_k) + similarities = similarities.tolist() + ids = ids.tolist() + + if issubclass(self.index_type, faiss.IndexFlatL2): + similarities = -np.sqrt(np.maximum(similarities, 0)) + + return similarities, ids + + def _reranking( + self, + embeddings: Array, + top_k: int, + top_ranked: TopRankedDocumentsType | None = None, + query_idx_to_id: dict[int, str] | None = None, + ) -> tuple[list[list[float]], list[list[int]]]: + doc_id_to_idx = {doc_id: i for i, doc_id in enumerate(self.idxs)} + scores_all: list[list[float]] = [] + idxs_all: list[list[int]] = [] + + for query_idx, query_emb in enumerate(embeddings): + query_id = query_idx_to_id[query_idx] + ranked_ids = top_ranked.get(query_id) + if not ranked_ids: + logger.warning(f"No top-ranked documents for query {query_id}") + scores_all.append([]) + idxs_all.append([]) + continue + + candidate_indices = [doc_id_to_idx[doc_id] for doc_id in ranked_ids] + d = self.index.d + candidate_embs = np.vstack( + [self.index.reconstruct(idx) for idx in candidate_indices] + ) + sub_reranking_index = self.index_type(d) + sub_reranking_index.add(candidate_embs) + + # Search returns scores and indices in one call + scores, local_indices = sub_reranking_index.search( + query_emb.reshape(1, -1).astype(np.float32), + min(top_k, len(candidate_indices)), + ) + # faiss will output 2d arrays even for single query + scores_all.append(scores[0].tolist()) + idxs_all.append(local_indices[0].tolist()) + + return scores_all, idxs_all + + def clear(self) -> None: + """Clear all stored documents and embeddings from the backend.""" + self.index = None + self.idxs = [] diff --git a/mteb/models/search_encoder_index/search_indexes/streaming_search_index.py b/mteb/models/search_encoder_index/search_indexes/streaming_search_index.py new file mode 100644 index 0000000000..5a2c1f3b05 --- /dev/null +++ b/mteb/models/search_encoder_index/search_indexes/streaming_search_index.py @@ -0,0 +1,99 @@ +import logging +from collections.abc import Callable + +import torch + +from mteb.types import Array, TopRankedDocumentsType + +logger = logging.getLogger(__name__) + + +class StreamingSearchIndex: + """Streaming backend for encoder-based search. + + - Does not store the entire corpus in memory. + - Encodes and searches corpus in chunks. + """ + + sub_corpus_embeddings: Array | None = None + idxs: list[str] + + def add_document( + self, + embeddings: Array, + idxs: list[str], + ) -> None: + """Add all document embeddings and their IDs to the backend.""" + self.sub_corpus_embeddings = embeddings + self.idxs = idxs + + def search( + self, + embeddings: Array, + top_k: int, + similarity_fn: Callable[[Array, Array], Array], + top_ranked: TopRankedDocumentsType | None = None, + query_idx_to_id: dict[int, str] | None = None, + ) -> tuple[list[list[float]], list[list[int]]]: + """Search through added corpus embeddings or rerank top-ranked documents.""" + if self.sub_corpus_embeddings is None: + raise ValueError("No corpus embeddings found. Did you call add_document()?") + + if top_ranked is not None: + if query_idx_to_id is None: + raise ValueError("query_idx_to_id is required when using top_ranked.") + + scores_all: list[list[float]] = [] + idxs_all: list[list[int]] = [] + + doc_id_to_idx = {doc_id: i for i, doc_id in enumerate(self.idxs)} + + for query_idx, query_emb in enumerate(embeddings): + query_id = query_idx_to_id[query_idx] + ranked_ids = top_ranked.get(query_id) + if not ranked_ids: + logger.warning(f"No top-ranked docs for query {query_id}") + scores_all.append([]) + idxs_all.append([]) + continue + + candidate_idx = [doc_id_to_idx[doc_id] for doc_id in ranked_ids] + candidate_embs = self.sub_corpus_embeddings[candidate_idx] + + scores = similarity_fn( + torch.as_tensor(query_emb).unsqueeze(0), + torch.as_tensor(candidate_embs), + ) + + values, indices = torch.topk( + torch.as_tensor(scores), + k=min(top_k, len(candidate_idx)), + dim=1, + largest=True, + ) + scores_all.append(values.squeeze(0).cpu().tolist()) + idxs_all.append(indices.squeeze(0).cpu().tolist()) + + return scores_all, idxs_all + + scores = similarity_fn(embeddings, self.sub_corpus_embeddings) + self.sub_corpus_embeddings = None + + cos_scores_top_k_values, cos_scores_top_k_idx = torch.topk( + torch.tensor(scores), + min( + top_k + 1, + len(scores[1]) if len(scores) > 1 else len(scores[-1]), + ), + dim=1, + largest=True, + ) + return ( + cos_scores_top_k_values.cpu().tolist(), + cos_scores_top_k_idx.cpu().tolist(), + ) + + def clear(self) -> None: + """Clear all stored documents and embeddings from the backend.""" + self.sub_corpus_embeddings = None + self.idxs = [] diff --git a/mteb/models/search_wrappers.py b/mteb/models/search_wrappers.py index fa492420c0..a771ee4601 100644 --- a/mteb/models/search_wrappers.py +++ b/mteb/models/search_wrappers.py @@ -2,7 +2,6 @@ import logging from typing import Any -import torch from datasets import Dataset from torch.utils.data import DataLoader @@ -21,6 +20,8 @@ ) from .models_protocols import CrossEncoderProtocol, EncoderProtocol +from .search_encoder_index import StreamingSearchIndex +from .search_encoder_index.search_backend_protocol import IndexEncoderSearchProtocol logger = logging.getLogger(__name__) @@ -28,13 +29,19 @@ class SearchEncoderWrapper: """Wrapper for Encoder models to be used in search tasks.""" - corpus_chunk_size = 50_000 task_corpus: CorpusDatasetType | None - def __init__(self, model: EncoderProtocol): + def __init__( + self, + model: EncoderProtocol, + corpus_chunk_size: int = 50_000, + index_backend: IndexEncoderSearchProtocol = StreamingSearchIndex(), + ) -> None: self.model = model self.task_corpus = None self.mteb_model_meta = model.mteb_model_meta + self.corpus_chunk_size = corpus_chunk_size + self.index_backend = index_backend def index( self, @@ -129,6 +136,7 @@ def search( # Reset the task corpus dataloader to None to free up memory self.task_corpus = None + self.index_backend.clear() results = {qid: {} for qid in query_idx_to_id.values()} for qid in result_heaps: @@ -173,23 +181,13 @@ def _full_corpus_search( prompt_type=PromptType.document, **encode_kwargs, ) - + self.index_backend.add_document(sub_corpus_embeddings, sub_corpus_ids) # Compute similarities using either cosine-similarity or dot product logger.info("Computing Similarities...") - scores = self.model.similarity(query_embeddings, sub_corpus_embeddings) - - # get top-k values - cos_scores_top_k_values, cos_scores_top_k_idx = torch.topk( - torch.tensor(scores), - min( - top_k + 1, - len(scores[1]) if len(scores) > 1 else len(scores[-1]), - ), - dim=1, - largest=True, + + cos_scores_top_k_values, cos_scores_top_k_idx = self.index_backend.search( + query_embeddings, top_k, self.model.similarity ) - cos_scores_top_k_idx = cos_scores_top_k_idx.cpu().tolist() - cos_scores_top_k_values = cos_scores_top_k_values.cpu().tolist() for query_itr in range(len(query_embeddings)): query_id = query_idx_to_id[query_itr] @@ -217,13 +215,8 @@ def _rerank_documents( hf_split: str, encode_kwargs: dict[str, Any], ) -> dict[str, list[tuple[float, str]]]: - """Rerank documents based on pre-ranked documents. - - Returns: - A dictionary mapping query IDs to a list of tuples, each containing a relevance score and a document ID. - """ + """Rerank documents using backend's search with top_ranked support.""" result_heaps = {qid: [] for qid in query_idx_to_id.values()} - doc_id_to_idx = {doc["id"]: idx for idx, doc in enumerate(self.task_corpus)} all_doc_embeddings = self.model.encode( create_dataloader( @@ -238,52 +231,24 @@ def _rerank_documents( prompt_type=PromptType.document, **encode_kwargs, ) + all_doc_ids = [doc["id"] for doc in self.task_corpus] + self.index_backend.add_document(all_doc_embeddings, all_doc_ids) + + cos_scores_top_k_values, cos_scores_top_k_idx = self.index_backend.search( + query_embeddings, + top_k, + similarity_fn=self.model.similarity, + top_ranked=top_ranked, + query_idx_to_id=query_idx_to_id, + ) - # Process each query - for query_idx, query_embedding in enumerate(query_embeddings): - query_id = query_idx_to_id[query_idx] - if query_id not in top_ranked: - logger.warning(f"No pre-ranked documents found for query {query_id}") - continue - - ranked_ids = top_ranked[query_id] - doc_indices = torch.tensor([doc_id_to_idx[doc_id] for doc_id in ranked_ids]) - query_doc_embeddings = torch.as_tensor(all_doc_embeddings[doc_indices]) - - # Ensure query embedding is on the correct device and has correct shape - query_embedding = torch.as_tensor(query_embedding).unsqueeze(0) - - scores = self.model.similarity( - query_embedding, - query_doc_embeddings, - ) - scores = torch.as_tensor(scores) - - # Handle NaN values - is_nan = torch.isnan(scores) - if is_nan.sum() > 0: - raise ValueError( - f"NaN values detected in the similarity scores: {is_nan.sum()}" - ) - - # Compute top-k scores - scores_top_k_values, scores_top_k_idx = torch.topk( - scores, - min(top_k, len(ranked_ids)), - dim=1, - largest=True, - ) - - # Move results back to CPU for heap operations - scores_top_k_values = scores_top_k_values.cpu() - scores_top_k_idx = scores_top_k_idx.cpu() - - # Build result heap - for doc_idx, score in zip( - scores_top_k_idx[0].tolist(), - scores_top_k_values[0].tolist(), + for query_itr, query_id in query_idx_to_id.items(): + ranked_ids = top_ranked.get(query_id, []) + for score, idx in zip( + cos_scores_top_k_values[query_itr], + cos_scores_top_k_idx[query_itr], ): - corpus_id = ranked_ids[doc_idx] + corpus_id = ranked_ids[idx] heapq.heappush(result_heaps[query_id], (score, corpus_id)) return result_heaps diff --git a/mteb/similarity_functions.py b/mteb/similarity_functions.py index b8e2f5e9ad..1624a034d1 100644 --- a/mteb/similarity_functions.py +++ b/mteb/similarity_functions.py @@ -1,6 +1,7 @@ import torch from mteb.models import EncoderProtocol +from mteb.models.model_meta import ScoringFunction from mteb.types import Array @@ -38,6 +39,54 @@ def compute_pairwise_similarity( return pairwise_cos_sim(embedding1, embedding2) +def select_similarity( + embedding1: Array, + embedding2: Array, + similarity_fn: ScoringFunction, +) -> Array: + """Compute similarity between two sets of embeddings using the specified similarity function. + + Args: + embedding1: The first set of embeddings. + embedding2: The second set of embeddings. + similarity_fn: The similarity function to use (COSINE, DOT_PRODUCT, EUCLIDEAN). + + Returns: + Array: The computed similarity scores. + """ + if similarity_fn is ScoringFunction.COSINE: + return cos_sim(embedding1, embedding2) + elif similarity_fn is ScoringFunction.DOT_PRODUCT: + return dot_score(embedding1, embedding2) + elif similarity_fn is ScoringFunction.EUCLIDEAN: + return euclidean_sim(embedding1, embedding2) + raise ValueError(f"Unsupported similarity function: {similarity_fn}") + + +def select_pairwise_similarity( + embedding1: Array, + embedding2: Array, + similarity_fn: ScoringFunction, +) -> Array: + """Compute pairwise similarity between two sets of embeddings using the specified similarity function. + + Args: + embedding1: The first set of embeddings. + embedding2: The second set of embeddings. + similarity_fn: The similarity function to use (COSINE, DOT_PRODUCT, EUCLIDEAN). + + Returns: + Array: The computed pairwise similarity scores. + """ + if similarity_fn is ScoringFunction.COSINE: + return pairwise_cos_sim(embedding1, embedding2) + elif similarity_fn is ScoringFunction.DOT_PRODUCT: + return pairwise_dot_score(embedding1, embedding2) + elif similarity_fn is ScoringFunction.EUCLIDEAN: + return pairwise_euclidean_sim(embedding1, embedding2) + raise ValueError(f"Unsupported similarity function: {similarity_fn}") + + def _normalize_embeddings(embeddings: Array) -> torch.Tensor: """Normalizes the embeddings matrix, so that each sentence embedding has unit length. diff --git a/tests/test_search_index/__init__.py b/tests/test_search_index/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_search_index/test_search_index.py b/tests/test_search_index/test_search_index.py new file mode 100644 index 0000000000..3a586c19f6 --- /dev/null +++ b/tests/test_search_index/test_search_index.py @@ -0,0 +1,76 @@ +import json +from copy import deepcopy +from pathlib import Path + +import pytest + +import mteb +from mteb.abstasks import AbsTaskRetrieval +from mteb.models import SearchEncoderWrapper +from mteb.models.model_meta import ScoringFunction +from mteb.models.search_encoder_index import FaissSearchIndex, StreamingSearchIndex +from tests.mock_tasks import ( + MockRerankingTask, + MockRetrievalTask, +) + + +@pytest.mark.parametrize( + "task", + [ + MockRetrievalTask(), + MockRerankingTask(), + ], +) +@pytest.mark.parametrize( + "similarity", + [ScoringFunction.DOT_PRODUCT, ScoringFunction.COSINE, ScoringFunction.EUCLIDEAN], +) +def test_retrieval_backends( + task: AbsTaskRetrieval, similarity: ScoringFunction, tmp_path: Path +): + """Test different retrieval backends for retrieval and reranking tasks.""" + model = mteb.get_model("baseline/random-encoder-baseline") + model_meta = deepcopy(model.mteb_model_meta) + model_meta.similarity_fn_name = similarity + model.mteb_model_meta = model_meta + + python_backend = SearchEncoderWrapper(model, index_backend=StreamingSearchIndex()) + faiss_backend = SearchEncoderWrapper(model, index_backend=FaissSearchIndex(model)) + + python_backend_predictions = tmp_path / "python_backend_predictions" + faiss_backend_predictions = tmp_path / "faiss_backend_predictions" + + python_results = mteb.evaluate( + python_backend, + task, + prediction_folder=python_backend_predictions, + cache=None, + ) + faiss_results = mteb.evaluate( + faiss_backend, + task, + prediction_folder=faiss_backend_predictions, + cache=None, + ) + + assert ( + python_results.task_results[0].get_score() + == faiss_results.task_results[0].get_score() + ) + + with task._predictions_path(python_backend_predictions).open() as f: + full_python_predictions = json.load(f) + python_predictions = full_python_predictions["default"]["test"] + + with task._predictions_path(faiss_backend_predictions).open() as f: + full_faiss_predictions = json.load(f) + faiss_predictions = full_faiss_predictions["default"]["test"] + + for python_pred_key, faiss_pred_key in zip( + sorted(python_predictions.keys()), sorted(faiss_predictions.keys()) + ): + assert python_pred_key == faiss_pred_key + assert python_predictions[python_pred_key] == pytest.approx( + faiss_predictions[faiss_pred_key] + )