Skip to content

Commit c668bb7

Browse files
authored
Merge pull request #129 from nnethercott/replace-minheap-for-median
Make QueryBuilder.by_item() faster with a better top k algorithm
2 parents f3b594a + 97ba58d commit c668bb7

File tree

3 files changed

+82
-10
lines changed

3 files changed

+82
-10
lines changed

src/reader.rs

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
use std::cmp::Reverse;
21
use std::collections::BinaryHeap;
32
use std::iter::repeat;
43
use std::marker;
@@ -388,19 +387,16 @@ impl<'t, D: Distance> Reader<'t, D> {
388387
unreachable!()
389388
};
390389
let distance = D::built_distance(query_leaf, &leaf);
391-
nns_distances.push(Reverse((OrderedFloat(distance), nn)));
390+
nns_distances.push((OrderedFloat(distance), nn));
392391
}
393392

394-
let mut sorted_nns = BinaryHeap::from(nns_distances);
395-
let capacity = opt.count.min(sorted_nns.len());
396-
let mut output = Vec::with_capacity(capacity);
397-
while let Some(Reverse((OrderedFloat(dist), item))) = sorted_nns.pop() {
398-
if output.len() == capacity {
399-
break;
400-
}
393+
// Get k nearest neighbors
394+
let k = opt.count.min(nns_distances.len());
395+
let top_k = median_based_top_k(nns_distances, k);
396+
let mut output = Vec::with_capacity(top_k.len());
397+
for (OrderedFloat(dist), item) in top_k {
401398
output.push((item, D::normalized_distance(dist, self.dimensions)));
402399
}
403-
404400
Ok(output)
405401
}
406402

@@ -606,3 +602,39 @@ pub fn item_leaf<'a, D: Distance>(
606602
None => Ok(None),
607603
}
608604
}
605+
606+
// Based on https://quickwit.io/blog/top-k-complexity, implemented in https://github.com/meilisearch/arroy/pull/129
607+
pub fn median_based_top_k(
608+
v: Vec<(OrderedFloat<f32>, u32)>,
609+
k: usize,
610+
) -> Vec<(OrderedFloat<f32>, u32)> {
611+
let mut threshold = (OrderedFloat(f32::MAX), u32::MAX);
612+
let mut buffer = Vec::with_capacity(2 * k.max(1));
613+
614+
// prefill with no threshold checks
615+
let mut v = v.into_iter();
616+
buffer.extend((&mut v).take(2 * k));
617+
618+
for item in v {
619+
if item >= threshold {
620+
continue;
621+
}
622+
if buffer.len() == 2 * k {
623+
let (_, &mut median, _) = buffer.select_nth_unstable(k - 1);
624+
threshold = median;
625+
buffer.truncate(k);
626+
}
627+
628+
// avoids buffer resizing from being inlined from vec.push()
629+
let uninit = buffer.spare_capacity_mut();
630+
uninit[0].write(item);
631+
// SAFETY: we would have panicked above already
632+
unsafe {
633+
buffer.set_len(buffer.len() + 1);
634+
}
635+
}
636+
637+
buffer.sort_unstable();
638+
buffer.truncate(k);
639+
buffer
640+
}

src/tests/mod.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::cmp::Reverse;
2+
use std::collections::BinaryHeap;
13
use std::fmt;
24

35
use heed::types::LazyDecode;
@@ -101,3 +103,19 @@ fn create_database<D: Distance>() -> DatabaseHandle<D> {
101103
fn rng() -> StdRng {
102104
StdRng::from_seed(std::array::from_fn(|_| 42))
103105
}
106+
107+
fn binary_heap_based_top_k<T: Ord>(v: Vec<T>, k: usize) -> Vec<T> {
108+
// max to min heap
109+
let v: Vec<_> = v.into_iter().map(Reverse).collect();
110+
111+
let mut heap = BinaryHeap::from(v);
112+
let mut output = Vec::with_capacity(k);
113+
114+
while let Some(Reverse(item)) = heap.pop() {
115+
if output.len() == k {
116+
break;
117+
}
118+
output.push(item);
119+
}
120+
output
121+
}

src/tests/reader.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
use std::fmt::Display;
22
use std::num::NonZeroUsize;
33

4+
use ordered_float::OrderedFloat;
5+
use proptest::collection::vec;
6+
use proptest::prelude::*;
47
use roaring::RoaringBitmap;
58

69
use super::*;
710
use crate::distance::Cosine;
811
use crate::distances::{Euclidean, Manhattan};
12+
use crate::reader::median_based_top_k;
913
use crate::{ItemId, Reader, Writer};
1014

1115
pub struct NnsRes(pub Option<Vec<(ItemId, f32)>>);
@@ -272,3 +276,21 @@ fn try_reading_in_a_non_built_database() {
272276
)
273277
"###);
274278
}
279+
280+
proptest! {
281+
#[test]
282+
fn median_top_k_vs_binary_heap(
283+
(original, k) in vec(any::<f32>(), 1..1000).prop_flat_map(|v|{
284+
let k_strategy = 1..=v.len();
285+
(Just(v), k_strategy)
286+
})
287+
){
288+
let original: Vec<(OrderedFloat<f32>, u32)> =
289+
original.into_iter().enumerate().map(|(num, item)| (OrderedFloat(item), num as u32)).collect();
290+
291+
let u = binary_heap_based_top_k(original.clone(), k);
292+
let v = median_based_top_k(original, k);
293+
294+
assert_eq!(u, v);
295+
}
296+
}

0 commit comments

Comments
 (0)