Skip to content

Commit 97ba58d

Browse files
author
nnethercott
committed
apply review suggestions, simplify top k fn
1 parent a1ce2f1 commit 97ba58d

File tree

2 files changed

+10
-8
lines changed

2 files changed

+10
-8
lines changed

src/reader.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ impl<'t, D: Distance> Reader<'t, D> {
392392

393393
// Get k nearest neighbors
394394
let k = opt.count.min(nns_distances.len());
395-
let top_k = median_based_top_k(nns_distances, k, (OrderedFloat(f32::MAX), u32::MAX));
395+
let top_k = median_based_top_k(nns_distances, k);
396396
let mut output = Vec::with_capacity(top_k.len());
397397
for (OrderedFloat(dist), item) in top_k {
398398
output.push((item, D::normalized_distance(dist, self.dimensions)));
@@ -604,15 +604,16 @@ pub fn item_leaf<'a, D: Distance>(
604604
}
605605

606606
// Based on https://quickwit.io/blog/top-k-complexity, implemented in https://github.com/meilisearch/arroy/pull/129
607-
pub fn median_based_top_k<T>(v: Vec<T>, k: usize, mut threshold: T) -> Vec<T>
608-
where
609-
T: Ord + Copy,
610-
{
607+
pub fn median_based_top_k(
608+
v: Vec<(OrderedFloat<f32>, u32)>,
609+
k: usize,
610+
) -> Vec<(OrderedFloat<f32>, u32)> {
611+
let mut threshold = (OrderedFloat(f32::MAX), u32::MAX);
611612
let mut buffer = Vec::with_capacity(2 * k.max(1));
612613

613614
// prefill with no threshold checks
614615
let mut v = v.into_iter();
615-
buffer.extend((&mut v).take(k));
616+
buffer.extend((&mut v).take(2 * k));
616617

617618
for item in v {
618619
if item >= threshold {

src/tests/reader.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,10 +285,11 @@ proptest! {
285285
(Just(v), k_strategy)
286286
})
287287
){
288-
let original: Vec<_> = original.into_iter().map(|item| OrderedFloat(item)).collect();
288+
let original: Vec<(OrderedFloat<f32>, u32)> =
289+
original.into_iter().enumerate().map(|(num, item)| (OrderedFloat(item), num as u32)).collect();
289290

290291
let u = binary_heap_based_top_k(original.clone(), k);
291-
let v = median_based_top_k(original, k, OrderedFloat(f32::MAX));
292+
let v = median_based_top_k(original, k);
292293

293294
assert_eq!(u, v);
294295
}

0 commit comments

Comments
 (0)