Skip to content

Commit fc31e6a

Browse files
authored
[ENH] Update sparse vector similarity metric (#5406)
## Description of changes _Summarize the changes made by this PR._ - Improvements & Bug fixes - Change sparse vector metric from inner product to one minus inner product. This is consistent with our dense vector similarity metric. - Updates `Rank` operator so that the results are returned in increasing score. Smaller score means higher similarity. - Removes `SparseKnnMerge` implementation as it is unnecessary now - New functionality - N/A ## Test plan _How are these changes tested?_ - [ ] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Migration plan _Are there any migrations, or any forwards/backwards compatibility changes needed in order to make sure this change deploys reliably?_ ## Observability plan _What is the plan to instrument and monitor this change?_ ## Documentation Changes _Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs section](https://github.com/chroma-core/chroma/tree/main/docs/docs.trychroma.com)?_
1 parent 2d896de commit fc31e6a

File tree

10 files changed

+79
-198
lines changed

10 files changed

+79
-198
lines changed

rust/worker/benches/spann.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ fn calculate_recall<'a>(
212212
}
213213
// Now merge.
214214
let knn_input = KnnMergeInput {
215-
batch_distances: merge_list,
215+
batch_measures: merge_list,
216216
};
217217
let knn_operator = Merge { k: k as u32 };
218218
let knn_output = knn_operator
@@ -245,7 +245,7 @@ fn calculate_recall<'a>(
245245
.expect("Error running operator");
246246
let mut recall = 0;
247247
for bf_record in bf_output.records.iter() {
248-
for spann_record in knn_output.distances.iter() {
248+
for spann_record in knn_output.measures.iter() {
249249
if bf_record.offset_id == spann_record.offset_id {
250250
recall += 1;
251251
}

rust/worker/src/execution/operators/knn_merge.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@ use thiserror::Error;
2121
2222
#[derive(Debug)]
2323
pub struct KnnMergeInput {
24-
pub batch_distances: Vec<Vec<RecordMeasure>>,
24+
pub batch_measures: Vec<Vec<RecordMeasure>>,
2525
}
2626

2727
#[derive(Debug, Default)]
2828
pub struct KnnMergeOutput {
29-
pub distances: Vec<RecordMeasure>,
29+
pub measures: Vec<RecordMeasure>,
3030
}
3131

3232
#[derive(Error, Debug)]
@@ -47,12 +47,12 @@ impl Operator<KnnMergeInput, KnnMergeOutput> for Merge {
4747
// Reversing because similarity is in ascending order,
4848
// while Merge takes element in descending order
4949
let reversed_distances = input
50-
.batch_distances
50+
.batch_measures
5151
.iter()
5252
.map(|batch| batch.iter().map(|m| Reverse(m.clone())).collect())
5353
.collect();
5454
Ok(KnnMergeOutput {
55-
distances: self
55+
measures: self
5656
.merge(reversed_distances)
5757
.into_iter()
5858
.map(|Reverse(distance)| distance)
@@ -74,7 +74,7 @@ mod tests {
7474
/// - Second: 1, 3, ..., 99
7575
fn setup_knn_merge_input() -> KnnMergeInput {
7676
KnnMergeInput {
77-
batch_distances: vec![
77+
batch_measures: vec![
7878
(1..=100)
7979
.filter_map(|offset_id| {
8080
(offset_id % 3 == 1).then_some(RecordMeasure {
@@ -116,7 +116,7 @@ mod tests {
116116

117117
assert_eq!(
118118
knn_merge_output
119-
.distances
119+
.measures
120120
.iter()
121121
.map(|record| record.offset_id)
122122
.collect::<Vec<_>>(),

rust/worker/src/execution/operators/mod.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,4 @@ pub mod repair_log_offsets;
2525
pub mod select;
2626
pub mod source_record_segment;
2727
pub mod sparse_index_knn;
28-
pub mod sparse_knn_merge;
2928
pub mod sparse_log_knn;

rust/worker/src/execution/operators/rank.rs

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,6 @@ impl Operator<RankInput, RankOutput> for Rank {
201201
.map(|(offset_id, measure)| RecordMeasure { offset_id, measure })
202202
.collect::<Vec<_>>();
203203
ranks.sort_unstable();
204-
ranks.reverse();
205204
Ok(RankOutput { ranks })
206205
}
207206
}
@@ -246,8 +245,9 @@ mod tests {
246245

247246
let output = rank.run(&input).await.expect("Rank should succeed");
248247
assert_eq!(output.ranks.len(), 3);
249-
assert_eq!(output.ranks[0].offset_id, 1);
250-
assert_eq!(output.ranks[0].measure, 0.9);
248+
// After removing .reverse(), results are in ascending order by measure
249+
assert_eq!(output.ranks[0].offset_id, 3);
250+
assert_eq!(output.ranks[0].measure, 0.5);
251251
}
252252

253253
#[tokio::test]
@@ -310,7 +310,11 @@ mod tests {
310310
};
311311

312312
let output = rank.run(&input).await.expect("Rank should succeed");
313-
// Record 1 appears in both: 0.8 + 0.4 = 1.2
313+
// Summation results:
314+
// Only Record 1 appears in both lists: 0.8 + 0.4 = 1.2
315+
// Records 2 and 3 are filtered out since they don't appear in both lists
316+
// and both Knn operations have default: None
317+
assert_eq!(output.ranks.len(), 1);
314318
assert_eq!(output.ranks[0].offset_id, 1);
315319
assert_eq!(output.ranks[0].measure, 1.2);
316320

@@ -329,8 +333,10 @@ mod tests {
329333
let input = RankInput { knn_results };
330334

331335
let output = rank.run(&input).await.expect("Rank should succeed");
332-
assert_eq!(output.ranks[0].offset_id, 1);
333-
assert_eq!(output.ranks[0].measure, 0.4); // 0.8 * 0.5
336+
// Results are in ascending order, so the record with the lowest measure comes first
337+
// After multiplication by 0.5: record 1 = 0.8 * 0.5 = 0.4, record 2 = 0.6 * 0.5 = 0.3
338+
assert_eq!(output.ranks[0].offset_id, 2);
339+
assert_eq!(output.ranks[0].measure, 0.3); // 0.6 * 0.5
334340
}
335341

336342
#[tokio::test]
@@ -367,10 +373,11 @@ mod tests {
367373
};
368374

369375
let output = rank.run(&input).await.expect("Rank should succeed");
370-
assert_eq!(output.ranks[0].offset_id, 1);
371-
assert_eq!(output.ranks[0].measure, 0.8); // max(0.8, 0.5) = 0.8
372-
assert_eq!(output.ranks[1].offset_id, 2);
373-
assert_eq!(output.ranks[1].measure, 0.5); // max(0.3, 0.5) = 0.5
376+
// Results are in ascending order
377+
assert_eq!(output.ranks[0].offset_id, 2);
378+
assert_eq!(output.ranks[0].measure, 0.5); // max(0.3, 0.5) = 0.5
379+
assert_eq!(output.ranks[1].offset_id, 1);
380+
assert_eq!(output.ranks[1].measure, 0.8); // max(0.8, 0.5) = 0.8
374381

375382
// Test min
376383
let rank = Rank::Minimum(vec![
@@ -386,9 +393,10 @@ mod tests {
386393
let input = RankInput { knn_results };
387394

388395
let output = rank.run(&input).await.expect("Rank should succeed");
389-
assert_eq!(output.ranks[0].offset_id, 1);
390-
assert_eq!(output.ranks[0].measure, 0.5); // min(0.8, 0.5) = 0.5
391-
assert_eq!(output.ranks[1].offset_id, 2);
392-
assert_eq!(output.ranks[1].measure, 0.3); // min(0.3, 0.5) = 0.3
396+
// Results are in ascending order
397+
assert_eq!(output.ranks[0].offset_id, 2);
398+
assert_eq!(output.ranks[0].measure, 0.3); // min(0.3, 0.5) = 0.3
399+
assert_eq!(output.ranks[1].offset_id, 1);
400+
assert_eq!(output.ranks[1].measure, 0.5); // min(0.8, 0.5) = 0.5
393401
}
394402
}

rust/worker/src/execution/operators/sparse_index_knn.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ impl Operator<SparseIndexKnnInput, SparseIndexKnnOutput> for SparseIndexKnn {
6868
.into_iter()
6969
.map(|score| RecordMeasure {
7070
offset_id: score.offset,
71-
measure: score.score,
71+
// NOTE: We use `1 - query · document` as similarity metrics
72+
measure: 1.0 - score.score,
7273
})
7374
.collect(),
7475
})

rust/worker/src/execution/operators/sparse_knn_merge.rs

Lines changed: 0 additions & 133 deletions
This file was deleted.

0 commit comments

Comments
 (0)