From 61e6706448b11201eb09fff031f665a1d504bb28 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 11 Apr 2023 01:10:08 -0400 Subject: [PATCH 1/2] implements insert on Hnsw and HnswMap and test --- instant-distance/src/lib.rs | 55 +++++++++++++++++++++++ instant-distance/tests/all.rs | 84 ++++++++++++++++++++++++++++++++++- 2 files changed, 138 insertions(+), 1 deletion(-) diff --git a/instant-distance/src/lib.rs b/instant-distance/src/lib.rs index a9f06f6..5e81514 100644 --- a/instant-distance/src/lib.rs +++ b/instant-distance/src/lib.rs @@ -170,6 +170,12 @@ where pub fn get(&self, i: usize, search: &Search) -> Option> { Some(MapItem::from(self.hnsw.get(i, search)?, self)) } + + pub fn insert(&mut self, point: P, value: V) -> Result> { + let point_id = self.hnsw.insert(point, 100, Some(Heuristic::default())); + self.values.push(value); + Ok(point_id) + } } pub struct MapItem<'a, P, V> { @@ -394,6 +400,55 @@ where pub fn get(&self, i: usize, search: &Search) -> Option> { Some(Item::new(search.nearest.get(i).copied()?, self)) } + + pub fn insert( + &mut self, + point: P, + ef_construction: usize, + heuristic: Option, + ) -> PointId { + let new_pid = self.points.len(); + let new_point_id = PointId(new_pid as u32); + + self.points.push(point); + self.zero.push(ZeroNode::default()); + + let zeros = self + .zero + .iter() + .map(|z| RwLock::new(z.clone())) + .collect::>(); + + let top = if self.layers.is_empty() { + LayerId(0) + } else { + LayerId(self.layers.len()) + }; + + let construction = Construction { + zero: zeros.as_slice(), + pool: SearchPool::new(self.points.len()), + top, + points: self.points.as_slice(), + heuristic, + ef_construction, + #[cfg(feature = "indicatif")] + progress: None, + #[cfg(feature = "indicatif")] + done: AtomicUsize::new(0), + }; + + let new_layer = construction.top; + construction.insert(new_point_id, new_layer, &self.layers); + + self.zero = construction + .zero + .iter() + .map(|node| node.read().clone()) + .collect(); + + new_point_id + } } pub struct Item<'a, P> { diff --git a/instant-distance/tests/all.rs b/instant-distance/tests/all.rs index b9fa973..295e9e1 100644 --- a/instant-distance/tests/all.rs +++ b/instant-distance/tests/all.rs @@ -92,7 +92,89 @@ struct Point(f32, f32); impl instant_distance::Point for Point { fn distance(&self, other: &Self) -> f32 { - // Euclidean distance metric + // Euclidean distance metricØ ((self.0 - other.0).powi(2) + (self.1 - other.1).powi(2)).sqrt() } } + +#[test] +#[allow(clippy::float_cmp, clippy::approx_constant)] +fn incremental_insert() { + let points = (0..4) + .map(|i| Point(i as f32, i as f32)) + .collect::>(); + let values = vec!["zero", "one", "two", "three"]; + let seed = ThreadRng::default().gen::(); + let builder = Builder::default().seed(seed); + + let mut map = builder.build(points, values); + + map.insert(Point(4.0, 4.0), "four").expect("Should insert"); + + let mut search = Search::default(); + + for (i, item) in map.search(&Point(4.0, 4.0), &mut search).enumerate() { + match i { + 0 => { + assert_eq!(item.distance, 0.0); + assert_eq!(item.value, &"four"); + } + 1 => { + assert_eq!(item.distance, 1.4142135); + assert!(item.value == &"three"); + } + 2 => { + assert_eq!(item.distance, 2.828427); + assert!(item.value == &"two"); + } + 3 => { + assert_eq!(item.distance, 4.2426405); + assert!(item.value == &"one"); + } + 4 => { + assert_eq!(item.distance, 5.656854); + assert!(item.value == &"zero"); + } + _ => unreachable!(), + } + } + + // Note + // This has the same expected results as incremental_insert but builds + // the whole map in one go. Only here for comparison. + { + let points = (0..5) + .map(|i| Point(i as f32, i as f32)) + .collect::>(); + let values = vec!["zero", "one", "two", "three", "four"]; + let seed = ThreadRng::default().gen::(); + let builder = Builder::default().seed(seed); + let map = builder.build(points, values); + let mut search = Search::default(); + for (i, item) in map.search(&Point(4.0, 4.0), &mut search).enumerate() { + match i { + 0 => { + assert_eq!(item.distance, 0.0); + assert_eq!(item.value, &"four"); + } + 1 => { + assert_eq!(item.distance, 1.4142135); + assert!(item.value == &"three"); + } + 2 => { + assert_eq!(item.distance, 2.828427); + assert!(item.value == &"two"); + } + 3 => { + assert_eq!(item.distance, 4.2426405); + assert!(item.value == &"one"); + } + 4 => { + assert_eq!(item.distance, 5.656854); + assert!(item.value == &"zero"); + } + _ => unreachable!(), + } + } + } +} From b5013d15405f5e81412240e6878db2aadfcaabb6 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 11 Apr 2023 01:11:14 -0400 Subject: [PATCH 2/2] fix stray keypress --- instant-distance/tests/all.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/instant-distance/tests/all.rs b/instant-distance/tests/all.rs index 295e9e1..c895c5a 100644 --- a/instant-distance/tests/all.rs +++ b/instant-distance/tests/all.rs @@ -92,7 +92,7 @@ struct Point(f32, f32); impl instant_distance::Point for Point { fn distance(&self, other: &Self) -> f32 { - // Euclidean distance metricØ + // Euclidean distance metric ((self.0 - other.0).powi(2) + (self.1 - other.1).powi(2)).sqrt() } }