Skip to content

Commit 3b4fde8

Browse files
committed
Generalize SIMD distance implementation to n-length vectors
1 parent f1cb9ee commit 3b4fde8

File tree

2 files changed

+35
-21
lines changed

2 files changed

+35
-21
lines changed

instant-distance-py/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ name = "instant_distance"
1515
crate-type = ["cdylib"]
1616

1717
[dependencies]
18+
aligned-vec = { version = "0.5.0", features = ["serde"] }
1819
bincode = "1.3.1"
1920
instant-distance = { version = "0.6", path = "../instant-distance", features = ["with-serde"] }
2021
pyo3 = { version = "0.18.0", features = ["extension-module"] }
2122
serde = { version = "1", features = ["derive"] }
22-
serde-big-array = "0.4.1"

instant-distance-py/src/lib.rs

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@ use std::fs::File;
66
use std::io::{BufReader, BufWriter};
77
use std::iter::FromIterator;
88

9+
use aligned_vec::AVec;
910
use instant_distance::Point;
1011
use pyo3::conversion::IntoPy;
11-
use pyo3::exceptions::{PyTypeError, PyValueError};
12+
use pyo3::exceptions::PyValueError;
1213
use pyo3::types::{PyList, PyModule, PyString};
1314
use pyo3::{pyclass, pymethods, pymodule};
1415
use pyo3::{Py, PyAny, PyErr, PyObject, PyRef, PyRefMut, PyResult, Python};
1516
use serde::{Deserialize, Serialize};
16-
use serde_big_array::BigArray;
1717

1818
#[pymodule]
1919
#[pyo3(name = "instant_distance")]
@@ -87,8 +87,7 @@ impl HnswMap {
8787

8888
/// An instance of hierarchical navigable small worlds
8989
///
90-
/// For now, this is specialized to only support 300-element (32-bit) float vectors
91-
/// with a squared Euclidean distance metric.
90+
/// For now, this uses a squared Euclidean distance metric.
9291
#[pyclass]
9392
struct Hnsw {
9493
inner: instant_distance::Hnsw<FloatArray>,
@@ -346,35 +345,32 @@ impl Neighbor {
346345
}
347346
}
348347

349-
#[repr(align(32))]
350348
#[derive(Clone, Deserialize, Serialize)]
351-
struct FloatArray(#[serde(with = "BigArray")] [f32; DIMENSIONS]);
349+
struct FloatArray(AVec<f32>);
352350

353351
impl TryFrom<&PyAny> for FloatArray {
354352
type Error = PyErr;
355353

356354
fn try_from(value: &PyAny) -> Result<Self, Self::Error> {
357-
let mut new = FloatArray([0.0; DIMENSIONS]);
358-
for (i, val) in value.iter()?.enumerate() {
359-
match i >= DIMENSIONS {
360-
true => return Err(PyTypeError::new_err("point array too long")),
361-
false => new.0[i] = val?.extract::<f32>()?,
362-
}
355+
let mut new = FloatArray(AVec::with_capacity(32, value.len()?));
356+
for val in value.iter()? {
357+
new.0.push(val?.extract()?);
363358
}
364359
Ok(new)
365360
}
366361
}
367362

368363
impl Point for FloatArray {
369364
fn distance(&self, rhs: &Self) -> f32 {
365+
debug_assert_eq!(self.0.len(), rhs.0.len());
366+
370367
#[cfg(target_arch = "x86_64")]
371368
{
372369
use std::arch::x86_64::{
373370
_mm256_castps256_ps128, _mm256_extractf128_ps, _mm256_fmadd_ps, _mm256_load_ps,
374371
_mm256_setzero_ps, _mm256_sub_ps, _mm_add_ps, _mm_add_ss, _mm_cvtss_f32,
375372
_mm_fmadd_ps, _mm_load_ps, _mm_movehl_ps, _mm_shuffle_ps, _mm_sub_ps,
376373
};
377-
debug_assert_eq!(self.0.len() % 8, 4);
378374

379375
unsafe {
380376
let mut acc_8x = _mm256_setzero_ps();
@@ -389,16 +385,36 @@ impl Point for FloatArray {
389385
let right = _mm256_castps256_ps128(acc_8x); // lower half
390386
acc_4x = _mm_add_ps(acc_4x, right); // sum halves
391387

392-
let lh_4x = _mm_load_ps(self.0[DIMENSIONS - 4..].as_ptr());
393-
let rh_4x = _mm_load_ps(rhs.0[DIMENSIONS - 4..].as_ptr());
394-
let diff = _mm_sub_ps(lh_4x, rh_4x);
395-
acc_4x = _mm_fmadd_ps(diff, diff, acc_4x);
388+
// count of already processed dimensions
389+
let mut processed_count = self.0.len() - self.0.len() % 8;
390+
391+
if self.0.len() % 8 >= 4 {
392+
// there are 4+ dimensions to process
393+
// let's process another 4 in a batch
394+
let lh_4x = _mm_load_ps(self.0[processed_count..].as_ptr());
395+
let rh_4x = _mm_load_ps(rhs.0[processed_count..].as_ptr());
396+
let diff = _mm_sub_ps(lh_4x, rh_4x);
397+
acc_4x = _mm_fmadd_ps(diff, diff, acc_4x);
398+
processed_count += 4;
399+
}
396400

401+
// sum up the registers
397402
let lower = _mm_movehl_ps(acc_4x, acc_4x);
398403
acc_4x = _mm_add_ps(acc_4x, lower);
399404
let upper = _mm_shuffle_ps(acc_4x, acc_4x, 0x1);
400405
acc_4x = _mm_add_ss(acc_4x, upper);
401-
_mm_cvtss_f32(acc_4x)
406+
let mut distance = _mm_cvtss_f32(acc_4x);
407+
408+
// process the leftover dimensions (if any are left)
409+
if processed_count < self.0.len() {
410+
distance += self.0[processed_count..]
411+
.iter()
412+
.zip(rhs.0[processed_count..].iter())
413+
.map(|(&a, &b)| (a - b).powi(2))
414+
.sum::<f32>()
415+
}
416+
417+
distance
402418
}
403419
}
404420
#[cfg(not(target_arch = "x86_64"))]
@@ -430,5 +446,3 @@ impl IntoPy<Py<PyAny>> for &'_ MapValue {
430446
}
431447
}
432448
}
433-
434-
const DIMENSIONS: usize = 300;

0 commit comments

Comments
 (0)