Skip to content

Commit 2d896de

Browse files
authored
[ENH] Unimplement Hash + Eq for KnnQuery (#5397)
## Description of changes _Summarize the changes made by this PR._ - Improvements & Bug fixes - Removed `Hash + Eq` implementation for `KnnQuery`. Now we assume the expression traversal is always in DFS order and we only keep a `Vec<Vec<RecordMeasure>>` in the same order instead of a `HashMap<KnnQuery, Vec<RecordMeasure>>` - New functionality - N/A ## Test plan _How are these changes tested?_ - [ ] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Migration plan _Are there any migrations, or any forwards/backwards compatibility changes needed in order to make sure this change deploys reliably?_ ## Observability plan _What is the plan to instrument and monitor this change?_ ## Documentation Changes _Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs section](https://github.com/chroma-core/chroma/tree/main/docs/docs.trychroma.com)?_
1 parent 2dd9df6 commit 2d896de

File tree

4 files changed

+73
-144
lines changed

4 files changed

+73
-144
lines changed

rust/types/src/execution/operator.rs

Lines changed: 6 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
use core::mem::discriminant;
21
use serde::{de::Error, Deserialize, Deserializer, Serialize};
32
use serde_json::Value;
43
use std::{
54
cmp::Ordering,
65
collections::{BinaryHeap, HashSet},
7-
hash::{Hash, Hasher},
6+
hash::Hash,
87
};
98
use thiserror::Error;
109
use utoipa::ToSchema;
@@ -660,31 +659,6 @@ pub enum QueryVector {
660659
Sparse(SparseVector),
661660
}
662661

663-
impl Eq for QueryVector {}
664-
665-
impl Hash for QueryVector {
666-
fn hash<H: Hasher>(&self, state: &mut H) {
667-
discriminant(self).hash(state);
668-
match self {
669-
QueryVector::Dense(embedding) => {
670-
let bits = embedding
671-
.iter()
672-
.map(|val| (val + 0.0).to_bits())
673-
.collect::<Vec<_>>();
674-
bits.hash(state);
675-
}
676-
QueryVector::Sparse(embedding) => {
677-
let mut sorted_bits = embedding
678-
.iter()
679-
.map(|(index, val)| (index, (val + 0.0).to_bits()))
680-
.collect::<Vec<_>>();
681-
sorted_bits.sort_unstable();
682-
sorted_bits.hash(state);
683-
}
684-
}
685-
}
686-
}
687-
688662
impl TryFrom<chroma_proto::QueryVector> for QueryVector {
689663
type Error = QueryConversionError;
690664

@@ -721,7 +695,7 @@ impl TryFrom<QueryVector> for chroma_proto::QueryVector {
721695
}
722696
}
723697

724-
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
698+
#[derive(Clone, Debug, PartialEq)]
725699
pub struct KnnQuery {
726700
pub embedding: QueryVector,
727701
pub key: String,
@@ -773,7 +747,7 @@ impl Rank {
773747
128
774748
}
775749

776-
pub fn knn_queries(&self) -> HashSet<KnnQuery> {
750+
pub fn knn_queries(&self) -> Vec<KnnQuery> {
777751
match self {
778752
Rank::Absolute(rank) | Rank::Exponentiation(rank) | Rank::Logarithm(rank) => {
779753
rank.knn_queries()
@@ -787,18 +761,18 @@ impl Rank {
787761
| Rank::Minimum(ranks)
788762
| Rank::Multiplication(ranks)
789763
| Rank::Summation(ranks) => ranks.iter().flat_map(Rank::knn_queries).collect(),
790-
Rank::Value(_) => HashSet::new(),
764+
Rank::Value(_) => Vec::new(),
791765
Rank::Knn {
792766
embedding,
793767
key,
794768
limit,
795769
default: _,
796770
ordinal: _,
797-
} => HashSet::from_iter([KnnQuery {
771+
} => vec![KnnQuery {
798772
embedding: embedding.clone(),
799773
key: key.clone(),
800774
limit: *limit,
801-
}]),
775+
}],
802776
}
803777
}
804778
}
@@ -1245,33 +1219,6 @@ mod tests {
12451219
}
12461220
}
12471221

1248-
#[test]
1249-
fn test_query_vector_equality_and_hash() {
1250-
use std::collections::HashSet;
1251-
1252-
let dense1 = QueryVector::Dense(vec![0.1, 0.2, 0.3]);
1253-
let dense2 = QueryVector::Dense(vec![0.1, 0.2, 0.3]);
1254-
let dense3 = QueryVector::Dense(vec![0.1, 0.2, 0.4]);
1255-
1256-
// Test equality
1257-
assert_eq!(dense1, dense2);
1258-
assert_ne!(dense1, dense3);
1259-
1260-
// Test hash - equal vectors should have same hash
1261-
let mut set = HashSet::new();
1262-
set.insert(dense1.clone());
1263-
assert!(set.contains(&dense2));
1264-
assert!(!set.contains(&dense3));
1265-
1266-
// Test sparse vectors
1267-
let sparse1 = QueryVector::Sparse(SparseVector::new(vec![0, 5], vec![0.1, 0.5]));
1268-
let sparse2 = QueryVector::Sparse(SparseVector::new(vec![0, 5], vec![0.1, 0.5]));
1269-
assert_eq!(sparse1, sparse2);
1270-
1271-
set.insert(sparse1.clone());
1272-
assert!(set.contains(&sparse2));
1273-
}
1274-
12751222
#[test]
12761223
fn test_filter_json_serialization() {
12771224
// Test basic filter serialization

rust/worker/src/execution/operators/rank.rs

Lines changed: 55 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,9 @@ use std::{
66
use async_trait::async_trait;
77
use chroma_error::{ChromaError, ErrorCodes};
88
use chroma_system::Operator;
9-
use chroma_types::operator::{KnnQuery, Rank, RecordMeasure};
9+
use chroma_types::operator::{Rank, RecordMeasure};
1010
use thiserror::Error;
1111

12-
struct RankProvider<'me> {
13-
knn_results: &'me HashMap<KnnQuery, Vec<RecordMeasure>>,
14-
}
15-
1612
// NOTE: `RankDomain` represents evaluated scores for records
1713
// - `support`: scores of specific records
1814
// - `default`: scores of records not specified in `support`
@@ -101,8 +97,15 @@ impl RankDomain {
10197
}
10298
}
10399

104-
impl RankProvider<'_> {
105-
fn eval(&self, rank: Rank) -> RankDomain {
100+
struct RankProvider<R> {
101+
knn_result_iter: R,
102+
}
103+
104+
impl<R> RankProvider<R>
105+
where
106+
R: Iterator<Item = Vec<RecordMeasure>>,
107+
{
108+
fn eval(&mut self, rank: Rank) -> RankDomain {
106109
match rank {
107110
Rank::Absolute(rank) => self.eval(*rank).map(f32::abs),
108111
Rank::Division { left, right } => {
@@ -129,30 +132,22 @@ impl RankProvider<'_> {
129132
RankDomain::merge(accumulate_domain, domain, f32::mul)
130133
}),
131134
Rank::Knn {
132-
embedding,
133-
key,
134-
limit,
135+
embedding: _,
136+
key: _,
137+
limit: _,
135138
default,
136139
ordinal,
137140
} => {
138-
let knn_query = KnnQuery {
139-
embedding,
140-
key,
141-
limit,
142-
};
143141
let support = self
144-
.knn_results
145-
.get(&knn_query)
146-
.map(|records| {
147-
records
148-
.iter()
149-
.enumerate()
150-
.map(|(index, &RecordMeasure { offset_id, measure })| {
151-
(offset_id, if ordinal { index as f32 } else { measure })
152-
})
153-
.collect()
142+
.knn_result_iter
143+
.next()
144+
.unwrap_or_default()
145+
.into_iter()
146+
.enumerate()
147+
.map(|(index, RecordMeasure { offset_id, measure })| {
148+
(offset_id, if ordinal { index as f32 } else { measure })
154149
})
155-
.unwrap_or_default();
150+
.collect();
156151
RankDomain { support, default }
157152
}
158153
Rank::Subtraction { left, right } => {
@@ -169,9 +164,10 @@ impl RankProvider<'_> {
169164
}
170165
}
171166

167+
// NOTE: We assume that the provided vector of knn results are in the DFS order of Rank expression.
172168
#[derive(Clone, Debug)]
173169
pub struct RankInput {
174-
pub knn_results: HashMap<KnnQuery, Vec<RecordMeasure>>,
170+
pub knn_results: Vec<Vec<RecordMeasure>>,
175171
}
176172

177173
#[derive(Clone, Debug)]
@@ -194,8 +190,9 @@ impl Operator<RankInput, RankOutput> for Rank {
194190
type Error = RankError;
195191

196192
async fn run(&self, input: &RankInput) -> Result<RankOutput, RankError> {
197-
let rank_provider = RankProvider {
198-
knn_results: &input.knn_results,
193+
let knn_results = input.knn_results.clone();
194+
let mut rank_provider = RankProvider {
195+
knn_result_iter: knn_results.into_iter(),
199196
};
200197
let rank_domain = rank_provider.eval(self.clone());
201198
let mut ranks = rank_domain
@@ -211,34 +208,31 @@ impl Operator<RankInput, RankOutput> for Rank {
211208

212209
#[cfg(test)]
213210
mod tests {
211+
use chroma_types::operator::KnnQuery;
212+
214213
use super::*;
215214

216215
#[tokio::test]
217216
async fn test_rank_with_knn_results() {
218-
// Setup KNN results
219-
let mut knn_results = HashMap::new();
220217
let query = KnnQuery {
221218
embedding: chroma_types::operator::QueryVector::Dense(vec![0.1, 0.2, 0.3]),
222219
key: String::new(),
223220
limit: 3,
224221
};
225-
knn_results.insert(
226-
query.clone(),
227-
vec![
228-
RecordMeasure {
229-
offset_id: 1,
230-
measure: 0.9,
231-
},
232-
RecordMeasure {
233-
offset_id: 2,
234-
measure: 0.7,
235-
},
236-
RecordMeasure {
237-
offset_id: 3,
238-
measure: 0.5,
239-
},
240-
],
241-
);
222+
let knn_results = vec![vec![
223+
RecordMeasure {
224+
offset_id: 1,
225+
measure: 0.9,
226+
},
227+
RecordMeasure {
228+
offset_id: 2,
229+
measure: 0.7,
230+
},
231+
RecordMeasure {
232+
offset_id: 3,
233+
measure: 0.5,
234+
},
235+
]];
242236

243237
// Test simple KNN rank
244238
let rank = Rank::Knn {
@@ -258,8 +252,6 @@ mod tests {
258252

259253
#[tokio::test]
260254
async fn test_rank_arithmetic_operations() {
261-
// Setup two KNN queries
262-
let mut knn_results = HashMap::new();
263255
let query1 = KnnQuery {
264256
embedding: chroma_types::operator::QueryVector::Dense(vec![0.1]),
265257
key: String::new(),
@@ -273,9 +265,7 @@ mod tests {
273265
key: "sparse".to_string(),
274266
limit: 2,
275267
};
276-
277-
knn_results.insert(
278-
query1.clone(),
268+
let mut knn_results = vec![
279269
vec![
280270
RecordMeasure {
281271
offset_id: 1,
@@ -286,9 +276,6 @@ mod tests {
286276
measure: 0.6,
287277
},
288278
],
289-
);
290-
knn_results.insert(
291-
query2.clone(),
292279
vec![
293280
RecordMeasure {
294281
offset_id: 1,
@@ -299,7 +286,7 @@ mod tests {
299286
measure: 0.2,
300287
},
301288
],
302-
);
289+
];
303290

304291
// Test summation
305292
let rank = Rank::Summation(vec![
@@ -328,6 +315,7 @@ mod tests {
328315
assert_eq!(output.ranks[0].measure, 1.2);
329316

330317
// Test multiplication with constant
318+
knn_results.pop();
331319
let rank = Rank::Multiplication(vec![
332320
Rank::Knn {
333321
embedding: query1.embedding.clone(),
@@ -347,26 +335,21 @@ mod tests {
347335

348336
#[tokio::test]
349337
async fn test_rank_min_max_functions() {
350-
let mut knn_results = HashMap::new();
351338
let query = KnnQuery {
352339
embedding: chroma_types::operator::QueryVector::Dense(vec![0.1]),
353340
key: String::new(),
354341
limit: 2,
355342
};
356-
357-
knn_results.insert(
358-
query.clone(),
359-
vec![
360-
RecordMeasure {
361-
offset_id: 1,
362-
measure: 0.8,
363-
},
364-
RecordMeasure {
365-
offset_id: 2,
366-
measure: 0.3,
367-
},
368-
],
369-
);
343+
let knn_results = vec![vec![
344+
RecordMeasure {
345+
offset_id: 1,
346+
measure: 0.8,
347+
},
348+
RecordMeasure {
349+
offset_id: 2,
350+
measure: 0.3,
351+
},
352+
]];
370353

371354
// Test max
372355
let rank = Rank::Maximum(vec![

rust/worker/src/execution/orchestration/rank.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@ use chroma_system::{
66
OrchestratorContext, PanicError, TaskError, TaskMessage, TaskResult,
77
};
88
use chroma_types::{
9-
operator::{KnnQuery, Limit, Rank, RecordMeasure, SearchPayloadResult, Select},
9+
operator::{Limit, Rank, RecordMeasure, SearchPayloadResult, Select},
1010
CollectionAndSegments,
1111
};
12-
use std::collections::HashMap;
1312
use thiserror::Error;
1413
use tokio::sync::oneshot::{error::RecvError, Sender};
1514
use tracing::Span;
@@ -101,7 +100,7 @@ pub struct RankOrchestrator {
101100
queue: usize,
102101

103102
// Input data
104-
knn_results: HashMap<KnnQuery, Vec<RecordMeasure>>,
103+
knn_results: Vec<Vec<RecordMeasure>>,
105104
rank: Rank,
106105
limit: Limit,
107106
select: Select,
@@ -120,7 +119,7 @@ impl RankOrchestrator {
120119
blockfile_provider: BlockfileProvider,
121120
dispatcher: ComponentHandle<Dispatcher>,
122121
queue: usize,
123-
knn_results: HashMap<KnnQuery, Vec<RecordMeasure>>,
122+
knn_results: Vec<Vec<RecordMeasure>>,
124123
rank: Rank,
125124
limit: Limit,
126125
select: Select,

0 commit comments

Comments
 (0)