Skip to content

Commit a048691

Browse files
committed
Introduce Metric trait
1 parent 52077ec commit a048691

File tree

7 files changed

+104
-84
lines changed

7 files changed

+104
-84
lines changed

instant-distance-py/benches/all.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use bencher::{benchmark_group, benchmark_main, Bencher};
22

3-
use instant_distance::{Builder, Point, Search};
4-
use instant_distance_py::FloatArray;
3+
use instant_distance::{Builder, Metric, Search};
4+
use instant_distance_py::{EuclidMetric, FloatArray};
55
use rand::{rngs::StdRng, Rng, SeedableRng};
66

77
benchmark_main!(benches);
@@ -12,7 +12,7 @@ fn distance(bench: &mut Bencher) {
1212
let point_a = FloatArray([rng.gen(); 300]);
1313
let point_b = FloatArray([rng.gen(); 300]);
1414

15-
bench.iter(|| point_a.distance(&point_b));
15+
bench.iter(|| EuclidMetric::distance(&point_a, &point_b));
1616
}
1717

1818
fn build(bench: &mut Bencher) {
@@ -25,7 +25,7 @@ fn build(bench: &mut Bencher) {
2525
bench.iter(|| {
2626
Builder::default()
2727
.seed(SEED)
28-
.build_hnsw::<_, _, Vec<FloatArray>>(points.clone())
28+
.build_hnsw::<_, _, EuclidMetric, Vec<FloatArray>>(points.clone())
2929
});
3030
}
3131

@@ -37,7 +37,7 @@ fn query(bench: &mut Bencher) {
3737
.collect::<Vec<_>>();
3838
let (hnsw, _) = Builder::default()
3939
.seed(SEED)
40-
.build_hnsw::<_, _, Vec<FloatArray>>(points);
40+
.build_hnsw::<_, _, EuclidMetric, Vec<FloatArray>>(points);
4141
let point = FloatArray([rng.gen(); 300]);
4242

4343
bench.iter(|| {

instant-distance-py/src/lib.rs

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use std::fs::File;
66
use std::io::{BufReader, BufWriter};
77
use std::iter::FromIterator;
88

9-
use instant_distance::Point;
9+
use instant_distance::Metric;
1010
use pyo3::conversion::IntoPy;
1111
use pyo3::exceptions::{PyTypeError, PyValueError};
1212
use pyo3::types::{PyList, PyModule, PyString};
@@ -29,7 +29,7 @@ fn instant_distance_py(_: Python, m: &PyModule) -> PyResult<()> {
2929

3030
#[pyclass]
3131
struct HnswMap {
32-
inner: instant_distance::HnswMap<FloatArray, MapValue, Vec<FloatArray>>,
32+
inner: instant_distance::HnswMap<FloatArray, EuclidMetric, MapValue, Vec<FloatArray>>,
3333
}
3434

3535
#[pymethods]
@@ -54,7 +54,7 @@ impl HnswMap {
5454
/// Load an index from the given file name
5555
#[staticmethod]
5656
fn load(fname: &str) -> PyResult<Self> {
57-
let hnsw_map = bincode::deserialize_from::<_, instant_distance::HnswMap<_, _, _>>(
57+
let hnsw_map = bincode::deserialize_from::<_, instant_distance::HnswMap<_, _, _, _>>(
5858
BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?),
5959
)
6060
.map_err(|e| PyValueError::new_err(format!("deserialization error: {e:?}")))?;
@@ -90,7 +90,7 @@ impl HnswMap {
9090
/// with a squared Euclidean distance metric.
9191
#[pyclass]
9292
struct Hnsw {
93-
inner: instant_distance::Hnsw<FloatArray, Vec<FloatArray>>,
93+
inner: instant_distance::Hnsw<FloatArray, EuclidMetric, Vec<FloatArray>>,
9494
}
9595

9696
#[pymethods]
@@ -111,7 +111,7 @@ impl Hnsw {
111111
/// Load an index from the given file name
112112
#[staticmethod]
113113
fn load(fname: &str) -> PyResult<Self> {
114-
let hnsw = bincode::deserialize_from::<_, instant_distance::Hnsw<_, _>>(
114+
let hnsw = bincode::deserialize_from::<_, instant_distance::Hnsw<_, _, _>>(
115115
BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?),
116116
)
117117
.map_err(|e| PyValueError::new_err(format!("deserialization error: {e:?}")))?;
@@ -144,7 +144,7 @@ impl Hnsw {
144144
/// Search buffer and result set
145145
#[pyclass]
146146
struct Search {
147-
inner: instant_distance::Search<FloatArray>,
147+
inner: instant_distance::Search<FloatArray, EuclidMetric>,
148148
cur: Option<(HnswType, usize)>,
149149
}
150150

@@ -364,20 +364,23 @@ impl TryFrom<&PyAny> for FloatArray {
364364
}
365365
}
366366

367-
impl Point for FloatArray {
368-
fn distance(&self, rhs: &Self) -> f32 {
367+
#[derive(Clone, Copy, Deserialize, Serialize)]
368+
pub struct EuclidMetric;
369+
370+
impl Metric<FloatArray> for EuclidMetric {
371+
fn distance(lhs: &FloatArray, rhs: &FloatArray) -> f32 {
369372
#[cfg(target_arch = "x86_64")]
370373
{
371374
use std::arch::x86_64::{
372375
_mm256_castps256_ps128, _mm256_extractf128_ps, _mm256_fmadd_ps, _mm256_load_ps,
373376
_mm256_setzero_ps, _mm256_sub_ps, _mm_add_ps, _mm_add_ss, _mm_cvtss_f32,
374377
_mm_fmadd_ps, _mm_load_ps, _mm_movehl_ps, _mm_shuffle_ps, _mm_sub_ps,
375378
};
376-
debug_assert_eq!(self.0.len() % 8, 4);
379+
debug_assert_eq!(lhs.0.len() % 8, 4);
377380

378381
unsafe {
379382
let mut acc_8x = _mm256_setzero_ps();
380-
for (lh_slice, rh_slice) in self.0.chunks_exact(8).zip(rhs.0.chunks_exact(8)) {
383+
for (lh_slice, rh_slice) in lhs.0.chunks_exact(8).zip(rhs.0.chunks_exact(8)) {
381384
let lh_8x = _mm256_load_ps(lh_slice.as_ptr());
382385
let rh_8x = _mm256_load_ps(rh_slice.as_ptr());
383386
let diff = _mm256_sub_ps(lh_8x, rh_8x);
@@ -388,7 +391,7 @@ impl Point for FloatArray {
388391
let right = _mm256_castps256_ps128(acc_8x); // lower half
389392
acc_4x = _mm_add_ps(acc_4x, right); // sum halves
390393

391-
let lh_4x = _mm_load_ps(self.0[DIMENSIONS - 4..].as_ptr());
394+
let lh_4x = _mm_load_ps(lhs.0[DIMENSIONS - 4..].as_ptr());
392395
let rh_4x = _mm_load_ps(rhs.0[DIMENSIONS - 4..].as_ptr());
393396
let diff = _mm_sub_ps(lh_4x, rh_4x);
394397
acc_4x = _mm_fmadd_ps(diff, diff, acc_4x);
@@ -401,7 +404,7 @@ impl Point for FloatArray {
401404
}
402405
}
403406
#[cfg(not(target_arch = "x86_64"))]
404-
self.0
407+
lhs.0
405408
.iter()
406409
.zip(rhs.0.iter())
407410
.map(|(&a, &b)| (a - b).powi(2))

instant-distance/benches/all.rs

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use bencher::{benchmark_group, benchmark_main, Bencher};
22
use rand::rngs::StdRng;
33
use rand::{Rng, SeedableRng};
44

5-
use instant_distance::Builder;
5+
use instant_distance::{Builder, Metric};
66

77
benchmark_main!(benches);
88
benchmark_group!(benches, build_heuristic);
@@ -11,13 +11,13 @@ fn build_heuristic(bench: &mut Bencher) {
1111
let mut rng = StdRng::seed_from_u64(SEED);
1212
let points = (0..1024)
1313
.into_iter()
14-
.map(|_| Point(rng.gen(), rng.gen()))
14+
.map(|_| [rng.gen(), rng.gen()])
1515
.collect::<Vec<_>>();
1616

1717
bench.iter(|| {
1818
Builder::default()
1919
.seed(SEED)
20-
.build_hnsw::<Point, Point, Vec<Point>>(points.clone())
20+
.build_hnsw::<[f32; 2], [f32; 2], EuclidMetric, Vec<[f32; 2]>>(points.clone())
2121
})
2222
}
2323

@@ -51,12 +51,15 @@ fn randomized(builder: Builder) -> (u64, usize) {
5151
}
5252
*/
5353

54-
#[derive(Clone, Copy, Debug)]
55-
struct Point(f32, f32);
54+
struct EuclidMetric;
5655

57-
impl instant_distance::Point for Point {
58-
fn distance(&self, other: &Self) -> f32 {
56+
impl Metric<[f32; 2]> for EuclidMetric {
57+
fn distance(a: &[f32; 2], b: &[f32; 2]) -> f32 {
5958
// Euclidean distance metric
60-
((self.0 - other.0).powi(2) + (self.1 - other.1).powi(2)).sqrt()
59+
a.iter()
60+
.zip(b.iter())
61+
.map(|(&a, &b)| (a - b).powi(2))
62+
.sum::<f32>()
63+
.sqrt()
6164
}
6265
}

instant-distance/examples/colors.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
use instant_distance::{Builder, Search};
1+
use instant_distance::{Builder, Metric, Search};
22

33
fn main() {
44
let points = vec![Point(255, 0, 0), Point(0, 255, 0), Point(0, 0, 255)];
55
let values = vec!["red", "green", "blue"];
66

7-
let map = Builder::default().build::<Point, Point, &str, Vec<Point>>(points, values);
7+
let map =
8+
Builder::default().build::<Point, Point, EuclidMetric, &str, Vec<Point>>(points, values);
89
let mut search = Search::default();
910

1011
let burnt_orange = Point(204, 85, 0);
@@ -17,10 +18,11 @@ fn main() {
1718
#[derive(Clone, Copy, Debug)]
1819
struct Point(isize, isize, isize);
1920

20-
impl instant_distance::Point for Point {
21-
fn distance(&self, other: &Self) -> f32 {
21+
struct EuclidMetric;
22+
23+
impl Metric<Point> for EuclidMetric {
24+
fn distance(a: &Point, b: &Point) -> f32 {
2225
// Euclidean distance metric
23-
(((self.0 - other.0).pow(2) + (self.1 - other.1).pow(2) + (self.2 - other.2).pow(2)) as f32)
24-
.sqrt()
26+
(((a.0 - b.0).pow(2) + (a.1 - b.1).pow(2) + (a.2 - b.2).pow(2)) as f32).sqrt()
2527
}
2628
}

0 commit comments

Comments
 (0)