|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import logging |
| 4 | +from typing import TYPE_CHECKING |
| 5 | + |
| 6 | +import numpy as np |
| 7 | + |
| 8 | +from mteb._create_dataloaders import _create_text_queries_dataloader |
| 9 | +from mteb._requires_package import requires_package |
| 10 | +from mteb.models.model_meta import ModelMeta |
| 11 | + |
| 12 | +if TYPE_CHECKING: |
| 13 | + from mteb.abstasks.task_metadata import TaskMetadata |
| 14 | + from mteb.models.models_protocols import SearchProtocol |
| 15 | + from mteb.types import ( |
| 16 | + CorpusDatasetType, |
| 17 | + EncodeKwargs, |
| 18 | + QueryDatasetType, |
| 19 | + RetrievalOutputType, |
| 20 | + TopRankedDocumentsType, |
| 21 | + ) |
| 22 | + |
| 23 | +logger = logging.getLogger(__name__) |
| 24 | + |
| 25 | + |
| 26 | +def _composite_prior( |
| 27 | + query_tfs: np.ndarray, |
| 28 | + doc_lengths: np.ndarray, |
| 29 | + avg_dl: float, |
| 30 | +) -> np.ndarray: |
| 31 | + """Composite prior from Bayesian BM25 paper Section 4.2. |
| 32 | +
|
| 33 | + Combines term frequency prior (Def 4.2.1) and field norm prior |
| 34 | + (Def 4.2.2) to produce document-level prior probabilities. |
| 35 | +
|
| 36 | + Args: |
| 37 | + query_tfs: Total query term frequency per candidate document. |
| 38 | + doc_lengths: Token count per candidate document. |
| 39 | + avg_dl: Average document length across the corpus. |
| 40 | +
|
| 41 | + Returns: |
| 42 | + Prior probabilities clipped to [0.1, 0.9]. |
| 43 | + """ |
| 44 | + p_tf = 0.2 + 0.7 * np.minimum(1.0, query_tfs / 10.0) |
| 45 | + |
| 46 | + if avg_dl > 0: |
| 47 | + norm_ratio = doc_lengths / avg_dl |
| 48 | + else: |
| 49 | + norm_ratio = np.ones_like(doc_lengths, dtype=np.float64) |
| 50 | + p_norm = 0.3 + 0.6 * (1.0 - np.minimum(1.0, np.abs(norm_ratio - 1.0) * 0.5)) |
| 51 | + |
| 52 | + prior = 0.7 * p_tf + 0.3 * p_norm |
| 53 | + return np.clip(prior, 0.1, 0.9) |
| 54 | + |
| 55 | + |
| 56 | +def bb25_loader(model_name, **kwargs) -> SearchProtocol: |
| 57 | + requires_package(bb25_loader, "bm25s", model_name, "pip install mteb[bm25s]") |
| 58 | + import bm25s |
| 59 | + import Stemmer |
| 60 | + |
| 61 | + class BB25Search: |
| 62 | + """Bayesian BM25 search using bm25s as the BM25 backend. |
| 63 | +
|
| 64 | + Bayesian BM25 transforms traditional BM25 scores into calibrated |
| 65 | + probability estimates in [0, 1] through Bayesian inference with a |
| 66 | + sigmoid likelihood model and composite prior design. |
| 67 | +
|
| 68 | + With the default prior_weight=0.0, the prior is a flat 0.5 for |
| 69 | + all documents. Since the sigmoid likelihood is strictly monotonic, |
| 70 | + this preserves BM25 rankings exactly while outputting calibrated |
| 71 | + probabilities suitable for hybrid search score fusion. |
| 72 | +
|
| 73 | + Setting prior_weight > 0 enables the Composite Prior (Section 4.2 |
| 74 | + of the paper), which re-adjusts rankings based on document-level |
| 75 | + evidence from query term frequency and document length signals. |
| 76 | +
|
| 77 | + Architecture: |
| 78 | + 1. bm25s handles fast BM25 indexing and top-k retrieval. |
| 79 | + 2. Retrieved candidates are re-scored with Bayesian posterior: |
| 80 | + - Dynamic beta = median(BM25 scores) per query |
| 81 | + - Sigmoid likelihood: sigma(alpha * (score - beta)) |
| 82 | + - prior = 0.5 + prior_weight * (composite_prior - 0.5) |
| 83 | + - Posterior via Bayes' rule |
| 84 | + """ |
| 85 | + |
| 86 | + retriever: bm25s.BM25 |
| 87 | + corpus_idx_to_id: dict[int, str] |
| 88 | + |
| 89 | + def __init__( |
| 90 | + self, |
| 91 | + previous_results: str | None = None, |
| 92 | + stopwords: str = "en", |
| 93 | + stemmer_language: str | None = "english", |
| 94 | + k1: float = 1.5, |
| 95 | + b: float = 0.75, |
| 96 | + alpha: float = 1.0, |
| 97 | + prior_weight: float = 0.0, |
| 98 | + **kwargs, |
| 99 | + ): |
| 100 | + self.k1 = k1 |
| 101 | + self.b = b |
| 102 | + self.alpha = alpha |
| 103 | + self.prior_weight = prior_weight |
| 104 | + self.stopwords = stopwords |
| 105 | + self.stemmer = ( |
| 106 | + Stemmer.Stemmer(stemmer_language) if stemmer_language else None |
| 107 | + ) |
| 108 | + |
| 109 | + def _encode(self, texts: list[str]): |
| 110 | + """Tokenize texts using bm25s. Not to be confused with EncoderProtocol.encode().""" |
| 111 | + return bm25s.tokenize(texts, stopwords=self.stopwords, stemmer=self.stemmer) |
| 112 | + |
| 113 | + def index( |
| 114 | + self, |
| 115 | + corpus: CorpusDatasetType, |
| 116 | + *, |
| 117 | + task_metadata: TaskMetadata, |
| 118 | + hf_split: str, |
| 119 | + hf_subset: str, |
| 120 | + encode_kwargs: EncodeKwargs, |
| 121 | + num_proc: int | None = None, |
| 122 | + ) -> None: |
| 123 | + logger.info("Encoding Corpus...") |
| 124 | + corpus_texts = [ |
| 125 | + "\n".join([doc.get("title", ""), doc["text"]]) for doc in corpus |
| 126 | + ] |
| 127 | + encoded_corpus = self._encode(corpus_texts) |
| 128 | + |
| 129 | + logger.info( |
| 130 | + f"Indexing Corpus... {len(encoded_corpus.ids):,} documents, " |
| 131 | + f"{len(encoded_corpus.vocab):,} vocab" |
| 132 | + ) |
| 133 | + |
| 134 | + self.retriever = bm25s.BM25(k1=self.k1, b=self.b) |
| 135 | + self.retriever.index(encoded_corpus) |
| 136 | + self.corpus_idx_to_id = {i: row["id"] for i, row in enumerate(corpus)} |
| 137 | + |
| 138 | + if self.prior_weight > 0: |
| 139 | + # Per-document token IDs for composite prior computation. |
| 140 | + # Stored as compact numpy int32 arrays to minimize memory. |
| 141 | + self.corpus_token_ids = [ |
| 142 | + np.array(doc_ids, dtype=np.int32) for doc_ids in encoded_corpus.ids |
| 143 | + ] |
| 144 | + self.corpus_vocab = dict(encoded_corpus.vocab) |
| 145 | + self.doc_lengths = np.array( |
| 146 | + [len(ids) for ids in encoded_corpus.ids], dtype=np.float64 |
| 147 | + ) |
| 148 | + self.avg_dl = ( |
| 149 | + float(self.doc_lengths.mean()) if len(self.doc_lengths) > 0 else 0.0 |
| 150 | + ) |
| 151 | + |
| 152 | + logger.info(f"Indexed {len(self.corpus_idx_to_id):,} documents") |
| 153 | + |
| 154 | + def search( |
| 155 | + self, |
| 156 | + queries: QueryDatasetType, |
| 157 | + *, |
| 158 | + task_metadata: TaskMetadata, |
| 159 | + hf_split: str, |
| 160 | + hf_subset: str, |
| 161 | + top_k: int, |
| 162 | + encode_kwargs: EncodeKwargs, |
| 163 | + top_ranked: TopRankedDocumentsType | None = None, |
| 164 | + num_proc: int | None = None, |
| 165 | + ) -> RetrievalOutputType: |
| 166 | + logger.info("Encoding Queries...") |
| 167 | + query_ids = list(queries["id"]) |
| 168 | + results: RetrievalOutputType = {qid: {} for qid in query_ids} |
| 169 | + queries_loader = _create_text_queries_dataloader(queries) |
| 170 | + queries_texts = [text for batch in queries_loader for text in batch["text"]] |
| 171 | + |
| 172 | + query_tokenized = self._encode(queries_texts) |
| 173 | + |
| 174 | + logger.info(f"Retrieving Results... {len(queries):,} queries") |
| 175 | + |
| 176 | + queries_results, queries_scores = self.retriever.retrieve( |
| 177 | + query_tokenized, |
| 178 | + k=min(top_k, len(self.corpus_idx_to_id)), |
| 179 | + ) |
| 180 | + |
| 181 | + use_prior = self.prior_weight > 0 |
| 182 | + if use_prior: |
| 183 | + query_id_to_str = {v: k for k, v in query_tokenized.vocab.items()} |
| 184 | + |
| 185 | + for qi, qid in enumerate(query_ids): |
| 186 | + doc_indices = queries_results[qi] |
| 187 | + bm25_scores = queries_scores[qi].astype(np.float64) |
| 188 | + |
| 189 | + query_documents = ( |
| 190 | + top_ranked[qid] if top_ranked and qid in top_ranked else None |
| 191 | + ) |
| 192 | + |
| 193 | + doc_id_to_score: dict[str, float] = {} |
| 194 | + |
| 195 | + # Separate positive-score candidates for Bayesian re-scoring |
| 196 | + positive_mask = bm25_scores > 0 |
| 197 | + positive_indices = np.where(positive_mask)[0] |
| 198 | + |
| 199 | + if len(positive_indices) > 0: |
| 200 | + cand_doc_indices = doc_indices[positive_indices] |
| 201 | + cand_bm25_scores = bm25_scores[positive_indices] |
| 202 | + |
| 203 | + # Dynamic beta = median of BM25 scores for this query |
| 204 | + beta = float(np.median(cand_bm25_scores)) |
| 205 | + |
| 206 | + # Dynamic alpha scaling for query-level score distribution |
| 207 | + # invariance. The paper defines alpha as sigmoid steepness; |
| 208 | + # dividing by std(scores) keeps the effective steepness |
| 209 | + # consistent across queries whose BM25 ranges vary widely, |
| 210 | + # preventing sigmoid saturation on high-scoring queries. |
| 211 | + score_std = float(np.std(cand_bm25_scores)) |
| 212 | + alpha_eff = ( |
| 213 | + self.alpha / score_std if score_std > 1e-10 else self.alpha |
| 214 | + ) |
| 215 | + |
| 216 | + # Sigmoid likelihood (monotonic -- preserves BM25 ranking) |
| 217 | + x = np.clip(alpha_eff * (cand_bm25_scores - beta), -500, 500) |
| 218 | + likelihood = 1.0 / (1.0 + np.exp(-x)) |
| 219 | + |
| 220 | + if use_prior: |
| 221 | + # Map query token IDs to corpus vocab IDs |
| 222 | + query_term_strs = [ |
| 223 | + query_id_to_str[tid] for tid in query_tokenized.ids[qi] |
| 224 | + ] |
| 225 | + query_corpus_ids = np.array( |
| 226 | + [ |
| 227 | + self.corpus_vocab[t] |
| 228 | + for t in query_term_strs |
| 229 | + if t in self.corpus_vocab |
| 230 | + ], |
| 231 | + dtype=np.int32, |
| 232 | + ) |
| 233 | + |
| 234 | + n_cand = len(cand_doc_indices) |
| 235 | + query_tfs = np.zeros(n_cand, dtype=np.float64) |
| 236 | + cand_doc_lengths = np.zeros(n_cand, dtype=np.float64) |
| 237 | + |
| 238 | + has_query_terms = len(query_corpus_ids) > 0 |
| 239 | + for ci in range(n_cand): |
| 240 | + doc_idx = cand_doc_indices[ci] |
| 241 | + cand_doc_lengths[ci] = self.doc_lengths[doc_idx] |
| 242 | + if has_query_terms: |
| 243 | + doc_ids = self.corpus_token_ids[doc_idx] |
| 244 | + query_tfs[ci] = np.isin(doc_ids, query_corpus_ids).sum() |
| 245 | + |
| 246 | + composite = _composite_prior( |
| 247 | + query_tfs, cand_doc_lengths, self.avg_dl |
| 248 | + ) |
| 249 | + # Interpolate between flat prior (0.5) and composite |
| 250 | + prior = 0.5 + self.prior_weight * (composite - 0.5) |
| 251 | + posterior = (likelihood * prior) / ( |
| 252 | + likelihood * prior + (1.0 - likelihood) * (1.0 - prior) |
| 253 | + ) |
| 254 | + else: |
| 255 | + # prior=0.5 => posterior = likelihood |
| 256 | + posterior = likelihood |
| 257 | + |
| 258 | + for ci in range(len(cand_doc_indices)): |
| 259 | + doc_id = self.corpus_idx_to_id[cand_doc_indices[ci]] |
| 260 | + if query_documents is None or doc_id in query_documents: |
| 261 | + doc_id_to_score[doc_id] = float(posterior[ci]) |
| 262 | + |
| 263 | + # Include zero-score documents with score 0.0 |
| 264 | + for vi in np.where(~positive_mask)[0]: |
| 265 | + doc_id = self.corpus_idx_to_id[doc_indices[vi]] |
| 266 | + if query_documents is None or doc_id in query_documents: |
| 267 | + doc_id_to_score[doc_id] = 0.0 |
| 268 | + |
| 269 | + results[qid] = doc_id_to_score |
| 270 | + |
| 271 | + return results |
| 272 | + |
| 273 | + return BB25Search(**kwargs) |
| 274 | + |
| 275 | + |
| 276 | +bb25_model = ModelMeta( |
| 277 | + loader=bb25_loader, |
| 278 | + name="baseline/bb25", |
| 279 | + model_type=["sparse"], |
| 280 | + languages=None, |
| 281 | + open_weights=True, |
| 282 | + revision="0_1_1", |
| 283 | + release_date="2026-02-06", |
| 284 | + n_parameters=None, |
| 285 | + n_embedding_parameters=None, |
| 286 | + memory_usage_mb=None, |
| 287 | + embed_dim=None, |
| 288 | + license=None, |
| 289 | + max_tokens=None, |
| 290 | + reference="https://github.com/instructkr/bb25", |
| 291 | + similarity_fn_name=None, |
| 292 | + framework=[], |
| 293 | + use_instructions=False, |
| 294 | + public_training_code="https://github.com/instructkr/bb25", |
| 295 | + public_training_data=None, |
| 296 | + training_datasets=None, |
| 297 | + citation="""@software{jeong2026bayesianbm25, |
| 298 | + title={Bayesian BM25: A Probabilistic Framework for Hybrid Text and Vector Search}, |
| 299 | + author={Jeong, Jaepil}, |
| 300 | + year={2026}, |
| 301 | + doi={10.5281/zenodo.18414941}, |
| 302 | + url={https://doi.org/10.5281/zenodo.18414941}, |
| 303 | +} |
| 304 | +@software{jeong2026neural, |
| 305 | + title={From Bayesian Inference to Neural Computation: The Analytical Emergence of Neural Network Structure from Probabilistic Relevance Estimation}, |
| 306 | + author={Jeong, Jaepil}, |
| 307 | + year={2026}, |
| 308 | + doi={10.5281/zenodo.18512411}, |
| 309 | + url={https://doi.org/10.5281/zenodo.18512411}, |
| 310 | +}""", |
| 311 | +) |
0 commit comments