diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 9edcdb8..bf86695 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -53,7 +53,7 @@ jobs: - name: Test Python bindings run: | sudo apt-get install -y wamerican - cp target/release/libinstant_distance.so instant-distance-py/test/instant_distance.so + cp target/release/libinstant_distance_py.so instant-distance-py/test/instant_distance.so PYTHONPATH=instant-distance-py/test/ python3 -m test lint: diff --git a/Makefile b/Makefile index ef56aec..6b0105e 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,9 @@ -test-python: - cargo build --release - cp target/release/libinstant_distance.dylib instant-distance-py/test/instant_distance.so +instant-distance-py/test/instant_distance.so: instant-distance-py/src/lib.rs + RUSTFLAGS="-C target-cpu=native" cargo build --release + ([ -f target/release/libinstant_distance_py.dylib ] && cp target/release/libinstant_distance_py.dylib instant-distance-py/test/instant_distance.so) || \ + ([ -f target/release/libinstant_distance_py.so ] && cp target/release/libinstant_distance_py.so instant-distance-py/test/instant_distance.so) + +test-python: instant-distance-py/test/instant_distance.so PYTHONPATH=instant-distance-py/test/ python3 -m test clean: diff --git a/instant-distance-py/Cargo.toml b/instant-distance-py/Cargo.toml index e8e3357..6821783 100644 --- a/instant-distance-py/Cargo.toml +++ b/instant-distance-py/Cargo.toml @@ -11,12 +11,23 @@ repository = "https://github.com/InstantDomain/instant-distance" readme = "../README.md" [lib] +name = "instant_distance_py" +crate-type = ["cdylib", "lib"] + +[package.metadata.maturin] name = "instant_distance" -crate-type = ["cdylib"] [dependencies] +aligned-vec = { version = "0.5.0", features = ["serde"] } bincode = "1.3.1" instant-distance = { version = "0.6", path = "../instant-distance", features = ["with-serde"] } pyo3 = { version = "0.18.0", features = ["extension-module"] } serde = { version = "1", features = ["derive"] } -serde-big-array = "0.5.0" + +[dev-dependencies] +bencher = "0.1.5" +rand = { version = "0.8", features = ["small_rng"] } + +[[bench]] +name = "all" +harness = false diff --git a/instant-distance-py/benches/all.rs b/instant-distance-py/benches/all.rs new file mode 100644 index 0000000..1cf3dd7 --- /dev/null +++ b/instant-distance-py/benches/all.rs @@ -0,0 +1,48 @@ +use aligned_vec::avec; +use bencher::{benchmark_group, benchmark_main, Bencher}; + +use instant_distance::{Builder, Metric, Search}; +use instant_distance_py::{EuclidMetric, PointStorage}; +use rand::{rngs::StdRng, Rng, SeedableRng}; + +benchmark_main!(benches); +benchmark_group!(benches, distance, build, query); + +fn distance(bench: &mut Bencher) { + let mut rng = StdRng::seed_from_u64(SEED); + let point_a = avec![rng.gen(); 304]; + let point_b = avec![rng.gen(); 304]; + + bench.iter(|| EuclidMetric::distance(&point_a, &point_b)); +} + +fn build(bench: &mut Bencher) { + let mut rng = StdRng::seed_from_u64(SEED); + let points = (0..1024) + .map(|_| vec![rng.gen(); 304]) + .collect::>(); + + bench.iter(|| { + Builder::default() + .seed(SEED) + .build_hnsw::, [f32], EuclidMetric, PointStorage>(points.clone()) + }); +} + +fn query(bench: &mut Bencher) { + let mut rng = StdRng::seed_from_u64(SEED); + let points = (0..1024) + .map(|_| vec![rng.gen(); 304]) + .collect::>(); + let (hnsw, _) = Builder::default() + .seed(SEED) + .build_hnsw::, [f32], EuclidMetric, PointStorage>(points); + let point = avec![rng.gen(); 304]; + + bench.iter(|| { + let mut search = Search::default(); + let _ = hnsw.search(&point, &mut search); + }); +} + +const SEED: u64 = 123456789; diff --git a/instant-distance-py/pyproject.toml b/instant-distance-py/pyproject.toml index 409cabc..16ce725 100644 --- a/instant-distance-py/pyproject.toml +++ b/instant-distance-py/pyproject.toml @@ -4,3 +4,6 @@ name = "instant-distance" [build-system] requires = ["maturin >= 0.14, < 0.15"] build-backend = "maturin" + +[tool.maturin] +python-source = "python" diff --git a/instant-distance-py/python/instant_distance/__init__.py b/instant-distance-py/python/instant_distance/__init__.py new file mode 100644 index 0000000..0e28f6c --- /dev/null +++ b/instant-distance-py/python/instant_distance/__init__.py @@ -0,0 +1,5 @@ +from .instant_distance import * + +__doc__ = instant_distance.__doc__ +if hasattr(instant_distance, "__all__"): + __all__ = instant_distance.__all__ diff --git a/instant-distance-py/src/lib.rs b/instant-distance-py/src/lib.rs index bc35090..ccd6e9c 100644 --- a/instant-distance-py/src/lib.rs +++ b/instant-distance-py/src/lib.rs @@ -4,16 +4,17 @@ use std::convert::TryFrom; use std::fs::File; use std::io::{BufReader, BufWriter}; -use std::iter::FromIterator; +use std::iter::{repeat, FromIterator}; +use std::ops::Index; -use instant_distance::Point; +use aligned_vec::{AVec, ConstAlign}; +use instant_distance::{Len, Metric, PointId}; use pyo3::conversion::IntoPy; -use pyo3::exceptions::{PyTypeError, PyValueError}; +use pyo3::exceptions::PyValueError; use pyo3::types::{PyList, PyModule, PyString}; use pyo3::{pyclass, pymethods, pymodule}; use pyo3::{Py, PyAny, PyErr, PyObject, PyRef, PyRefMut, PyResult, Python}; use serde::{Deserialize, Serialize}; -use serde_big_array::BigArray; #[pymodule] #[pyo3(name = "instant_distance")] @@ -29,7 +30,7 @@ fn instant_distance_py(_: Python, m: &PyModule) -> PyResult<()> { #[pyclass] struct HnswMap { - inner: instant_distance::HnswMap, + inner: instant_distance::HnswMap<[f32], EuclidMetric, MapValue, PointStorage>, } #[pymethods] @@ -39,7 +40,12 @@ impl HnswMap { fn build(points: &PyList, values: &PyList, config: &Config) -> PyResult { let points = points .into_iter() - .map(FloatArray::try_from) + .map(|v| { + v.iter()? + .into_iter() + .map(|x| x?.extract()) + .collect::, PyErr>>() + }) .collect::, PyErr>>()?; let values = values @@ -47,18 +53,22 @@ impl HnswMap { .map(MapValue::try_from) .collect::, PyErr>>()?; - let hsnw_map = instant_distance::Builder::from(config).build(points, values); + let hsnw_map = instant_distance::Builder::from(config) + .build::, [f32], EuclidMetric, MapValue, PointStorage>(points, values); Ok(Self { inner: hsnw_map }) } /// Load an index from the given file name #[staticmethod] fn load(fname: &str) -> PyResult { - let hnsw_map = - bincode::deserialize_from::<_, instant_distance::HnswMap>( - BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?), - ) - .map_err(|e| PyValueError::new_err(format!("deserialization error: {e:?}")))?; + let hnsw_map = bincode::deserialize_from::< + _, + instant_distance::HnswMap<[f32], EuclidMetric, MapValue, PointStorage>, + >(BufReader::with_capacity( + 32 * 1024 * 1024, + File::open(fname)?, + )) + .map_err(|e| PyValueError::new_err(format!("deserialization error: {e:?}")))?; Ok(Self { inner: hnsw_map }) } @@ -78,7 +88,7 @@ impl HnswMap { /// /// For best performance, reusing `Search` objects is recommended. fn search(slf: Py, point: &PyAny, search: &mut Search, py: Python<'_>) -> PyResult<()> { - let point = FloatArray::try_from(point)?; + let point = try_avec_from_py(point)?; let _ = slf.try_borrow(py)?.inner.search(&point, &mut search.inner); search.cur = Some((HnswType::Map(slf.clone_ref(py)), 0)); Ok(()) @@ -91,7 +101,7 @@ impl HnswMap { /// with a squared Euclidean distance metric. #[pyclass] struct Hnsw { - inner: instant_distance::Hnsw, + inner: instant_distance::Hnsw<[f32], EuclidMetric, PointStorage>, } #[pymethods] @@ -101,10 +111,16 @@ impl Hnsw { fn build(input: &PyList, config: &Config) -> PyResult<(Self, Vec)> { let points = input .into_iter() - .map(FloatArray::try_from) + .map(|v| { + v.iter()? + .into_iter() + .map(|x| x?.extract()) + .collect::, PyErr>>() + }) .collect::, PyErr>>()?; - let (inner, ids) = instant_distance::Builder::from(config).build_hnsw(points); + let (inner, ids) = instant_distance::Builder::from(config) + .build_hnsw::, [f32], EuclidMetric, PointStorage>(points); let ids = Vec::from_iter(ids.into_iter().map(|pid| pid.into_inner())); Ok((Self { inner }, ids)) } @@ -112,9 +128,13 @@ impl Hnsw { /// Load an index from the given file name #[staticmethod] fn load(fname: &str) -> PyResult { - let hnsw = bincode::deserialize_from::<_, instant_distance::Hnsw>( - BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?), - ) + let hnsw = bincode::deserialize_from::< + _, + instant_distance::Hnsw<[f32], EuclidMetric, PointStorage>, + >(BufReader::with_capacity( + 32 * 1024 * 1024, + File::open(fname)?, + )) .map_err(|e| PyValueError::new_err(format!("deserialization error: {e:?}")))?; Ok(Self { inner: hnsw }) } @@ -135,7 +155,7 @@ impl Hnsw { /// /// For best performance, reusing `Search` objects is recommended. fn search(slf: Py, point: &PyAny, search: &mut Search, py: Python<'_>) -> PyResult<()> { - let point = FloatArray::try_from(point)?; + let point = try_avec_from_py(point)?; let _ = slf.try_borrow(py)?.inner.search(&point, &mut search.inner); search.cur = Some((HnswType::Hnsw(slf.clone_ref(py)), 0)); Ok(()) @@ -145,7 +165,7 @@ impl Hnsw { /// Search buffer and result set #[pyclass] struct Search { - inner: instant_distance::Search, + inner: instant_distance::Search<[f32], EuclidMetric>, cur: Option<(HnswType, usize)>, } @@ -346,39 +366,36 @@ impl Neighbor { } } -#[repr(align(32))] -#[derive(Clone, Deserialize, Serialize)] -struct FloatArray(#[serde(with = "BigArray")] [f32; DIMENSIONS]); - -impl TryFrom<&PyAny> for FloatArray { - type Error = PyErr; - - fn try_from(value: &PyAny) -> Result { - let mut new = FloatArray([0.0; DIMENSIONS]); - for (i, val) in value.iter()?.enumerate() { - match i >= DIMENSIONS { - true => return Err(PyTypeError::new_err("point array too long")), - false => new.0[i] = val?.extract::()?, - } - } - Ok(new) +fn try_avec_from_py(value: &PyAny) -> Result>, PyErr> { + let mut new = AVec::new(ALIGNMENT); + for val in value.iter()? { + new.push(val?.extract::()?); } + for _ in 0..PointStorage::padding(new.len()) { + new.push(0.0); + } + Ok(new) } -impl Point for FloatArray { - fn distance(&self, rhs: &Self) -> f32 { +#[derive(Clone, Copy, Deserialize, Serialize)] +pub struct EuclidMetric; + +impl Metric<[f32]> for EuclidMetric { + fn distance(lhs: &[f32], rhs: &[f32]) -> f32 { + debug_assert_eq!(lhs.len(), rhs.len()); + #[cfg(target_arch = "x86_64")] { use std::arch::x86_64::{ _mm256_castps256_ps128, _mm256_extractf128_ps, _mm256_fmadd_ps, _mm256_load_ps, _mm256_setzero_ps, _mm256_sub_ps, _mm_add_ps, _mm_add_ss, _mm_cvtss_f32, - _mm_fmadd_ps, _mm_load_ps, _mm_movehl_ps, _mm_shuffle_ps, _mm_sub_ps, + _mm_movehl_ps, _mm_shuffle_ps, }; - debug_assert_eq!(self.0.len() % 8, 4); + debug_assert_eq!(lhs.len() % 8, 0); unsafe { let mut acc_8x = _mm256_setzero_ps(); - for (lh_slice, rh_slice) in self.0.chunks_exact(8).zip(rhs.0.chunks_exact(8)) { + for (lh_slice, rh_slice) in lhs.chunks_exact(8).zip(rhs.chunks_exact(8)) { let lh_8x = _mm256_load_ps(lh_slice.as_ptr()); let rh_8x = _mm256_load_ps(rh_slice.as_ptr()); let diff = _mm256_sub_ps(lh_8x, rh_8x); @@ -389,11 +406,6 @@ impl Point for FloatArray { let right = _mm256_castps256_ps128(acc_8x); // lower half acc_4x = _mm_add_ps(acc_4x, right); // sum halves - let lh_4x = _mm_load_ps(self.0[DIMENSIONS - 4..].as_ptr()); - let rh_4x = _mm_load_ps(rhs.0[DIMENSIONS - 4..].as_ptr()); - let diff = _mm_sub_ps(lh_4x, rh_4x); - acc_4x = _mm_fmadd_ps(diff, diff, acc_4x); - let lower = _mm_movehl_ps(acc_4x, acc_4x); acc_4x = _mm_add_ps(acc_4x, lower); let upper = _mm_shuffle_ps(acc_4x, acc_4x, 0x1); @@ -402,7 +414,7 @@ impl Point for FloatArray { } } #[cfg(not(target_arch = "x86_64"))] - self.0 + lhs.0 .iter() .zip(rhs.0.iter()) .map(|(&a, &b)| (a - b).powi(2)) @@ -410,6 +422,114 @@ impl Point for FloatArray { } } +#[derive(Debug, Deserialize, Serialize)] +pub struct PointStorage { + point_len: usize, + points_data: AVec, +} + +impl PointStorage { + const fn padding(len: usize) -> usize { + let floats_per_alignment = ALIGNMENT / std::mem::size_of::(); + match len % floats_per_alignment { + 0 => 0, + floats_over_alignment => floats_per_alignment - floats_over_alignment, + } + } + + pub fn iter(&self) -> impl Iterator { + self.points_data.chunks_exact(self.point_len) + } +} + +impl Default for PointStorage { + fn default() -> Self { + Self { + point_len: 1, + points_data: AVec::new(ALIGNMENT), + } + } +} + +impl Index for PointStorage { + type Output = [f32]; + + fn index(&self, index: usize) -> &Self::Output { + let raw_idx = index * self.point_len; + &self.points_data[raw_idx..(raw_idx + self.point_len)] + } +} + +impl Index for PointStorage { + type Output = [f32]; + + fn index(&self, index: PointId) -> &Self::Output { + self.index(index.into_inner() as usize) + } +} + +impl From>> for PointStorage { + fn from(value: Vec>) -> Self { + if let Some(point) = value.first() { + let point_len = point.len(); + let padding = PointStorage::padding(point_len); + let mut points_data = + AVec::with_capacity(ALIGNMENT, value.len() * (point_len + padding)); + for point in value { + // all points should have the same length + debug_assert_eq!(point.len(), point_len); + for v in point.into_iter().chain(repeat(0.0).take(padding)) { + points_data.push(v); + } + } + Self { + point_len: point_len + padding, + points_data, + } + } else { + Default::default() + } + } +} + +impl Len for PointStorage { + fn len(&self) -> usize { + self.points_data.len() / self.point_len + } +} + +impl<'a> IntoIterator for &'a PointStorage { + type Item = &'a [f32]; + + type IntoIter = PointStorageIterator<'a>; + + fn into_iter(self) -> Self::IntoIter { + PointStorageIterator { + storage: self, + next_idx: 0, + } + } +} + +pub struct PointStorageIterator<'a> { + storage: &'a PointStorage, + next_idx: usize, +} + +impl<'a> Iterator for PointStorageIterator<'a> { + type Item = &'a [f32]; + + fn next(&mut self) -> Option { + if self.next_idx < self.storage.len() { + let result = &self.storage[self.next_idx]; + self.next_idx += 1; + Some(result) + } else { + None + } + } +} + #[derive(Clone, Deserialize, Serialize)] enum MapValue { String(String), @@ -431,4 +551,4 @@ impl IntoPy> for &'_ MapValue { } } -const DIMENSIONS: usize = 300; +const ALIGNMENT: usize = 32; diff --git a/instant-distance/benches/all.rs b/instant-distance/benches/all.rs index 43695c9..54c5970 100644 --- a/instant-distance/benches/all.rs +++ b/instant-distance/benches/all.rs @@ -2,7 +2,7 @@ use bencher::{benchmark_group, benchmark_main, Bencher}; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; -use instant_distance::Builder; +use instant_distance::{Builder, Metric}; benchmark_main!(benches); benchmark_group!(benches, build_heuristic); @@ -10,10 +10,14 @@ benchmark_group!(benches, build_heuristic); fn build_heuristic(bench: &mut Bencher) { let mut rng = StdRng::seed_from_u64(SEED); let points = (0..1024) - .map(|_| Point(rng.gen(), rng.gen())) + .map(|_| [rng.gen(), rng.gen()]) .collect::>(); - bench.iter(|| Builder::default().seed(SEED).build_hnsw(points.clone())) + bench.iter(|| { + Builder::default() + .seed(SEED) + .build_hnsw::<[f32; 2], [f32; 2], EuclidMetric, Vec<[f32; 2]>>(points.clone()) + }) } const SEED: u64 = 123456789; @@ -46,12 +50,15 @@ fn randomized(builder: Builder) -> (u64, usize) { } */ -#[derive(Clone, Copy, Debug)] -struct Point(f32, f32); +struct EuclidMetric; -impl instant_distance::Point for Point { - fn distance(&self, other: &Self) -> f32 { +impl Metric<[f32; 2]> for EuclidMetric { + fn distance(a: &[f32; 2], b: &[f32; 2]) -> f32 { // Euclidean distance metric - ((self.0 - other.0).powi(2) + (self.1 - other.1).powi(2)).sqrt() + a.iter() + .zip(b.iter()) + .map(|(&a, &b)| (a - b).powi(2)) + .sum::() + .sqrt() } } diff --git a/instant-distance/examples/colors.rs b/instant-distance/examples/colors.rs index 776347a..5e7ad96 100644 --- a/instant-distance/examples/colors.rs +++ b/instant-distance/examples/colors.rs @@ -1,10 +1,11 @@ -use instant_distance::{Builder, Search}; +use instant_distance::{Builder, Metric, Search}; fn main() { let points = vec![Point(255, 0, 0), Point(0, 255, 0), Point(0, 0, 255)]; let values = vec!["red", "green", "blue"]; - let map = Builder::default().build(points, values); + let map = + Builder::default().build::>(points, values); let mut search = Search::default(); let burnt_orange = Point(204, 85, 0); @@ -17,10 +18,11 @@ fn main() { #[derive(Clone, Copy, Debug)] struct Point(isize, isize, isize); -impl instant_distance::Point for Point { - fn distance(&self, other: &Self) -> f32 { +struct EuclidMetric; + +impl Metric for EuclidMetric { + fn distance(a: &Point, b: &Point) -> f32 { // Euclidean distance metric - (((self.0 - other.0).pow(2) + (self.1 - other.1).pow(2) + (self.2 - other.2).pow(2)) as f32) - .sqrt() + (((a.0 - b.0).pow(2) + (a.1 - b.1).pow(2) + (a.2 - b.2).pow(2)) as f32).sqrt() } } diff --git a/instant-distance/src/lib.rs b/instant-distance/src/lib.rs index a9f06f6..eb455cf 100644 --- a/instant-distance/src/lib.rs +++ b/instant-distance/src/lib.rs @@ -1,6 +1,8 @@ use std::cmp::{max, Ordering, Reverse}; use std::collections::BinaryHeap; use std::collections::HashSet; +use std::marker::PhantomData; +use std::ops::Index; #[cfg(feature = "indicatif")] use std::sync::atomic::{self, AtomicUsize}; @@ -75,12 +77,25 @@ impl Builder { } /// Build an `HnswMap` with the given sets of points and values - pub fn build(self, points: Vec

, values: Vec) -> HnswMap { + pub fn build(self, points: Vec, values: Vec) -> HnswMap + where + P: ?Sized + Send + Sync, + M: Metric

, + V: Clone, + S: Default + From> + Len + Index + Sync, + for<'a> &'a S: IntoIterator, + { HnswMap::new(points, values, self) } /// Build the `Hnsw` with the given set of points - pub fn build_hnsw(self, points: Vec

) -> (Hnsw

, Vec) { + pub fn build_hnsw, S>(self, points: Vec) -> (Hnsw, Vec) + where + P: ?Sized + Send + Sync, + M: Metric

, + S: Default + From> + Len + Index + Sync, + for<'a> &'a S: IntoIterator, + { Hnsw::new(points, self) } @@ -128,17 +143,27 @@ impl Default for Heuristic { } #[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] -pub struct HnswMap { - hnsw: Hnsw

, +pub struct HnswMap { + #[cfg_attr( + feature = "serde", + serde(bound(deserialize = "Hnsw: Deserialize<'de>")) + )] + hnsw: Hnsw, pub values: Vec, } -impl HnswMap +impl HnswMap where - P: Point, + P: ?Sized + Send + Sync, + M: Metric

, V: Clone, + S: Default + Len + Index + Sync, + for<'a> &'a S: IntoIterator, { - fn new(points: Vec

, values: Vec, builder: Builder) -> Self { + fn new(points: Vec, values: Vec, builder: Builder) -> Self + where + S: From>, + { let (hnsw, ids) = Hnsw::new(points, builder); let mut sorted = ids.into_iter().enumerate().collect::>(); @@ -154,7 +179,7 @@ where pub fn search<'a>( &'a self, point: &P, - search: &'a mut Search, + search: &'a mut Search, ) -> impl Iterator> + ExactSizeIterator + 'a { self.hnsw .search(point, search) @@ -167,20 +192,20 @@ where } #[doc(hidden)] - pub fn get(&self, i: usize, search: &Search) -> Option> { + pub fn get(&self, i: usize, search: &Search) -> Option> { Some(MapItem::from(self.hnsw.get(i, search)?, self)) } } -pub struct MapItem<'a, P, V> { +pub struct MapItem<'a, P: ?Sized, V> { pub distance: f32, pub pid: PointId, pub point: &'a P, pub value: &'a V, } -impl<'a, P, V> MapItem<'a, P, V> { - fn from(item: Item<'a, P>, map: &'a HnswMap) -> Self { +impl<'a, P: ?Sized, V> MapItem<'a, P, V> { + fn from(item: Item<'a, P>, map: &'a HnswMap) -> Self { MapItem { distance: item.distance, pid: item.pid, @@ -191,22 +216,29 @@ impl<'a, P, V> MapItem<'a, P, V> { } #[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] -pub struct Hnsw

{ +pub struct Hnsw { ef_search: usize, - points: Vec

, + points: S, zero: Vec, layers: Vec>, + phantom: PhantomData<(Box

, M)>, } -impl

Hnsw

+impl Hnsw where - P: Point, + P: ?Sized + Send + Sync, + M: Metric

, + S: Default + Len + Index + Sync, + for<'a> &'a S: IntoIterator, { pub fn builder() -> Builder { Builder::default() } - fn new(points: Vec

, builder: Builder) -> (Self, Vec) { + fn new(points: Vec, builder: Builder) -> (Self, Vec) + where + S: From>, + { let ef_search = builder.ef_search; let ef_construction = builder.ef_construction; let ml = builder.ml; @@ -226,8 +258,9 @@ where Self { ef_search, zero: Vec::new(), - points: Vec::new(), + points: Default::default(), layers: Vec::new(), + phantom: PhantomData, }, Vec::new(), ); @@ -260,14 +293,21 @@ where shuffled.sort_unstable(); let mut out = vec![INVALID; points.len()]; + let mut points = points + .into_iter() + .map(|p| Some(p)) + .collect::>>(); let points = shuffled .into_iter() .enumerate() .map(|(i, (_, idx))| { out[idx] = PointId(i as u32); - points[idx].clone() + points[idx] + .take() + .expect("Point should be present as it's wrapped in Option above") }) .collect::>(); + let points: S = points.into(); // Figure out how many nodes will go on each layer. This helps us allocate memory capacity // for each layer in advance, and also helps enable batch insertion of points. @@ -284,17 +324,18 @@ where let mut layers = vec![vec![]; top.0]; let zero = points - .iter() + .into_iter() .map(|_| RwLock::new(ZeroNode::default())) .collect::>(); let state = Construction { zero: zero.as_slice(), - pool: SearchPool::new(points.len()), + pool: SearchPool::new(zero.len()), top, points: &points, heuristic, ef_construction, + phantom: PhantomData::, #[cfg(feature = "indicatif")] progress, #[cfg(feature = "indicatif")] @@ -339,6 +380,7 @@ where zero: zero.into_iter().map(|node| node.into_inner()).collect(), points, layers, + phantom: PhantomData, }, out, ) @@ -352,7 +394,7 @@ where pub fn search<'a, 'b: 'a>( &'b self, point: &P, - search: &'a mut Search, + search: &'a mut Search, ) -> impl Iterator> + ExactSizeIterator + 'a { search.reset(); let map = move |candidate| Item::new(candidate, self); @@ -385,25 +427,28 @@ where /// Iterate over the keys and values in this index pub fn iter(&self) -> impl Iterator { self.points - .iter() + .into_iter() .enumerate() .map(|(i, p)| (PointId(i as u32), p)) } #[doc(hidden)] - pub fn get(&self, i: usize, search: &Search) -> Option> { + pub fn get(&self, i: usize, search: &Search) -> Option> { Some(Item::new(search.nearest.get(i).copied()?, self)) } } -pub struct Item<'a, P> { +pub struct Item<'a, P: ?Sized> { pub distance: f32, pub pid: PointId, pub point: &'a P, } -impl<'a, P> Item<'a, P> { - fn new(candidate: Candidate, hnsw: &'a Hnsw

) -> Self { +impl<'a, P: ?Sized> Item<'a, P> { + fn new(candidate: Candidate, hnsw: &'a Hnsw) -> Self + where + S: Index, + { Self { distance: candidate.distance.into_inner(), pid: candidate.pid, @@ -412,20 +457,25 @@ impl<'a, P> Item<'a, P> { } } -struct Construction<'a, P: Point> { +struct Construction<'a, P: ?Sized, M: Metric

, S> { zero: &'a [RwLock], - pool: SearchPool, + pool: SearchPool, top: LayerId, - points: &'a [P], + points: &'a S, heuristic: Option, ef_construction: usize, + phantom: PhantomData, #[cfg(feature = "indicatif")] progress: Option, #[cfg(feature = "indicatif")] done: AtomicUsize, } -impl<'a, P: Point> Construction<'a, P> { +impl<'a, P: ?Sized, M, S> Construction<'a, P, M, S> +where + S: Index, + M: Metric

, +{ /// Insert new node in the zero layer /// /// * `new` is the `PointId` for the new node @@ -507,7 +557,7 @@ impl<'a, P: Point> Construction<'a, P> { _ => return Ordering::Greater, }; - distance.cmp(&old.distance(&self.points[third]).into()) + distance.cmp(&M::distance(old, &self.points[third]).into()) }) .unwrap_or_else(|e| e); @@ -528,12 +578,14 @@ impl<'a, P: Point> Construction<'a, P> { } } -struct SearchPool { - pool: Mutex>, +type SearchPoolItem = (Search, Search); + +struct SearchPool> { + pool: Mutex>>, len: usize, } -impl SearchPool { +impl> SearchPool { fn new(len: usize) -> Self { Self { pool: Mutex::new(Vec::new()), @@ -541,14 +593,14 @@ impl SearchPool { } } - fn pop(&self) -> (Search, Search) { + fn pop(&self) -> SearchPoolItem { match self.pool.lock().pop() { Some(res) => res, None => (Search::new(self.len), Search::new(self.len)), } } - fn push(&self, item: (Search, Search)) { + fn push(&self, item: SearchPoolItem) { self.pool.lock().push(item); } } @@ -557,7 +609,7 @@ impl SearchPool { /// /// In particular, this contains most of the state used in algorithm 2. The structure is /// initialized by using `push()` to add the initial enter points. -pub struct Search { +pub struct Search> { /// Nodes visited so far (`v` in the paper) visited: Visited, /// Candidates for further inspection (`C` in the paper) @@ -571,9 +623,11 @@ pub struct Search { discarded: Vec, /// Maximum number of nearest neighbors to retain (`ef` in the paper) ef: usize, + /// PhantomData to bind the Metric parameter + phantom: PhantomData<(Box

, M)>, } -impl Search { +impl> Search { fn new(capacity: usize) -> Self { Self { visited: Visited::with_capacity(capacity), @@ -595,7 +649,13 @@ impl Search { /// /// Invariants: `self.nearest` should be in sorted (nearest first) order, and should be /// truncated to `self.ef`. - fn search(&mut self, point: &P, layer: L, points: &[P], links: usize) { + fn search>( + &mut self, + point: &P, + layer: L, + points: &S, + links: usize, + ) { while let Some(Reverse(candidate)) = self.candidates.pop() { if let Some(furthest) = self.nearest.last() { if candidate.distance > furthest.distance { @@ -613,13 +673,13 @@ impl Search { } } - fn add_neighbor_heuristic( + fn add_neighbor_heuristic>( &mut self, new: PointId, current: impl Iterator, layer: L, point: &P, - points: &[P], + points: &S, params: Heuristic, ) -> &[Candidate] { self.reset(); @@ -633,11 +693,11 @@ impl Search { /// Heuristically sort and truncate neighbors in `self.nearest` /// /// Invariant: `self.nearest` must be in sorted (nearest first) order. - fn select_heuristic( + fn select_heuristic>( &mut self, point: &P, layer: L, - points: &[P], + points: &S, params: Heuristic, ) -> &[Candidate] { self.working.clear(); @@ -652,7 +712,7 @@ impl Search { } let other = &points[hop]; - let distance = OrderedFloat::from(point.distance(other)); + let distance = OrderedFloat::from(M::distance(point, other)); let new = Candidate { distance, pid: hop }; self.working.push(new); } @@ -674,7 +734,8 @@ impl Search { // are to the query point, to facilitate bridging between clustered points. let candidate_point = &points[candidate.pid]; let nearest = !self.nearest.iter().any(|result| { - let distance = OrderedFloat::from(candidate_point.distance(&points[result.pid])); + let distance = + OrderedFloat::from(M::distance(candidate_point, &points[result.pid])); distance < candidate.distance }); @@ -701,13 +762,13 @@ impl Search { /// /// Will immediately return if the node has been considered before. This implements /// the inner loop from the paper's algorithm 2. - fn push(&mut self, pid: PointId, point: &P, points: &[P]) { + fn push>(&mut self, pid: PointId, point: &P, points: &S) { if !self.visited.insert(pid) { return; } let other = &points[pid]; - let distance = OrderedFloat::from(point.distance(other)); + let distance = OrderedFloat::from(M::distance(point, other)); let new = Candidate { distance, pid }; let idx = match self.nearest.binary_search(&new) { Err(idx) if idx < self.ef => idx, @@ -745,6 +806,7 @@ impl Search { working, discarded, ef: _, + phantom: _, } = self; visited.clear(); @@ -764,7 +826,7 @@ impl Search { } } -impl Default for Search { +impl> Default for Search { fn default() -> Self { Self { visited: Visited::with_capacity(0), @@ -773,12 +835,35 @@ impl Default for Search { working: Vec::new(), discarded: Vec::new(), ef: 1, + phantom: PhantomData, } } } -pub trait Point: Clone + Sync { - fn distance(&self, other: &Self) -> f32; +pub trait Metric: Send + Sync { + fn distance(a: &P, b: &P) -> f32; +} + +pub trait Len { + fn len(&self) -> usize; + + fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +impl Len for Vec { + fn len(&self) -> usize { + Vec::len(self) + } +} + +impl

Index for Vec

{ + type Output = P; + + fn index(&self, index: PointId) -> &Self::Output { + Vec::index(self, index.0 as usize) + } } /// The parameter `M` from the paper diff --git a/instant-distance/src/types.rs b/instant-distance/src/types.rs index d6a03fe..67ab70b 100644 --- a/instant-distance/src/types.rs +++ b/instant-distance/src/types.rs @@ -8,7 +8,7 @@ use serde::{Deserialize, Serialize}; #[cfg(feature = "serde-big-array")] use serde_big_array::BigArray; -use crate::{Hnsw, Point, M}; +use crate::{Hnsw, M}; pub(crate) struct Visited { store: Vec, @@ -266,19 +266,15 @@ impl Default for PointId { } } -impl

Index for Hnsw

{ - type Output = P; - - fn index(&self, index: PointId) -> &Self::Output { - &self.points[index.0 as usize] - } -} - -impl Index for [P] { +impl Index for Hnsw +where + P: ?Sized, + S: Index, +{ type Output = P; fn index(&self, index: PointId) -> &Self::Output { - &self[index.0 as usize] + &self.points[index] } } diff --git a/instant-distance/tests/all.rs b/instant-distance/tests/all.rs index b9fa973..51833dd 100644 --- a/instant-distance/tests/all.rs +++ b/instant-distance/tests/all.rs @@ -4,7 +4,7 @@ use ordered_float::OrderedFloat; use rand::rngs::{StdRng, ThreadRng}; use rand::{Rng, SeedableRng}; -use instant_distance::{Builder, Point as _, Search}; +use instant_distance::{Builder, Metric, Search}; #[test] #[allow(clippy::float_cmp, clippy::approx_constant)] @@ -16,7 +16,9 @@ fn map() { let seed = ThreadRng::default().gen::(); println!("map (seed = {seed})"); - let map = Builder::default().seed(seed).build(points, values); + let map = Builder::default() + .seed(seed) + .build::>(points, values); let mut search = Search::default(); for (i, item) in map.search(&Point(2.0, 2.0), &mut search).enumerate() { @@ -62,14 +64,16 @@ fn randomized(builder: Builder) -> (u64, usize) { let query = Point(rng.gen(), rng.gen()); let mut nearest = Vec::with_capacity(256); for (i, p) in points.iter().enumerate() { - nearest.push((OrderedFloat::from(query.distance(p)), i)); + nearest.push((OrderedFloat::from(EuclidMetric::distance(&query, p)), i)); if nearest.len() >= 200 { nearest.sort_unstable(); nearest.truncate(100); } } - let (hnsw, pids) = builder.seed(seed).build_hnsw(points); + let (hnsw, pids) = builder + .seed(seed) + .build_hnsw::<_, _, EuclidMetric, Vec>(points); let mut search = Search::default(); let results = hnsw.search(&query, &mut search); assert!(results.len() >= 100); @@ -90,9 +94,11 @@ fn randomized(builder: Builder) -> (u64, usize) { #[derive(Clone, Copy, Debug)] struct Point(f32, f32); -impl instant_distance::Point for Point { - fn distance(&self, other: &Self) -> f32 { +struct EuclidMetric; + +impl Metric for EuclidMetric { + fn distance(a: &Point, b: &Point) -> f32 { // Euclidean distance metric - ((self.0 - other.0).powi(2) + (self.1 - other.1).powi(2)).sqrt() + ((a.0 - b.0).powi(2) + (a.1 - b.1).powi(2)).sqrt() } }