|
1 | | -use std::cmp::Reverse; |
2 | 1 | use std::collections::BinaryHeap; |
3 | 2 | use std::iter::repeat; |
4 | 3 | use std::marker; |
@@ -388,19 +387,16 @@ impl<'t, D: Distance> Reader<'t, D> { |
388 | 387 | unreachable!() |
389 | 388 | }; |
390 | 389 | let distance = D::built_distance(query_leaf, &leaf); |
391 | | - nns_distances.push(Reverse((OrderedFloat(distance), nn))); |
| 390 | + nns_distances.push((OrderedFloat(distance), nn)); |
392 | 391 | } |
393 | 392 |
|
394 | | - let mut sorted_nns = BinaryHeap::from(nns_distances); |
395 | | - let capacity = opt.count.min(sorted_nns.len()); |
396 | | - let mut output = Vec::with_capacity(capacity); |
397 | | - while let Some(Reverse((OrderedFloat(dist), item))) = sorted_nns.pop() { |
398 | | - if output.len() == capacity { |
399 | | - break; |
400 | | - } |
| 393 | + // Get k nearest neighbors |
| 394 | + let k = opt.count.min(nns_distances.len()); |
| 395 | + let top_k = median_based_top_k(nns_distances, k); |
| 396 | + let mut output = Vec::with_capacity(top_k.len()); |
| 397 | + for (OrderedFloat(dist), item) in top_k { |
401 | 398 | output.push((item, D::normalized_distance(dist, self.dimensions))); |
402 | 399 | } |
403 | | - |
404 | 400 | Ok(output) |
405 | 401 | } |
406 | 402 |
|
@@ -606,3 +602,39 @@ pub fn item_leaf<'a, D: Distance>( |
606 | 602 | None => Ok(None), |
607 | 603 | } |
608 | 604 | } |
| 605 | + |
| 606 | +// Based on https://quickwit.io/blog/top-k-complexity, implemented in https://github.com/meilisearch/arroy/pull/129 |
| 607 | +pub fn median_based_top_k( |
| 608 | + v: Vec<(OrderedFloat<f32>, u32)>, |
| 609 | + k: usize, |
| 610 | +) -> Vec<(OrderedFloat<f32>, u32)> { |
| 611 | + let mut threshold = (OrderedFloat(f32::MAX), u32::MAX); |
| 612 | + let mut buffer = Vec::with_capacity(2 * k.max(1)); |
| 613 | + |
| 614 | + // prefill with no threshold checks |
| 615 | + let mut v = v.into_iter(); |
| 616 | + buffer.extend((&mut v).take(2 * k)); |
| 617 | + |
| 618 | + for item in v { |
| 619 | + if item >= threshold { |
| 620 | + continue; |
| 621 | + } |
| 622 | + if buffer.len() == 2 * k { |
| 623 | + let (_, &mut median, _) = buffer.select_nth_unstable(k - 1); |
| 624 | + threshold = median; |
| 625 | + buffer.truncate(k); |
| 626 | + } |
| 627 | + |
| 628 | + // avoids buffer resizing from being inlined from vec.push() |
| 629 | + let uninit = buffer.spare_capacity_mut(); |
| 630 | + uninit[0].write(item); |
| 631 | + // SAFETY: we would have panicked above already |
| 632 | + unsafe { |
| 633 | + buffer.set_len(buffer.len() + 1); |
| 634 | + } |
| 635 | + } |
| 636 | + |
| 637 | + buffer.sort_unstable(); |
| 638 | + buffer.truncate(k); |
| 639 | + buffer |
| 640 | +} |
0 commit comments