diff --git a/README.md b/README.md index 63cf212..fc88512 100644 --- a/README.md +++ b/README.md @@ -31,38 +31,32 @@ use heed::EnvOpenOptions; use rand::{rngs::StdRng, SeedableRng}; fn main() -> Result<()> { - const DIM: usize = 3; - let vecs: Vec<[f32; DIM]> = vec![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]; - let env = unsafe { EnvOpenOptions::new() - .map_size(1024 * 1024 * 1024 * 1) // 1GiB + .map_size(1024 * 1024 * 1024) // 1GiB .open("./") } .unwrap(); - let mut wtxn = env.write_txn().unwrap(); + let mut wtxn = env.write_txn()?; let db: Database = env.create_database(&mut wtxn, None)?; - let writer: Writer = Writer::new(db, 0, DIM); + let writer: Writer = Writer::new(db, 0, 3); - // insert into lmdb - writer.add_item(&mut wtxn, 0, &vecs[0])?; - writer.add_item(&mut wtxn, 1, &vecs[1])?; - writer.add_item(&mut wtxn, 2, &vecs[2])?; + // build + writer.add_item(&mut wtxn, 0, &[1.0, 0.0, 0.0])?; + writer.add_item(&mut wtxn, 0, &[0.0, 1.0, 0.0])?; - // ...and build hnsw let mut rng = StdRng::seed_from_u64(42); - let mut builder = writer.builder(&mut rng); builder.ef_construction(100).build::<16,32>(&mut wtxn)?; wtxn.commit()?; - // search hnsw using a new lmdb read transaction + // search let rtxn = env.read_txn()?; let reader = Reader::::open(&rtxn, 0, db)?; let query = vec![0.0, 1.0, 0.0]; - let nns = reader.nns(1).ef_search(10).by_vector(&rtxn, &query)?; + let nns = reader.nns(1).ef_search(10).by_vector(&rtxn, &query)?.into_nns(); dbg!("{:?}", &nns); Ok(()) @@ -81,7 +75,6 @@ db = hannoy.Database(tmp_dir, Metric.COSINE) with db.writer(3, m=4, ef=10) as writer: writer.add_item(0, [1.0, 0.0, 0.0]) writer.add_item(1, [0.0, 1.0, 0.0]) - writer.add_item(2, [0.0, 0.0, 1.0]) reader = db.reader() nns = reader.by_vec([0.0, 1.0, 0.0], n=2) diff --git a/src/lib.rs b/src/lib.rs index 311e6fe..6c1f1c2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -101,7 +101,7 @@ use key::{Key, Prefix, PrefixCodec}; use metadata::{Metadata, MetadataCodec}; use node::{Node, NodeCodec}; use node_id::{NodeId, NodeMode}; -pub use reader::{QueryBuilder, Reader}; +pub use reader::{QueryBuilder, Reader, Searched}; pub use roaring::RoaringBitmapCodec; pub use writer::{HannoyBuilder, Writer}; diff --git a/src/python.rs b/src/python.rs index a23e2b0..03cef6e 100644 --- a/src/python.rs +++ b/src/python.rs @@ -385,7 +385,7 @@ impl PyReader { }; } - let neighbours = match &self.dyn_reader { + let found = match &self.dyn_reader { DynReader::Cosine(reader) => hnsw_search!(reader, &query)?, DynReader::Euclidean(reader) => hnsw_search!(reader, &query)?, DynReader::Manhattan(reader) => hnsw_search!(reader, &query)?, @@ -394,7 +394,7 @@ impl PyReader { DynReader::BqManhattan(reader) => hnsw_search!(reader, &query)?, DynReader::Hamming(reader) => hnsw_search!(reader, &query)?, }; - Ok(neighbours) + Ok(found.into_nns()) } } diff --git a/src/reader.rs b/src/reader.rs index 7505466..969b295 100644 --- a/src/reader.rs +++ b/src/reader.rs @@ -33,6 +33,31 @@ const LINEAR_SEARCH_THRESHOLD: u64 = 1000; /// to zero to make sure we test the HNSW algorithm. const LINEAR_SEARCH_THRESHOLD: u64 = 0; +/// Container storing nearest neighbour search result +#[derive(Debug)] +pub struct Searched { + /// The nearest neighbours for the performed query + pub nns: Vec<(ItemId, f32)>, + /// A bool indicating whether or not the search terminated early + pub did_cancel: bool, +} + +impl Searched { + pub(crate) fn new(nns: Vec<(ItemId, f32)>, did_cancel: bool) -> Self { + Searched { nns, did_cancel } + } + + /// Indicates if the search terminated early + pub fn did_cancel(&self) -> bool { + self.did_cancel + } + + /// Consumes `self` and returns vector of nearest neighbours + pub fn into_nns(self) -> Vec<(ItemId, f32)> { + self.nns + } +} + /// Options used to make a query against an hannoy [`Reader`]. pub struct QueryBuilder<'a, D: Distance> { reader: &'a Reader, @@ -53,8 +78,44 @@ impl<'a, D: Distance> QueryBuilder<'a, D> { /// # let (reader, rtxn): (Reader, heed::RoTxn) = todo!(); /// reader.nns(20).by_item(&rtxn, 5); /// ``` - pub fn by_item(&self, rtxn: &RoTxn, item: ItemId) -> Result>> { - self.reader.nns_by_item(rtxn, item, self) + pub fn by_item(&self, rtxn: &RoTxn, item: ItemId) -> Result> { + self.reader.nns_by_item(rtxn, item, self, || false).map(|res| match res { + Some(Completion::Done(items)) => Some(Searched::new(items, false)), + Some(Completion::Cancelled(_)) => { + unreachable!("cancellation only possible using by_item_with_cancellation") + } + None => None, + }) + } + + /// Returns as many nearest neighbours to the query as possible before `cancel_fn` evaluates to + /// true, and indicates whether or not search terminated early. + /// + /// See also [`Self::by_vector_with_cancellation`]. + /// + /// # Examples + /// + /// ```no_run + /// # use hannoy::{Reader, distances::Euclidean, Searched}; + /// # let (reader, rtxn): (Reader, heed::RoTxn) = todo!(); + /// use std::time::{Instant, Duration}; + /// + /// let later = Instant::now().checked_add(Duration::from_secs(1)).unwrap(); + /// let cancel_fn = || Instant::now() > later; + /// let Searched{ nns, did_cancel } = reader.nns(20).by_item_with_cancellation(&rtxn, 5, cancel_fn)?.unwrap(); + /// # Ok::<(), hannoy::Error>(()) + /// ``` + pub fn by_item_with_cancellation( + &self, + rtxn: &RoTxn, + item: ItemId, + cancel_fn: impl Fn() -> bool, + ) -> Result> { + self.reader.nns_by_item(rtxn, item, self, cancel_fn).map(|res| match res { + Some(Completion::Done(done)) => Some(Searched::new(done, false)), + Some(Completion::Cancelled(cancelled)) => Some(Searched::new(cancelled, true)), + None => None, + }) } /// Returns the closest items from the provided `vector`. @@ -68,7 +129,7 @@ impl<'a, D: Distance> QueryBuilder<'a, D> { /// # let (reader, rtxn): (Reader, heed::RoTxn) = todo!(); /// reader.nns(20).by_vector(&rtxn, &[1.25854, -0.75598, 0.58524]); /// ``` - pub fn by_vector(&self, rtxn: &RoTxn, vector: &'a [f32]) -> Result> { + pub fn by_vector(&self, rtxn: &RoTxn, vector: &'a [f32]) -> Result { if vector.len() != self.reader.dimensions() { return Err(Error::InvalidVecDimension { expected: self.reader.dimensions(), @@ -78,7 +139,52 @@ impl<'a, D: Distance> QueryBuilder<'a, D> { let vector = UnalignedVector::from_slice(vector); let item = Item { header: D::new_header(&vector), vector }; - self.reader.nns_by_vec(rtxn, &item, self) + + let cancel_fn = || false; + let neighbours = + self.reader.nns_by_vec(rtxn, &item, self, cancel_fn).map(|res| res.into_inner())?; + + Ok(Searched::new(neighbours, false)) + } + + /// Returns as many nearest neighbours to the query as possible before `cancel_fn` evaluates to + /// true, and indicates whether or not search terminated early. + /// + /// See also [`Self::by_item_with_cancellation`]. + /// + /// # Examples + /// + /// ```no_run + /// # use hannoy::{Reader, distances::Euclidean, Searched}; + /// # let (reader, rtxn): (Reader, heed::RoTxn) = todo!(); + /// use std::time::{Instant, Duration}; + /// + /// let later = Instant::now().checked_add(Duration::from_secs(1)).unwrap(); + /// let cancel_fn = || Instant::now() > later; + /// let Searched{ nns, did_cancel } = reader.nns(20).by_vector_with_cancellation(&rtxn, &[1.25854, -0.75598, 0.58524], cancel_fn)?; + /// # Ok::<(), hannoy::Error>(()) + /// ``` + pub fn by_vector_with_cancellation( + &self, + rtxn: &RoTxn, + vector: &'a [f32], + cancel_fn: impl Fn() -> bool, + ) -> Result { + if vector.len() != self.reader.dimensions() { + return Err(Error::InvalidVecDimension { + expected: self.reader.dimensions(), + received: vector.len(), + }); + } + + let vector = UnalignedVector::from_slice(vector); + let item = Item { header: D::new_header(&vector), vector }; + + let nns = self.reader.nns_by_vec(rtxn, &item, self, cancel_fn)?; + match nns { + Completion::Done(done) => Ok(Searched::new(done, false)), + Completion::Cancelled(cancelled) => Ok(Searched::new(cancelled, true)), + } } /// Specify a subset of candidates to inspect. Filters out everything else. @@ -114,11 +220,120 @@ impl<'a, D: Distance> QueryBuilder<'a, D> { } } +enum Completion { + Done(T), + Cancelled(T), +} +impl Completion { + pub fn into_inner(self) -> T { + match self { + Completion::Done(inner) => inner, + Completion::Cancelled(inner) => inner, + } + } + pub fn map(self, op: impl FnOnce(T) -> U) -> Completion { + match self { + Self::Done(inner) => Completion::Done(op(inner)), + Self::Cancelled(inner) => Completion::Cancelled(op(inner)), + } + } +} + +struct Visitor<'a> { + pub eps: Vec, + pub level: usize, + pub ef: usize, + pub candidates: Option<&'a RoaringBitmap>, +} +impl<'a> Visitor<'a> { + pub fn new( + eps: Vec, + level: usize, + ef: usize, + candidates: Option<&'a RoaringBitmap>, + ) -> Self { + Self { eps, level, ef, candidates } + } + + /// Iteratively traverse a given level of the HNSW graph, updating the search path history. + /// Returns a Min-Max heap of size ef nearest neighbours to the query in that layer. + #[allow(clippy::too_many_arguments)] + pub fn visit( + &self, + query: &Item, + reader: &Reader, + rtxn: &RoTxn, + path: &mut RoaringBitmap, + cancel_fn: &impl Fn() -> bool, + ) -> Result>> { + use Completion::*; + + let mut search_queue = BinaryHeap::new(); + let mut res = MinMaxHeap::with_capacity(self.ef); + + // Register all entry points as visited and populate candidates + for &ep in &self.eps[..] { + let ve = get_item(reader.database, reader.index, rtxn, ep)?.unwrap(); + let dist = D::distance(query, &ve); + + search_queue.push((Reverse(OrderedFloat(dist)), ep)); + path.insert(ep); + + if self.candidates.is_none_or(|c| c.contains(ep)) { + res.push((OrderedFloat(dist), ep)); + } + } + + // Stop occurs either once we've done at least ef searches and notice no improvements, or + // when we've exhausted the search queue. + while let Some(&(Reverse(OrderedFloat(f)), _)) = search_queue.peek() { + if cancel_fn() { + return Ok(Cancelled(res)); + } + let f_max = res.peek_max().map(|&(OrderedFloat(d), _)| d).unwrap_or(f32::MAX); + if f > f_max { + break; + } + let (_, c) = search_queue.pop().unwrap(); + + let Links { links } = get_links(rtxn, reader.database, reader.index, c, self.level)? + .expect("Links must exist"); + + for point in links.iter() { + if !path.insert(point) { + continue; + } + let dist = D::distance( + query, + &get_item(reader.database, reader.index, rtxn, point)?.unwrap(), + ); + + // The search queue can take points that aren't included in the (optional) + // candidates bitmap, but the final result must *not* include them. + if res.len() < self.ef || dist < f_max { + search_queue.push((Reverse(OrderedFloat(dist)), point)); + if let Some(c) = self.candidates { + if !c.contains(point) { + continue; + } + } + if res.len() == self.ef { + let _ = res.push_pop_max((OrderedFloat(dist), point)); + } else { + res.push((OrderedFloat(dist), point)); + } + } + } + } + Ok(Done(res)) + } +} + /// A reader over the hannoy hnsw graph #[derive(Debug)] pub struct Reader { - database: Database, - index: u16, + pub(crate) database: Database, + pub(crate) index: u16, entry_points: Vec, max_level: usize, dimensions: usize, @@ -356,93 +571,27 @@ impl Reader { QueryBuilder { reader: self, candidates: None, count, ef: DEFAULT_EF_SEARCH } } - /// Iteratively traverse a given level of the HNSW graph, updating the search path history. - /// Returns a Min-Max heap of size ef nearest neighbours to the query in that layer. - #[allow(clippy::too_many_arguments)] - fn walk_layer( - &self, - query: &Item, - eps: &[ItemId], - level: usize, - ef: usize, - candidates: Option<&RoaringBitmap>, - path: &mut RoaringBitmap, - rtxn: &RoTxn, - ) -> Result> { - let mut search_queue = BinaryHeap::new(); - let mut res = MinMaxHeap::with_capacity(ef); - - // Register all entry points as visited and populate candidates - for &ep in eps { - let ve = get_item(self.database, self.index, rtxn, ep)?.unwrap(); - let dist = D::distance(query, &ve); - - search_queue.push((Reverse(OrderedFloat(dist)), ep)); - path.insert(ep); - - if candidates.is_none_or(|c| c.contains(ep)) { - res.push((OrderedFloat(dist), ep)); - } - } - - // Stop occurs either once we've done at least ef searches and notice no improvements, or - // when we've exhausted the search queue. - while let Some(&(Reverse(OrderedFloat(f)), _)) = search_queue.peek() { - let f_max = res.peek_max().map(|&(OrderedFloat(d), _)| d).unwrap_or(f32::MAX); - if f > f_max { - break; - } - let (_, c) = search_queue.pop().unwrap(); - - let Links { links } = - get_links(rtxn, self.database, self.index, c, level)?.expect("Links must exist"); - - for point in links.iter() { - if !path.insert(point) { - continue; - } - let dist = - D::distance(query, &get_item(self.database, self.index, rtxn, point)?.unwrap()); - - // The search queue can take points that aren't included in the (optional) - // candidates bitmap, but the final result must *not* include them. - if res.len() < ef || dist < f_max { - search_queue.push((Reverse(OrderedFloat(dist)), point)); - if let Some(c) = candidates { - if !c.contains(point) { - continue; - } - } - if res.len() == ef { - let _ = res.push_pop_max((OrderedFloat(dist), point)); - } else { - res.push((OrderedFloat(dist), point)); - } - } - } - } - Ok(res) - } - fn nns_by_vec( &self, rtxn: &RoTxn, query: &Item, opt: &QueryBuilder, - ) -> Result> { + cancel_fn: impl Fn() -> bool, + ) -> Result>> { + use Completion::*; + // If we will never find any candidates, return an empty vector if opt.candidates.is_some_and(|c| self.item_ids().is_disjoint(c)) { - return Ok(Vec::new()); + return Ok(Done(Vec::new())); } // If the number of candidates is less than a given threshold, perform linear search if let Some(candidates) = opt.candidates.filter(|c| c.len() < LINEAR_SEARCH_THRESHOLD) { - let mut nns = self.brute_force_search(query, rtxn, candidates)?; - nns.truncate(opt.count); - return Ok(nns); + return self.brute_force_search(query, rtxn, candidates, opt.count, cancel_fn); } - self.hnsw_search(query, rtxn, opt) + // exhaustive search + self.hnsw_search(query, rtxn, opt, cancel_fn) } /// Directly retrieves items in the candidate list and ranks them by distance to the query. @@ -451,10 +600,18 @@ impl Reader { query: &Item, rtxn: &RoTxn, candidates: &RoaringBitmap, - ) -> Result> { + count: usize, + cancel_fn: impl Fn() -> bool, + ) -> Result>> { + use Completion::*; + let mut item_distances = Vec::with_capacity(candidates.len() as usize); for item_id in candidates { + if cancel_fn() { + return Ok(Cancelled(item_distances)); + } + let Some(vector) = self.item_vector(rtxn, item_id)? else { continue }; let vector = UnalignedVector::from_vec(vector); let item = Item { header: D::new_header(&vector), vector }; @@ -462,7 +619,9 @@ impl Reader { item_distances.push((item_id, distance)); } item_distances.sort_by_key(|(_, dist)| OrderedFloat(*dist)); - Ok(item_distances) + item_distances.truncate(count); + + Ok(Done(item_distances)) } /// Hnsw search according to arXiv:1603.09320. @@ -471,25 +630,55 @@ impl Reader { /// is controlled by `opt.ef`. Since the graph is not necessarily acyclic, search may become /// "trapped" in a local sub-graph with fewer elements than `opt.count` - to account for this /// we run an expensive exhaustive search at the end if fewer nns were returned. + /// + /// To break out of search early, users may wish to provide a `cancel_fn` which terminates the + /// execution of the hnsw search and returns partial results so far. fn hnsw_search( &self, query: &Item, rtxn: &RoTxn, opt: &QueryBuilder, - ) -> Result> { - let mut eps = self.entry_points.clone(); - let mut seen = RoaringBitmap::new(); + cancel_fn: impl Fn() -> bool, + ) -> Result>> { + use Completion::*; + + let cancel_fn = &cancel_fn; + let mut visitor = Visitor::new(self.entry_points.clone(), self.max_level, 1, None); - for lvl in (1..=self.max_level).rev() { - let neighbours = self.walk_layer(query, &eps, lvl, 1, None, &mut seen, rtxn)?; + let mut path = RoaringBitmap::new(); + for _ in (1..=self.max_level).rev() { + let neighbours = visitor.visit(query, self, rtxn, &mut path, &|| false)?.into_inner(); let closest = neighbours.peek_min().map(|(_, n)| n).expect("No neighbor was found"); - eps = vec![*closest]; + + visitor.eps = vec![*closest]; + visitor.level -= 1; } // clear visited set as we only care about level 0 - seen.clear(); - let ef = opt.ef.max(opt.count); + path.clear(); + debug_assert!(visitor.level == 0); + + visitor.ef = opt.ef.max(opt.count); + visitor.candidates = opt.candidates; + + macro_rules! return_if_cancelled { + ($completion: expr) => { + match $completion { + Completion::Done(done) => done, + cancelled => { + return Ok(cancelled.map(|mut found| { + found + .drain_asc() + .map(|(OrderedFloat(f), i)| (i, f)) + .take(opt.count) + .collect() + })) + } + } + }; + } + let mut neighbours = - self.walk_layer(query, &eps, 0, ef, opt.candidates, &mut seen, rtxn)?; + return_if_cancelled!(visitor.visit(query, self, rtxn, &mut path, cancel_fn)?); // If we still don't have enough nns (e.g. search encountered cyclic subgraphs) then do exhaustive // search over remaining unseen items. @@ -502,19 +691,16 @@ impl Reader { while let Some((key, _)) = cursor.next().transpose()? { let id = key.node.item; - if seen.contains(id) { + if path.contains(id) { continue; } - let more_nns = self.walk_layer( - query, - &[id], - 0, - opt.count - neighbours.len(), - opt.candidates, - &mut seen, - rtxn, - )?; + visitor.eps = vec![id]; + visitor.ef = opt.count - neighbours.len(); + + let more_nns = + return_if_cancelled!(visitor.visit(query, self, rtxn, &mut path, cancel_fn)?); + neighbours.extend(more_nns.into_iter()); if neighbours.len() >= opt.count { break; @@ -522,7 +708,9 @@ impl Reader { } } - Ok(neighbours.drain_asc().map(|(OrderedFloat(f), i)| (i, f)).take(opt.count).collect()) + let found = + neighbours.drain_asc().map(|(OrderedFloat(f), i)| (i, f)).take(opt.count).collect(); + Ok(Done(found)) } /// Returns the nearest points to the item id, not including the point itself. @@ -531,12 +719,17 @@ impl Reader { /// `&[item]` instead of the hnsw entrypoints. Since search starts in the true neighbourhood of /// the item fewer comparisons are needed to retrieve the nearest neighbours, making it more /// efficient than simply calling `Reader.nns_by_vec` with the associated vector. + #[allow(clippy::type_complexity)] fn nns_by_item( &self, rtxn: &RoTxn, item: ItemId, opt: &QueryBuilder, - ) -> Result>> { + cancel_fn: impl Fn() -> bool, + ) -> Result>>> { + use Completion::*; + let cancel_fn = &cancel_fn; + // If we will never find any candidates, return none if opt.candidates.is_some_and(|c| self.item_ids().is_disjoint(c)) { return Ok(None); @@ -548,19 +741,36 @@ impl Reader { // If the number of candidates is less than a given threshold, perform linear search if let Some(candidates) = opt.candidates.filter(|c| c.len() < LINEAR_SEARCH_THRESHOLD) { - let mut nns = self.brute_force_search(&query, rtxn, candidates)?; - nns.truncate(opt.count); + let nns = self.brute_force_search(&query, rtxn, candidates, opt.count, cancel_fn)?; return Ok(Some(nns)); } // Search over all items except `item` let ef = opt.ef.max(opt.count); - let mut seen = RoaringBitmap::new(); + let mut path = RoaringBitmap::new(); let mut candidates = opt.candidates.unwrap_or_else(|| self.item_ids()).clone(); candidates.remove(item); + let mut visitor = Visitor::new(vec![item], 0, ef, Some(&candidates)); + + macro_rules! return_if_cancelled { + ($completion: expr) => { + match $completion { + Completion::Done(done) => done, + cancelled => { + return Ok(Some(cancelled.map(|mut found| { + found + .drain_asc() + .map(|(OrderedFloat(f), i)| (i, f)) + .take(opt.count) + .collect() + }))) + } + } + }; + } let mut neighbours = - self.walk_layer(&query, &[item], 0, ef, Some(&candidates), &mut seen, rtxn)?; + return_if_cancelled!(visitor.visit(&query, self, rtxn, &mut path, cancel_fn)?); // If we still don't have enough nns (e.g. search encountered cyclic subgraphs) then do exhaustive // search over remaining unseen items. @@ -573,19 +783,16 @@ impl Reader { while let Some((key, _)) = cursor.next().transpose()? { let id = key.node.item; - if seen.contains(id) { + if path.contains(id) { continue; } - let more_nns = self.walk_layer( - &query, - &[id], - 0, - opt.count - neighbours.len(), - opt.candidates, - &mut seen, - rtxn, - )?; + // update walker + visitor.eps = vec![id]; + visitor.ef = opt.count - neighbours.len(); + + let more_nns = + return_if_cancelled!(visitor.visit(&query, self, rtxn, &mut path, cancel_fn)?); neighbours.extend(more_nns.into_iter()); if neighbours.len() >= opt.count { break; @@ -593,9 +800,9 @@ impl Reader { } } - let found = + let found: Vec<_> = neighbours.drain_asc().map(|(OrderedFloat(f), i)| (i, f)).take(opt.count).collect(); - Ok(Some(found)) + Ok(Some(Done(found))) } /// NOTE: a [`crate::Reader`] can't be opened unless updates are commited through a build ! diff --git a/src/tests/reader.rs b/src/tests/reader.rs index b97ea29..fc81158 100644 --- a/src/tests/reader.rs +++ b/src/tests/reader.rs @@ -63,13 +63,15 @@ fn search_on_candidates_has_right_num() { let c: [u32; 10] = std::array::from_fn(|_| thread_rng().gen::() % 1000); let candidates = RoaringBitmap::from_iter(c); - let found = reader.nns(10).candidates(&candidates).by_vector(&rtxn, &query).unwrap(); + let _found = reader.nns(10).candidates(&candidates).by_vector(&rtxn, &query).unwrap(); + let found = _found.into_nns(); assert_eq!(&RoaringBitmap::from_iter(found.into_iter().map(|(i, _)| i)), &candidates); // search with 1 candidate let c: [u32; 1] = std::array::from_fn(|_| thread_rng().gen::() % 1000); let candidates = RoaringBitmap::from_iter(c); - let found = reader.nns(1).candidates(&candidates).by_vector(&rtxn, &query).unwrap(); + let _found = reader.nns(1).candidates(&candidates).by_vector(&rtxn, &query).unwrap(); + let found = _found.into_nns(); assert_eq!(&RoaringBitmap::from_iter(found.into_iter().map(|(i, _)| i)), &candidates); } } @@ -87,7 +89,8 @@ fn all_items_are_reachable(n: usize) { assert_eq!(reader.item_ids().len(), n as u64); assert!((0..n as u32).all(|i| reader.contains_item(&rtxn, i).unwrap())); - let found = reader.nns(n).ef_search(n).by_vector(&rtxn, &[0.0; DIM]).unwrap(); + let _found = reader.nns(n).ef_search(n).by_vector(&rtxn, &[0.0; DIM]).unwrap(); + let found = _found.into_nns(); assert_eq!(&RoaringBitmap::from_iter(found.into_iter().map(|(id, _)| id)), reader.item_ids()) } @@ -113,7 +116,7 @@ fn search_by_item_does_not_contain_item() { let reader = crate::Reader::::open(&rtxn, 0, database).unwrap(); - let found = reader.nns(10).by_item(&rtxn, 0).unwrap().unwrap(); + let found = reader.nns(10).by_item(&rtxn, 0).unwrap().unwrap().into_nns(); assert!(found.len() == 10); assert!(!found.contains(&(0, 0.0))) } @@ -133,3 +136,30 @@ fn search_by_item_returns_none_if_not_exists() { let found = reader.nns(10).by_item(&rtxn, 101).unwrap(); assert!(found.is_none()); } + +#[test] +fn search_cancellation_works() { + const DIM: usize = 768; + let mut rng = rng(); + + let DatabaseHandle { env, database, tempdir: _ } = + create_database_indices_with_items::(0..1, 100, &mut rng); + let rtxn = env.read_txn().unwrap(); + + let reader = crate::Reader::::open(&rtxn, 0, database).unwrap(); + + // use an item id that does not exist + let query: [f32; DIM] = std::array::from_fn(|_| rng.gen()); + + // by vector + let searched = reader.nns(10).by_vector_with_cancellation(&rtxn, &query, || false).unwrap(); + assert!(!searched.did_cancel()); + let searched = reader.nns(10).by_vector_with_cancellation(&rtxn, &query, || true).unwrap(); + assert!(searched.did_cancel()); + + // by item + let searched = reader.nns(10).by_item_with_cancellation(&rtxn, 0, || false).unwrap().unwrap(); + assert!(!searched.did_cancel()); + let searched = reader.nns(10).by_item_with_cancellation(&rtxn, 0, || true).unwrap().unwrap(); + assert!(searched.did_cancel()); +} diff --git a/src/tests/writer.rs b/src/tests/writer.rs index 0be526f..3153684 100644 --- a/src/tests/writer.rs +++ b/src/tests/writer.rs @@ -287,7 +287,7 @@ fn convert_from_arroy_to_hannoy() { let reader = Reader::open(&wtxn, index, database).unwrap(); let vec = reader.item_vector(&wtxn, item_id).unwrap(); assert_eq!(vec.as_deref(), Some(&vector[..])); - let mut found = reader.nns(1).by_vector(&wtxn, &vector).unwrap(); + let mut found = reader.nns(1).by_vector(&wtxn, &vector).unwrap().into_nns(); dbg!(&found); let (found_item_id, found_distance) = found.pop().unwrap(); assert_eq!(found_item_id, item_id); @@ -363,7 +363,7 @@ fn convert_from_arroy_to_hannoy_binary_quantized() { let reader = Reader::open(&wtxn, index, database).unwrap(); let vec = reader.item_vector(&wtxn, item_id).unwrap(); assert_eq!(vec.as_deref(), Some(&vector[..])); - let mut found = reader.nns(1).by_vector(&wtxn, &vector).unwrap(); + let mut found = reader.nns(1).by_vector(&wtxn, &vector).unwrap().into_nns(); dbg!(&found); let (found_item_id, found_distance) = found.pop().unwrap(); assert_eq!(found_item_id, item_id); @@ -516,7 +516,12 @@ fn delete_document_in_an_empty_index_74() { let reader = Reader::open(&wtxn, 1, handle.database).unwrap(); let ret = reader.nns(10).by_vector(&wtxn, &[0., 0.]).unwrap(); - insta::assert_debug_snapshot!(ret, @"[]"); + insta::assert_debug_snapshot!(ret, @r" + Searched { + nns: [], + did_cancel: false, + } + "); wtxn.commit().unwrap(); @@ -534,7 +539,12 @@ fn delete_document_in_an_empty_index_74() { let rtxn = handle.env.read_txn().unwrap(); let reader = Reader::open(&rtxn, 1, handle.database).unwrap(); let ret = reader.nns(10).by_vector(&rtxn, &[0., 0.]).unwrap(); - insta::assert_debug_snapshot!(ret, @"[]"); + insta::assert_debug_snapshot!(ret, @r" + Searched { + nns: [], + did_cancel: false, + } + "); } #[test]