-
Notifications
You must be signed in to change notification settings - Fork 504
feat: add search encoder backend #3492
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
c8b2bd3
8a3527f
b2c3f60
51111ca
ae31d1b
2ce10fd
74458c5
05b0ba8
48143c0
7fbc60f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| from .default_backend_search import DefaultEncoderSearchBackend | ||
| from .faiss_search_backend import FaissEncoderSearchBackend | ||
| from .search_backend_protocol import IndexEncoderSearchProtocol | ||
|
|
||
| __all__ = [ | ||
| "DefaultEncoderSearchBackend", | ||
| "FaissEncoderSearchBackend", | ||
| "IndexEncoderSearchProtocol", | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,99 @@ | ||
| import logging | ||
| from collections.abc import Callable | ||
|
|
||
| import torch | ||
|
|
||
| from mteb.types import Array, TopRankedDocumentsType | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class DefaultEncoderSearchBackend: | ||
| """Streaming backend for encoder-based search. | ||
| - Does not store the entire corpus in memory. | ||
| - Encodes and searches corpus in chunks. | ||
| """ | ||
|
|
||
| sub_corpus_embeddings: Array | None = None | ||
| idxs: list[str] | ||
|
|
||
| def add_document( | ||
| self, | ||
| embeddings: Array, | ||
| idxs: list[str], | ||
| ) -> None: | ||
| """Add all document embeddings and their IDs to the backend.""" | ||
| self.sub_corpus_embeddings = embeddings | ||
| self.idxs = idxs | ||
|
|
||
| def search( | ||
| self, | ||
| embeddings: Array, | ||
| top_k: int, | ||
| similarity_fn: Callable[[Array, Array], Array], | ||
| top_ranked: TopRankedDocumentsType | None = None, | ||
| query_idx_to_id: dict[int, str] | None = None, | ||
| ) -> tuple[list[list[float]], list[list[int]]]: | ||
| """Search through added corpus embeddings or rerank top-ranked documents.""" | ||
| if self.sub_corpus_embeddings is None: | ||
| raise ValueError("No corpus embeddings found. Did you call add_document()?") | ||
|
|
||
| if top_ranked is not None: | ||
| if query_idx_to_id is None: | ||
| raise ValueError("query_idx_to_id is required when using top_ranked.") | ||
|
|
||
| scores_all: list[list[float]] = [] | ||
| idxs_all: list[list[int]] = [] | ||
|
|
||
| doc_id_to_idx = {doc_id: i for i, doc_id in enumerate(self.idxs)} | ||
|
|
||
| for query_idx, query_emb in enumerate(embeddings): | ||
| query_id = query_idx_to_id[query_idx] | ||
| ranked_ids = top_ranked.get(query_id) | ||
| if not ranked_ids: | ||
| logger.warning(f"No top-ranked docs for query {query_id}") | ||
| scores_all.append([]) | ||
| idxs_all.append([]) | ||
| continue | ||
|
|
||
| candidate_idx = [doc_id_to_idx[doc_id] for doc_id in ranked_ids] | ||
| candidate_embs = self.sub_corpus_embeddings[candidate_idx] | ||
|
|
||
| scores = similarity_fn( | ||
| torch.as_tensor(query_emb).unsqueeze(0), | ||
| torch.as_tensor(candidate_embs), | ||
| ) | ||
|
|
||
| values, indices = torch.topk( | ||
| torch.as_tensor(scores), | ||
| k=min(top_k, len(candidate_idx)), | ||
| dim=1, | ||
| largest=True, | ||
| ) | ||
| scores_all.append(values.squeeze(0).cpu().tolist()) | ||
| idxs_all.append(indices.squeeze(0).cpu().tolist()) | ||
|
|
||
| return scores_all, idxs_all | ||
|
|
||
| scores = similarity_fn(embeddings, self.sub_corpus_embeddings) | ||
| self.sub_corpus_embeddings = None | ||
|
|
||
| cos_scores_top_k_values, cos_scores_top_k_idx = torch.topk( | ||
| torch.tensor(scores), | ||
| min( | ||
| top_k + 1, | ||
| len(scores[1]) if len(scores) > 1 else len(scores[-1]), | ||
| ), | ||
| dim=1, | ||
| largest=True, | ||
| ) | ||
| return ( | ||
| cos_scores_top_k_values.cpu().tolist(), | ||
| cos_scores_top_k_idx.cpu().tolist(), | ||
| ) | ||
|
|
||
| def clear(self) -> None: | ||
| """Clear all stored documents and embeddings from the backend.""" | ||
| self.sub_corpus_embeddings = None | ||
| self.idxs = [] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,136 @@ | ||
| import logging | ||
| from collections.abc import Callable | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
||
| from mteb._requires_package import requires_package | ||
| from mteb.models.model_meta import ScoringFunction | ||
| from mteb.models.models_protocols import EncoderProtocol | ||
| from mteb.types import Array, TopRankedDocumentsType | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class FaissEncoderSearchBackend: | ||
KennethEnevoldsen marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """FAISS-based backend for encoder-based search. | ||
|
|
||
| Supports both full-corpus retrieval and reranking (via `top_ranked`). | ||
|
|
||
| Notes: | ||
| - Stores *all* embeddings in memory (IndexFlatIP or IndexFlatL2). | ||
| - Expects embeddings to be normalized if cosine similarity is desired. | ||
| """ | ||
|
|
||
| _normalize: bool = False | ||
|
|
||
| def __init__(self, model: EncoderProtocol) -> None: | ||
| requires_package( | ||
| self, | ||
| "faiss", | ||
| "FAISS-based search", | ||
| install_instruction="pip install mteb[faiss-cpu]", | ||
| ) | ||
|
|
||
| import faiss | ||
| from faiss import IndexFlatIP, IndexFlatL2 | ||
|
|
||
| # https://github.com/facebookresearch/faiss/wiki/Faiss-indexes | ||
| if ( | ||
| model.mteb_model_meta.similarity_fn_name == "dot" | ||
| or model.mteb_model_meta.similarity_fn_name is ScoringFunction.DOT_PRODUCT | ||
| ): | ||
| self.index_type = IndexFlatL2 | ||
| else: | ||
| self.index_type = IndexFlatIP | ||
| self._normalize = True | ||
|
|
||
| self.idxs: list[str] = [] | ||
| self.index: faiss.Index | None = None | ||
|
|
||
| def add_document(self, embeddings: Array, idxs: list[str]) -> None: | ||
| """Add all document embeddings and their IDs to FAISS index.""" | ||
| import faiss | ||
|
|
||
| if isinstance(embeddings, torch.Tensor): | ||
| embeddings = embeddings.detach().cpu().numpy() | ||
|
|
||
| embeddings = embeddings.astype(np.float32) | ||
| self.idxs.extend(idxs) | ||
|
|
||
| if self._normalize: | ||
| faiss.normalize_L2(embeddings) | ||
|
|
||
| dim = embeddings.shape[1] | ||
| if self.index is None: | ||
| self.index = self.index_type(dim) | ||
|
|
||
| self.index.add(embeddings) | ||
| logger.info(f"FAISS index built with {len(idxs)} vectors of dim {dim}.") | ||
|
|
||
| def search( | ||
| self, | ||
| embeddings: Array, | ||
| top_k: int, | ||
| similarity_fn: Callable[[Array, Array], Array], | ||
| top_ranked: TopRankedDocumentsType | None = None, | ||
| query_idx_to_id: dict[int, str] | None = None, | ||
| ) -> tuple[list[list[float]], list[list[int]]]: | ||
| """Search using FAISS.""" | ||
| import faiss | ||
|
|
||
| if self.index is None: | ||
| raise ValueError("No index built. Call add_document() first.") | ||
|
|
||
| if isinstance(embeddings, torch.Tensor): | ||
| embeddings = embeddings.detach().cpu().numpy() | ||
|
|
||
| if self._normalize: | ||
| faiss.normalize_L2(embeddings) | ||
|
|
||
| if top_ranked is not None: | ||
| if query_idx_to_id is None: | ||
| raise ValueError("query_idx_to_id must be provided when reranking.") | ||
|
|
||
| doc_id_to_idx = {doc_id: i for i, doc_id in enumerate(self.idxs)} | ||
| scores_all: list[list[float]] = [] | ||
| idxs_all: list[list[int]] = [] | ||
|
|
||
| for query_idx, query_emb in enumerate(embeddings): | ||
| query_id = query_idx_to_id[query_idx] | ||
| ranked_ids = top_ranked.get(query_id) | ||
| if not ranked_ids: | ||
| logger.warning(f"No top-ranked documents for query {query_id}") | ||
| scores_all.append([]) | ||
| idxs_all.append([]) | ||
| continue | ||
|
|
||
| candidate_indices = [doc_id_to_idx[doc_id] for doc_id in ranked_ids] | ||
| d = self.index.d | ||
| candidate_embs = np.zeros((len(candidate_indices), d), dtype=np.float32) | ||
| for j, idx in enumerate(candidate_indices): | ||
| candidate_embs[j] = self.index.reconstruct(idx) | ||
|
|
||
| scores = similarity_fn( | ||
| torch.as_tensor(query_emb).unsqueeze(0), | ||
| torch.as_tensor(candidate_embs), | ||
| ) | ||
|
|
||
| values, indices = torch.topk( | ||
| torch.as_tensor(scores), | ||
| k=min(top_k, len(candidate_indices)), | ||
| dim=1, | ||
| largest=True, | ||
| ) | ||
| scores_all.append(values.squeeze(0).cpu().tolist()) | ||
| idxs_all.append(indices.squeeze(0).cpu().tolist()) | ||
|
|
||
| return scores_all, idxs_all | ||
|
|
||
| documents, ids = self.index.search(embeddings.astype(np.float32), top_k) | ||
| return documents.tolist(), ids.tolist() | ||
|
|
||
| def clear(self) -> None: | ||
| """Clear all stored documents and embeddings from the backend.""" | ||
| self.index = None | ||
| self.idxs = [] | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,50 @@ | ||||||
| from collections.abc import Callable | ||||||
| from typing import Protocol | ||||||
|
|
||||||
| from mteb.types import Array, TopRankedDocumentsType | ||||||
|
|
||||||
|
|
||||||
| class IndexEncoderSearchProtocol(Protocol): | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can specify that this is only for index and only for encoder, because this can be confused that |
||||||
| """Protocol for search backends used in encoder-based retrieval.""" | ||||||
|
|
||||||
| def add_document( | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| self, | ||||||
| embeddings: Array, | ||||||
| idxs: list[str], | ||||||
| ) -> None: | ||||||
| """Add documents to the search backend. | ||||||
| Args: | ||||||
| embeddings: Embeddings of the documents to add. | ||||||
| idxs: IDs of the documents to add. | ||||||
| """ | ||||||
|
|
||||||
| def search( | ||||||
| self, | ||||||
| embeddings: Array, | ||||||
| top_k: int, | ||||||
| similarity_fn: Callable[[Array, Array], Array], | ||||||
| top_ranked: TopRankedDocumentsType | None = None, | ||||||
| query_idx_to_id: dict[int, str] | None = None, | ||||||
| ) -> tuple[list[list[float]], list[list[int]]]: | ||||||
| """Search through added corpus embeddings or rerank top-ranked documents. | ||||||
| Supports both full-corpus and reranking search modes: | ||||||
| - Full-corpus mode: `top_ranked=None`, uses added corpus embeddings. | ||||||
| - Reranking mode: `top_ranked` contains mapping {query_id: [doc_ids]}. | ||||||
| Args: | ||||||
| embeddings: Query embeddings, shape (num_queries, dim). | ||||||
| top_k: Number of top results to return. | ||||||
| similarity_fn: Function to compute similarity between query and corpus. | ||||||
| top_ranked: Mapping of query_id -> list of candidate doc_ids. Used for reranking. | ||||||
| query_idx_to_id: Mapping of query index -> query_id. Used for reranking. | ||||||
| Returns: | ||||||
| A tuple (top_k_values, top_k_indices), for each query: | ||||||
| - top_k_values: List of top-k similarity scores. | ||||||
| - top_k_indices: List of indices of the top-k documents in the added corpus. | ||||||
| """ | ||||||
|
|
||||||
| def clear(self) -> None: | ||||||
| """Clear all stored documents and embeddings from the backend.""" | ||||||
Uh oh!
There was an error while loading. Please reload this page.