Skip to content

Commit 9f79843

Browse files
authored
[ENH] Add sync point to test_filtering + fix issues (#2388)
## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - Adds sync point to test_filtering - Modifies test_filtering in cluster mode to not emit documents/where_clauses of length < 3 and characters "_" and "%" - Increases the deadline of tests from 45 secs to 90 secs since waiting for compaction to finish could end up taking longer than 45 secs - Fix lte comparison bug with f32 metadata values - Fixes version syncing logic to not race with compaction by getting the initial version before the add - Suppresses health check warning for filtering too much - Fixes replace_block bug in sparse index - Fixes split bug in bf writer ## Test plan - [x] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Documentation Changes None
1 parent 9cce6b1 commit 9f79843

File tree

12 files changed

+155
-27
lines changed

12 files changed

+155
-27
lines changed

chromadb/test/property/strategies.py

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,16 @@ class Record(TypedDict):
9999
# TODO: support empty strings everywhere
100100
sql_alphabet = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_"
101101
safe_text = st.text(alphabet=sql_alphabet, min_size=1)
102+
sql_alphabet_minus_underscore = (
103+
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-"
104+
)
105+
safe_text_min_size_3 = st.text(alphabet=sql_alphabet_minus_underscore, min_size=3)
102106
tenant_database_name = st.text(alphabet=sql_alphabet, min_size=3)
103107

104108
# Workaround for FastAPI json encoding peculiarities
105109
# https://github.com/tiangolo/fastapi/blob/8ac8d70d52bb0dd9eb55ba4e22d3e383943da05c/fastapi/encoders.py#L104
106110
safe_text = safe_text.filter(lambda s: not s.startswith("_sa"))
111+
safe_text_min_size_3 = safe_text_min_size_3.filter(lambda s: not s.startswith("_sa"))
107112
tenant_database_name = tenant_database_name.filter(lambda s: not s.startswith("_sa"))
108113

109114
safe_integers = st.integers(
@@ -316,10 +321,21 @@ def collections(
316321
if has_documents is None:
317322
has_documents = draw(st.booleans())
318323
assert has_documents is not None
319-
if has_documents and add_filterable_data:
320-
known_document_keywords = draw(st.lists(safe_text, min_size=5, max_size=5))
324+
# For cluster tests, we want to avoid generating documents and where_document
325+
# clauses of length < 3. We also don't want them to contain certan special
326+
# characters like _ and % that implicitly involve searching for a regex in sqlite.
327+
if not NOT_CLUSTER_ONLY:
328+
if has_documents and add_filterable_data:
329+
known_document_keywords = draw(
330+
st.lists(safe_text_min_size_3, min_size=5, max_size=5)
331+
)
332+
else:
333+
known_document_keywords = []
321334
else:
322-
known_document_keywords = []
335+
if has_documents and add_filterable_data:
336+
known_document_keywords = draw(st.lists(safe_text, min_size=5, max_size=5))
337+
else:
338+
known_document_keywords = []
323339

324340
if not has_documents:
325341
has_embeddings = True
@@ -375,6 +391,27 @@ def metadata(
375391
@st.composite
376392
def document(draw: st.DrawFn, collection: Collection) -> types.Document:
377393
"""Strategy for generating documents that could be a part of the given collection"""
394+
# For cluster tests, we want to avoid generating documents of length < 3.
395+
# We also don't want them to contain certan special
396+
# characters like _ and % that implicitly involve searching for a regex in sqlite.
397+
if not NOT_CLUSTER_ONLY:
398+
# Blacklist certain unicode characters that affect sqlite processing.
399+
# For example, the null (/x00) character makes sqlite stop processing a string.
400+
# Also, blacklist _ and % for cluster tests.
401+
blacklist_categories = ("Cc", "Cs", "Pc", "Po")
402+
if collection.known_document_keywords:
403+
known_words_st = st.sampled_from(collection.known_document_keywords)
404+
else:
405+
known_words_st = st.text(
406+
min_size=3,
407+
alphabet=st.characters(blacklist_categories=blacklist_categories), # type: ignore
408+
)
409+
410+
random_words_st = st.text(
411+
min_size=3, alphabet=st.characters(blacklist_categories=blacklist_categories) # type: ignore
412+
)
413+
words = draw(st.lists(st.one_of(known_words_st, random_words_st), min_size=1))
414+
return " ".join(words)
378415

379416
# Blacklist certain unicode characters that affect sqlite processing.
380417
# For example, the null (/x00) character makes sqlite stop processing a string.
@@ -531,10 +568,19 @@ def where_clause(draw: st.DrawFn, collection: Collection) -> types.Where:
531568
@st.composite
532569
def where_doc_clause(draw: st.DrawFn, collection: Collection) -> types.WhereDocument:
533570
"""Generate a where_document filter that could be used against the given collection"""
534-
if collection.known_document_keywords:
535-
word = draw(st.sampled_from(collection.known_document_keywords))
571+
# For cluster tests, we want to avoid generating where_document
572+
# clauses of length < 3. We also don't want them to contain certan special
573+
# characters like _ and % that implicitly involve searching for a regex in sqlite.
574+
if not NOT_CLUSTER_ONLY:
575+
if collection.known_document_keywords:
576+
word = draw(st.sampled_from(collection.known_document_keywords))
577+
else:
578+
word = draw(safe_text_min_size_3)
536579
else:
537-
word = draw(safe_text)
580+
if collection.known_document_keywords:
581+
word = draw(st.sampled_from(collection.known_document_keywords))
582+
else:
583+
word = draw(safe_text)
538584

539585
# This is hacky, but the distributed system does not support $not_contains
540586
# so we need to avoid generating these operators for now in that case.

chromadb/test/property/test_add.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,11 @@ def test_add_large(
188188
):
189189
coll.add(*batch)
190190

191-
if not NOT_CLUSTER_ONLY and should_compact:
191+
if (
192+
not NOT_CLUSTER_ONLY
193+
and should_compact
194+
and len(normalized_record_set["ids"]) > 10
195+
):
192196
initial_version = coll.get_model()["version"]
193197
# Wait for the model to be updated, since the record set is larger, add some additional time
194198
wait_for_version_increase(

chromadb/test/property/test_filtering.py

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import logging
2121
import random
2222
import re
23+
from chromadb.test.utils.wait_for_version_increase import wait_for_version_increase
2324

2425

2526
def _filter_where_clause(clause: Where, metadata: Optional[Metadata]) -> bool:
@@ -175,18 +176,26 @@ def _filter_embedding_set(
175176

176177

177178
@settings(
179+
deadline=90000,
178180
suppress_health_check=[
179181
HealthCheck.function_scoped_fixture,
180182
HealthCheck.large_base_example,
181-
]
183+
HealthCheck.filter_too_much,
184+
],
182185
) # type: ignore
183186
@given(
184187
collection=collection_st,
185188
record_set=recordset_st,
186189
filters=st.lists(strategies.filters(collection_st, recordset_st), min_size=1),
190+
should_compact=st.booleans(),
187191
)
188192
def test_filterable_metadata_get(
189-
caplog, api: ServerAPI, collection: strategies.Collection, record_set, filters
193+
caplog,
194+
api: ServerAPI,
195+
collection: strategies.Collection,
196+
record_set,
197+
filters,
198+
should_compact: bool,
190199
) -> None:
191200
caplog.set_level(logging.ERROR)
192201

@@ -197,25 +206,38 @@ def test_filterable_metadata_get(
197206
embedding_function=collection.embedding_function,
198207
)
199208

209+
initial_version = coll.get_model()["version"]
210+
200211
coll.add(**record_set)
212+
213+
if not NOT_CLUSTER_ONLY:
214+
# Only wait for compaction if the size of the collection is
215+
# some minimal size
216+
if should_compact and len(invariants.wrap(record_set["ids"])) > 10:
217+
# Wait for the model to be updated
218+
wait_for_version_increase(api, collection.name, initial_version)
219+
201220
for filter in filters:
202221
result_ids = coll.get(**filter)["ids"]
203222
expected_ids = _filter_embedding_set(record_set, filter)
204223
assert sorted(result_ids) == sorted(expected_ids)
205224

206225

207226
@settings(
227+
deadline=90000,
208228
suppress_health_check=[
209229
HealthCheck.function_scoped_fixture,
210230
HealthCheck.large_base_example,
211-
]
231+
HealthCheck.filter_too_much,
232+
],
212233
) # type: ignore
213234
@given(
214235
collection=collection_st,
215236
record_set=recordset_st,
216237
filters=st.lists(strategies.filters(collection_st, recordset_st), min_size=1),
217238
limit=st.integers(min_value=1, max_value=10),
218239
offset=st.integers(min_value=0, max_value=10),
240+
should_compact=st.booleans(),
219241
)
220242
def test_filterable_metadata_get_limit_offset(
221243
caplog,
@@ -225,6 +247,7 @@ def test_filterable_metadata_get_limit_offset(
225247
filters,
226248
limit,
227249
offset,
250+
should_compact: bool,
228251
) -> None:
229252
caplog.set_level(logging.ERROR)
230253

@@ -240,7 +263,17 @@ def test_filterable_metadata_get_limit_offset(
240263
embedding_function=collection.embedding_function,
241264
)
242265

266+
initial_version = coll.get_model()["version"]
267+
243268
coll.add(**record_set)
269+
270+
if not NOT_CLUSTER_ONLY:
271+
# Only wait for compaction if the size of the collection is
272+
# some minimal size
273+
if should_compact and len(invariants.wrap(record_set["ids"])) > 10:
274+
# Wait for the model to be updated
275+
wait_for_version_increase(api, collection.name, initial_version)
276+
244277
for filter in filters:
245278
# add limit and offset to filter
246279
filter["limit"] = limit
@@ -251,10 +284,12 @@ def test_filterable_metadata_get_limit_offset(
251284

252285

253286
@settings(
287+
deadline=90000,
254288
suppress_health_check=[
255289
HealthCheck.function_scoped_fixture,
256290
HealthCheck.large_base_example,
257-
]
291+
HealthCheck.filter_too_much,
292+
],
258293
)
259294
@given(
260295
collection=collection_st,
@@ -263,13 +298,15 @@ def test_filterable_metadata_get_limit_offset(
263298
strategies.filters(collection_st, recordset_st, include_all_ids=True),
264299
min_size=1,
265300
),
301+
should_compact=st.booleans(),
266302
)
267303
def test_filterable_metadata_query(
268304
caplog: pytest.LogCaptureFixture,
269305
api: ServerAPI,
270306
collection: strategies.Collection,
271307
record_set: strategies.RecordSet,
272308
filters: List[strategies.Filter],
309+
should_compact: bool,
273310
) -> None:
274311
caplog.set_level(logging.ERROR)
275312

@@ -279,9 +316,18 @@ def test_filterable_metadata_query(
279316
metadata=collection.metadata, # type: ignore
280317
embedding_function=collection.embedding_function,
281318
)
319+
initial_version = coll.get_model()["version"]
282320
normalized_record_set = invariants.wrap_all(record_set)
283321

284322
coll.add(**record_set)
323+
324+
if not NOT_CLUSTER_ONLY:
325+
# Only wait for compaction if the size of the collection is
326+
# some minimal size
327+
if should_compact and len(invariants.wrap(record_set["ids"])) > 10:
328+
# Wait for the model to be updated
329+
wait_for_version_increase(api, collection.name, initial_version)
330+
285331
total_count = len(normalized_record_set["ids"])
286332
# Pick a random vector
287333
random_query: Embedding

rust/worker/src/blockstore/arrow/block/delta.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ impl BlockDelta {
128128
let mut blocks_to_split = Vec::new();
129129
blocks_to_split.push(self.clone());
130130
let mut output = Vec::new();
131+
let mut first_iter = true;
131132
// iterate over all blocks to split until its empty
132133
while !blocks_to_split.is_empty() {
133134
let curr_block = blocks_to_split.pop().unwrap();
@@ -168,7 +169,11 @@ impl BlockDelta {
168169
builder: new_delta,
169170
id: Uuid::new_v4(),
170171
};
171-
172+
if first_iter {
173+
first_iter = false;
174+
} else {
175+
output.push((curr_block.builder.get_key(0).clone(), curr_block));
176+
}
172177
if new_block.get_size::<K, V>() > MAX_BLOCK_SIZE {
173178
blocks_to_split.push(new_block);
174179
} else {

rust/worker/src/blockstore/arrow/block/delta_storage.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ impl StringValueStorage {
267267
}
268268
}
269269

270-
#[derive(Clone)]
270+
#[derive(Clone, Debug)]
271271
pub(super) struct UInt32Storage {
272272
pub(super) storage: Arc<RwLock<BTreeMap<CompositeKey, u32>>>,
273273
}
@@ -355,7 +355,7 @@ impl UInt32Storage {
355355
}
356356
}
357357

358-
#[derive(Clone)]
358+
#[derive(Clone, Debug)]
359359
pub(super) struct Int32ArrayStorage {
360360
pub(super) storage: Arc<RwLock<BTreeMap<CompositeKey, Int32Array>>>,
361361
}
@@ -464,7 +464,7 @@ impl Int32ArrayStorage {
464464
}
465465
}
466466

467-
#[derive(Clone)]
467+
#[derive(Clone, Debug)]
468468
pub(super) struct RoaringBitmapStorage {
469469
pub(super) storage: Arc<RwLock<BTreeMap<CompositeKey, Vec<u8>>>>,
470470
}
@@ -561,7 +561,7 @@ impl RoaringBitmapStorage {
561561
}
562562
}
563563

564-
#[derive(Clone)]
564+
#[derive(Clone, Debug)]
565565
pub(super) struct DataRecordStorage {
566566
pub(super) id_storage: Arc<RwLock<BTreeMap<CompositeKey, String>>>,
567567
pub(super) embedding_storage: Arc<RwLock<BTreeMap<CompositeKey, Vec<f32>>>>,

rust/worker/src/blockstore/arrow/blockfile.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,14 +275,21 @@ impl<'me, K: ArrowReadableKey<'me> + Into<KeyWrapper>, V: ArrowReadableValue<'me
275275
let target_block_id = self.sparse_index.get_target_block_id(&search_key);
276276
let block = self.get_block(target_block_id).await;
277277
let res = match block {
278-
Some(block) => block.get(prefix, key),
278+
Some(block) => block.get(prefix, key.clone()),
279279
None => {
280+
tracing::error!("Block with id {:?} not found", target_block_id);
280281
return Err(Box::new(ArrowBlockfileError::BlockNotFound));
281282
}
282283
};
283284
match res {
284285
Some(value) => Ok(value),
285286
None => {
287+
tracing::error!(
288+
"Key {:?}/{:?} not found in block {:?}",
289+
prefix,
290+
key,
291+
target_block_id
292+
);
286293
return Err(Box::new(BlockfileError::NotFoundError));
287294
}
288295
}
@@ -309,6 +316,7 @@ impl<'me, K: ArrowReadableKey<'me> + Into<KeyWrapper>, V: ArrowReadableValue<'me
309316
block_offset += b.len();
310317
}
311318
None => {
319+
tracing::error!("Block id {:?} not found", uuid);
312320
return Err(Box::new(ArrowBlockfileError::BlockNotFound));
313321
}
314322
}
@@ -320,6 +328,10 @@ impl<'me, K: ArrowReadableKey<'me> + Into<KeyWrapper>, V: ArrowReadableValue<'me
320328
return Ok((prefix, key, value));
321329
}
322330
_ => {
331+
tracing::error!(
332+
"Value not found at index {:?} for block",
333+
index - block_offset,
334+
);
323335
return Err(Box::new(BlockfileError::NotFoundError));
324336
}
325337
}

rust/worker/src/blockstore/arrow/sparse_index.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,8 +341,13 @@ impl SparseIndex {
341341
forward.remove(&old_start_key);
342342
if old_start_key == SparseIndexDelimiter::Start {
343343
forward.insert(SparseIndexDelimiter::Start, new_block_id);
344+
reverse.insert(new_block_id, SparseIndexDelimiter::Start);
344345
} else {
345-
forward.insert(SparseIndexDelimiter::Key(new_start_key), new_block_id);
346+
forward.insert(
347+
SparseIndexDelimiter::Key(new_start_key.clone()),
348+
new_block_id,
349+
);
350+
reverse.insert(new_block_id, SparseIndexDelimiter::Key(new_start_key));
346351
}
347352
}
348353
}

rust/worker/src/execution/operators/count_records.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ impl Operator<CountRecordsInput, CountRecordsOutput> for CountRecordsOperator {
7878
Err(e) => {
7979
match *e {
8080
RecordSegmentReaderCreationError::UninitializedSegment => {
81+
tracing::info!("[CountQueryOrchestrator] Record segment is uninitialized");
8182
// This means there no compaction has occured.
8283
// So we can just traverse the log records
8384
// and count the number of records.

0 commit comments

Comments
 (0)