Skip to content

Commit a7575cc

Browse files
authored
[ENH] Add chroma_bm25 embedding function to python (#5806)
## Description of changes _Summarize the changes made by this PR._ - Improvements & Bug fixes - Add chroma_bm25 ef to python client. it matches the rust bm25 impl, and is 100% compatible with the js client chroma_bm25 embedding function - New functionality - ... ## Test plan _How are these changes tested?_ added new tests to ensure the results from this ef match those of rust and js manually created a collection in staging with this client, and then did getCollection from js to ensure that they are compatible both ways. they correctly embed as expected. - [ x] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Migration plan _Are there any migrations, or any forwards/backwards compatibility changes needed in order to make sure this change deploys reliably?_ ## Observability plan _What is the plan to instrument and monitor this change?_ ## Documentation Changes _Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs section](https://github.com/chroma-core/chroma/tree/main/docs/docs.trychroma.com)?_
1 parent 1064fac commit a7575cc

File tree

7 files changed

+574
-0
lines changed

7 files changed

+574
-0
lines changed
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
import math
2+
3+
import pytest
4+
5+
from chromadb.utils.embedding_functions.chroma_bm25_embedding_function import (
6+
DEFAULT_CHROMA_BM25_STOPWORDS,
7+
ChromaBm25EmbeddingFunction,
8+
)
9+
10+
11+
def _is_sorted(values: list[int]) -> bool:
12+
return all(values[i] >= values[i - 1] for i in range(1, len(values)))
13+
14+
15+
def test_comprehensive_tokenization_matches_reference() -> None:
16+
embedder = ChromaBm25EmbeddingFunction()
17+
embedding = embedder(
18+
[
19+
"Usain Bolt's top speed reached ~27.8 mph (44.72 km/h)",
20+
]
21+
)[0]
22+
23+
expected_indices = [
24+
230246813,
25+
395514983,
26+
458027949,
27+
488165615,
28+
729632045,
29+
734978415,
30+
997512866,
31+
1114505193,
32+
1381820790,
33+
1501587190,
34+
1649421877,
35+
1837285388,
36+
]
37+
expected_value = 1.6391153
38+
39+
assert embedding.indices == expected_indices
40+
for value in embedding.values:
41+
assert value == pytest.approx(expected_value, abs=1e-5)
42+
43+
44+
def test_matches_rust_reference_values() -> None:
45+
embedder = ChromaBm25EmbeddingFunction()
46+
embedding = embedder(
47+
[
48+
"The space-time continuum WARPS near massive objects...",
49+
]
50+
)[0]
51+
52+
expected_indices = [
53+
90097469,
54+
519064992,
55+
737893654,
56+
1110755108,
57+
1950894484,
58+
2031641008,
59+
2058513491,
60+
]
61+
expected_value = 1.660867
62+
63+
assert embedding.indices == expected_indices
64+
for value in embedding.values:
65+
assert value == pytest.approx(expected_value, abs=1e-5)
66+
67+
68+
def test_generates_embeddings_for_multiple_documents() -> None:
69+
embedder = ChromaBm25EmbeddingFunction()
70+
texts = [
71+
"Usain Bolt's top speed reached ~27.8 mph (44.72 km/h)",
72+
"The space-time continuum WARPS near massive objects...",
73+
"BM25 is great for sparse retrieval tasks",
74+
]
75+
76+
embeddings = embedder(texts)
77+
78+
assert len(embeddings) == len(texts)
79+
for embedding in embeddings:
80+
assert embedding.indices
81+
assert len(embedding.indices) == len(embedding.values)
82+
assert _is_sorted(embedding.indices)
83+
for value in embedding.values:
84+
assert value > 0
85+
assert math.isfinite(value)
86+
87+
88+
def test_embed_query_matches_call() -> None:
89+
embedder = ChromaBm25EmbeddingFunction()
90+
query = "retrieve BM25 docs"
91+
92+
query_embedding = embedder.embed_query([query])[0]
93+
doc_embedding = embedder([query])[0]
94+
95+
assert query_embedding.indices == doc_embedding.indices
96+
assert query_embedding.values == doc_embedding.values
97+
98+
99+
def test_config_round_trip() -> None:
100+
embedder = ChromaBm25EmbeddingFunction()
101+
config = embedder.get_config()
102+
103+
assert config["k"] == pytest.approx(1.2, abs=1e-9)
104+
assert config["b"] == pytest.approx(0.75, abs=1e-9)
105+
assert config["avg_doc_length"] == pytest.approx(256.0, abs=1e-9)
106+
assert config["token_max_length"] == 40
107+
assert "stopwords" not in config
108+
109+
custom_stopwords = DEFAULT_CHROMA_BM25_STOPWORDS[:10]
110+
rebuilt = ChromaBm25EmbeddingFunction.build_from_config(
111+
{
112+
**config,
113+
"stopwords": custom_stopwords,
114+
}
115+
)
116+
117+
rebuilt_config = rebuilt.get_config()
118+
assert rebuilt_config["stopwords"] == custom_stopwords
119+
assert rebuilt_config["token_max_length"] == config["token_max_length"]
120+
assert rebuilt_config["k"] == pytest.approx(config["k"], abs=1e-9)
121+
assert rebuilt_config["b"] == pytest.approx(config["b"], abs=1e-9)
122+
assert rebuilt_config["avg_doc_length"] == pytest.approx(
123+
config["avg_doc_length"], abs=1e-9
124+
)
125+
126+
127+
def test_validate_config_update_rejects_unknown_keys() -> None:
128+
embedder = ChromaBm25EmbeddingFunction()
129+
130+
with pytest.raises(ValueError):
131+
embedder.validate_config_update(embedder.get_config(), {"unknown": 123})
132+
133+
134+
def test_validate_config_update_allows_known_keys() -> None:
135+
embedder = ChromaBm25EmbeddingFunction()
136+
137+
embedder.validate_config_update(
138+
embedder.get_config(), {"k": 1.1, "stopwords": ["custom"]}
139+
)

chromadb/test/ef/test_ef.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def test_get_builtins_holds() -> None:
5454
"Bm25EmbeddingFunction",
5555
"ChromaCloudQwenEmbeddingFunction",
5656
"ChromaCloudSpladeEmbeddingFunction",
57+
"ChromaBm25EmbeddingFunction",
5758
}
5859

5960
assert expected_builtins == embedding_functions.get_builtins()

chromadb/utils/embedding_functions/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@
8383
from chromadb.utils.embedding_functions.chroma_cloud_splade_embedding_function import (
8484
ChromaCloudSpladeEmbeddingFunction,
8585
)
86+
from chromadb.utils.embedding_functions.chroma_bm25_embedding_function import (
87+
ChromaBm25EmbeddingFunction,
88+
)
8689

8790

8891
# Get all the class names for backward compatibility
@@ -116,6 +119,7 @@
116119
"Bm25EmbeddingFunction",
117120
"ChromaCloudQwenEmbeddingFunction",
118121
"ChromaCloudSpladeEmbeddingFunction",
122+
"ChromaBm25EmbeddingFunction",
119123
}
120124

121125

@@ -157,6 +161,7 @@ def get_builtins() -> Set[str]:
157161
"fastembed_sparse": FastembedSparseEmbeddingFunction,
158162
"bm25": Bm25EmbeddingFunction,
159163
"chroma-cloud-splade": ChromaCloudSpladeEmbeddingFunction,
164+
"chroma_bm25": ChromaBm25EmbeddingFunction,
160165
}
161166

162167

@@ -273,6 +278,7 @@ def config_to_embedding_function(config: Dict[str, Any]) -> EmbeddingFunction:
273278
"Bm25EmbeddingFunction",
274279
"ChromaCloudQwenEmbeddingFunction",
275280
"ChromaCloudSpladeEmbeddingFunction",
281+
"ChromaBm25EmbeddingFunction",
276282
"register_embedding_function",
277283
"config_to_embedding_function",
278284
"known_embedding_functions",

chromadb/utils/embedding_functions/bm25_embedding_function.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
)
66
from typing import Dict, Any, TypedDict, Optional
77
from typing import cast, Literal
8+
import warnings
89
from chromadb.utils.embedding_functions.schemas import validate_config_schema
910
from chromadb.utils.sparse_embedding_utils import normalize_sparse_vector
1011

@@ -45,6 +46,11 @@ def __init__(
4546
query_config (dict, optional): Configuration for the query, can be "task"
4647
**kwargs: Additional arguments to pass to the Bm25 model.
4748
"""
49+
warnings.warn(
50+
"Bm25EmbeddingFunction is deprecated. Please use ChromaBm25EmbeddingFunction instead.",
51+
DeprecationWarning,
52+
stacklevel=2,
53+
)
4854
try:
4955
from fastembed.sparse.bm25 import Bm25
5056
except ImportError:
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
from __future__ import annotations
2+
3+
from collections import Counter
4+
from typing import Any, Dict, Iterable, List, Optional, TypedDict
5+
6+
from chromadb.api.types import Documents, SparseEmbeddingFunction, SparseVectors
7+
from chromadb.base_types import SparseVector
8+
from chromadb.utils.embedding_functions.schemas import validate_config_schema
9+
from chromadb.utils.embedding_functions.schemas.bm25_tokenizer import (
10+
Bm25Tokenizer,
11+
DEFAULT_CHROMA_BM25_STOPWORDS as _DEFAULT_STOPWORDS,
12+
get_english_stemmer,
13+
Murmur3AbsHasher,
14+
)
15+
16+
NAME = "chroma_bm25"
17+
18+
DEFAULT_K = 1.2
19+
DEFAULT_B = 0.75
20+
DEFAULT_AVG_DOC_LENGTH = 256.0
21+
DEFAULT_TOKEN_MAX_LENGTH = 40
22+
23+
DEFAULT_CHROMA_BM25_STOPWORDS: List[str] = list(_DEFAULT_STOPWORDS)
24+
25+
26+
class ChromaBm25Config(TypedDict, total=False):
27+
k: float
28+
b: float
29+
avg_doc_length: float
30+
token_max_length: int
31+
stopwords: List[str]
32+
33+
34+
class ChromaBm25EmbeddingFunction(SparseEmbeddingFunction[Documents]):
35+
def __init__(
36+
self,
37+
k: float = DEFAULT_K,
38+
b: float = DEFAULT_B,
39+
avg_doc_length: float = DEFAULT_AVG_DOC_LENGTH,
40+
token_max_length: int = DEFAULT_TOKEN_MAX_LENGTH,
41+
stopwords: Optional[Iterable[str]] = None,
42+
) -> None:
43+
"""Initialize the BM25 sparse embedding function."""
44+
45+
self.k = float(k)
46+
self.b = float(b)
47+
self.avg_doc_length = float(avg_doc_length)
48+
self.token_max_length = int(token_max_length)
49+
50+
if stopwords is not None:
51+
self.stopwords: Optional[List[str]] = [str(word) for word in stopwords]
52+
stopword_list: Iterable[str] = self.stopwords
53+
else:
54+
self.stopwords = None
55+
stopword_list = DEFAULT_CHROMA_BM25_STOPWORDS
56+
57+
stemmer = get_english_stemmer()
58+
self._tokenizer = Bm25Tokenizer(stemmer, stopword_list, self.token_max_length)
59+
self._hasher = Murmur3AbsHasher()
60+
61+
def _encode(self, text: str) -> SparseVector:
62+
tokens = self._tokenizer.tokenize(text)
63+
64+
if not tokens:
65+
return SparseVector(indices=[], values=[])
66+
67+
doc_len = float(len(tokens))
68+
counts = Counter(self._hasher.hash(token) for token in tokens)
69+
70+
indices = sorted(counts.keys())
71+
values: List[float] = []
72+
for idx in indices:
73+
tf = float(counts[idx])
74+
denominator = tf + self.k * (
75+
1 - self.b + (self.b * doc_len) / self.avg_doc_length
76+
)
77+
score = tf * (self.k + 1) / denominator
78+
values.append(score)
79+
80+
return SparseVector(indices=indices, values=values)
81+
82+
def __call__(self, input: Documents) -> SparseVectors:
83+
sparse_vectors: SparseVectors = []
84+
85+
if not input:
86+
return sparse_vectors
87+
88+
for document in input:
89+
sparse_vectors.append(self._encode(document))
90+
91+
return sparse_vectors
92+
93+
def embed_query(self, input: Documents) -> SparseVectors:
94+
return self.__call__(input)
95+
96+
@staticmethod
97+
def name() -> str:
98+
return NAME
99+
100+
@staticmethod
101+
def build_from_config(
102+
config: Dict[str, Any]
103+
) -> "SparseEmbeddingFunction[Documents]":
104+
return ChromaBm25EmbeddingFunction(
105+
k=config.get("k", DEFAULT_K),
106+
b=config.get("b", DEFAULT_B),
107+
avg_doc_length=config.get("avg_doc_length", DEFAULT_AVG_DOC_LENGTH),
108+
token_max_length=config.get("token_max_length", DEFAULT_TOKEN_MAX_LENGTH),
109+
stopwords=config.get("stopwords"),
110+
)
111+
112+
def get_config(self) -> Dict[str, Any]:
113+
config: Dict[str, Any] = {
114+
"k": self.k,
115+
"b": self.b,
116+
"avg_doc_length": self.avg_doc_length,
117+
"token_max_length": self.token_max_length,
118+
}
119+
120+
if self.stopwords is not None:
121+
config["stopwords"] = list(self.stopwords)
122+
123+
return config
124+
125+
def validate_config_update(
126+
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
127+
) -> None:
128+
mutable_keys = {"k", "b", "avg_doc_length", "token_max_length", "stopwords"}
129+
for key in new_config:
130+
if key not in mutable_keys:
131+
raise ValueError(f"Updating '{key}' is not supported for {NAME}")
132+
133+
@staticmethod
134+
def validate_config(config: Dict[str, Any]) -> None:
135+
validate_config_schema(config, NAME)

0 commit comments

Comments
 (0)