Skip to content

Commit c0dc4de

Browse files
committed
Introduce PointStorage and support variable length
1 parent a048691 commit c0dc4de

File tree

3 files changed

+174
-55
lines changed

3 files changed

+174
-55
lines changed

instant-distance-py/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@ crate-type = ["cdylib", "lib"]
1818
name = "instant_distance"
1919

2020
[dependencies]
21+
aligned-vec = { version = "0.5.0", features = ["serde"] }
2122
bincode = "1.3.1"
2223
instant-distance = { version = "0.6", path = "../instant-distance", features = ["with-serde"] }
2324
pyo3 = { version = "0.18.0", features = ["extension-module"] }
2425
serde = { version = "1", features = ["derive"] }
25-
serde-big-array = "0.5.0"
2626

2727
[dev-dependencies]
2828
bencher = "0.1.5"

instant-distance-py/benches/all.rs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
1+
use aligned_vec::avec;
12
use bencher::{benchmark_group, benchmark_main, Bencher};
23

34
use instant_distance::{Builder, Metric, Search};
4-
use instant_distance_py::{EuclidMetric, FloatArray};
5+
use instant_distance_py::{EuclidMetric, PointStorage};
56
use rand::{rngs::StdRng, Rng, SeedableRng};
67

78
benchmark_main!(benches);
89
benchmark_group!(benches, distance, build, query);
910

1011
fn distance(bench: &mut Bencher) {
1112
let mut rng = StdRng::seed_from_u64(SEED);
12-
let point_a = FloatArray([rng.gen(); 300]);
13-
let point_b = FloatArray([rng.gen(); 300]);
13+
let point_a = avec![rng.gen(); 304];
14+
let point_b = avec![rng.gen(); 304];
1415

1516
bench.iter(|| EuclidMetric::distance(&point_a, &point_b));
1617
}
@@ -19,26 +20,26 @@ fn build(bench: &mut Bencher) {
1920
let mut rng = StdRng::seed_from_u64(SEED);
2021
let points = (0..1024)
2122
.into_iter()
22-
.map(|_| FloatArray([rng.gen(); 300]))
23+
.map(|_| vec![rng.gen(); 304])
2324
.collect::<Vec<_>>();
2425

2526
bench.iter(|| {
2627
Builder::default()
2728
.seed(SEED)
28-
.build_hnsw::<_, _, EuclidMetric, Vec<FloatArray>>(points.clone())
29+
.build_hnsw::<Vec<f32>, [f32], EuclidMetric, PointStorage>(points.clone())
2930
});
3031
}
3132

3233
fn query(bench: &mut Bencher) {
3334
let mut rng = StdRng::seed_from_u64(SEED);
3435
let points = (0..1024)
3536
.into_iter()
36-
.map(|_| FloatArray([rng.gen(); 300]))
37+
.map(|_| vec![rng.gen(); 304])
3738
.collect::<Vec<_>>();
3839
let (hnsw, _) = Builder::default()
3940
.seed(SEED)
40-
.build_hnsw::<_, _, EuclidMetric, Vec<FloatArray>>(points);
41-
let point = FloatArray([rng.gen(); 300]);
41+
.build_hnsw::<Vec<f32>, [f32], EuclidMetric, PointStorage>(points);
42+
let point = avec![rng.gen(); 304];
4243

4344
bench.iter(|| {
4445
let mut search = Search::default();

instant-distance-py/src/lib.rs

Lines changed: 164 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,17 @@
44
use std::convert::TryFrom;
55
use std::fs::File;
66
use std::io::{BufReader, BufWriter};
7-
use std::iter::FromIterator;
7+
use std::iter::{repeat, FromIterator};
8+
use std::ops::Index;
89

9-
use instant_distance::Metric;
10+
use aligned_vec::{AVec, ConstAlign};
11+
use instant_distance::{Len, Metric, PointId};
1012
use pyo3::conversion::IntoPy;
11-
use pyo3::exceptions::{PyTypeError, PyValueError};
13+
use pyo3::exceptions::PyValueError;
1214
use pyo3::types::{PyList, PyModule, PyString};
1315
use pyo3::{pyclass, pymethods, pymodule};
1416
use pyo3::{Py, PyAny, PyErr, PyObject, PyRef, PyRefMut, PyResult, Python};
1517
use serde::{Deserialize, Serialize};
16-
use serde_big_array::BigArray;
1718

1819
#[pymodule]
1920
#[pyo3(name = "instant_distance")]
@@ -29,7 +30,7 @@ fn instant_distance_py(_: Python, m: &PyModule) -> PyResult<()> {
2930

3031
#[pyclass]
3132
struct HnswMap {
32-
inner: instant_distance::HnswMap<FloatArray, EuclidMetric, MapValue, Vec<FloatArray>>,
33+
inner: instant_distance::HnswMap<[f32], EuclidMetric, MapValue, PointStorage>,
3334
}
3435

3536
#[pymethods]
@@ -39,24 +40,34 @@ impl HnswMap {
3940
fn build(points: &PyList, values: &PyList, config: &Config) -> PyResult<Self> {
4041
let points = points
4142
.into_iter()
42-
.map(FloatArray::try_from)
43+
.map(|v| {
44+
v.iter()?
45+
.into_iter()
46+
.map(|x| x?.extract())
47+
.collect::<Result<Vec<_>, PyErr>>()
48+
})
4349
.collect::<Result<Vec<_>, PyErr>>()?;
4450

4551
let values = values
4652
.into_iter()
4753
.map(MapValue::try_from)
4854
.collect::<Result<Vec<_>, PyErr>>()?;
4955

50-
let hsnw_map = instant_distance::Builder::from(config).build(points, values);
56+
let hsnw_map = instant_distance::Builder::from(config)
57+
.build::<Vec<_>, [f32], EuclidMetric, MapValue, PointStorage>(points, values);
5158
Ok(Self { inner: hsnw_map })
5259
}
5360

5461
/// Load an index from the given file name
5562
#[staticmethod]
5663
fn load(fname: &str) -> PyResult<Self> {
57-
let hnsw_map = bincode::deserialize_from::<_, instant_distance::HnswMap<_, _, _, _>>(
58-
BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?),
59-
)
64+
let hnsw_map = bincode::deserialize_from::<
65+
_,
66+
instant_distance::HnswMap<[f32], EuclidMetric, MapValue, PointStorage>,
67+
>(BufReader::with_capacity(
68+
32 * 1024 * 1024,
69+
File::open(fname)?,
70+
))
6071
.map_err(|e| PyValueError::new_err(format!("deserialization error: {e:?}")))?;
6172
Ok(Self { inner: hnsw_map })
6273
}
@@ -77,7 +88,7 @@ impl HnswMap {
7788
///
7889
/// For best performance, reusing `Search` objects is recommended.
7990
fn search(slf: Py<Self>, point: &PyAny, search: &mut Search, py: Python<'_>) -> PyResult<()> {
80-
let point = FloatArray::try_from(point)?;
91+
let point = try_avec_from_py(point)?;
8192
let _ = slf.try_borrow(py)?.inner.search(&point, &mut search.inner);
8293
search.cur = Some((HnswType::Map(slf.clone_ref(py)), 0));
8394
Ok(())
@@ -90,7 +101,7 @@ impl HnswMap {
90101
/// with a squared Euclidean distance metric.
91102
#[pyclass]
92103
struct Hnsw {
93-
inner: instant_distance::Hnsw<FloatArray, EuclidMetric, Vec<FloatArray>>,
104+
inner: instant_distance::Hnsw<[f32], EuclidMetric, PointStorage>,
94105
}
95106

96107
#[pymethods]
@@ -100,20 +111,30 @@ impl Hnsw {
100111
fn build(input: &PyList, config: &Config) -> PyResult<(Self, Vec<u32>)> {
101112
let points = input
102113
.into_iter()
103-
.map(FloatArray::try_from)
114+
.map(|v| {
115+
v.iter()?
116+
.into_iter()
117+
.map(|x| x?.extract())
118+
.collect::<Result<Vec<_>, PyErr>>()
119+
})
104120
.collect::<Result<Vec<_>, PyErr>>()?;
105121

106-
let (inner, ids) = instant_distance::Builder::from(config).build_hnsw(points);
122+
let (inner, ids) = instant_distance::Builder::from(config)
123+
.build_hnsw::<Vec<f32>, [f32], EuclidMetric, PointStorage>(points);
107124
let ids = Vec::from_iter(ids.into_iter().map(|pid| pid.into_inner()));
108125
Ok((Self { inner }, ids))
109126
}
110127

111128
/// Load an index from the given file name
112129
#[staticmethod]
113130
fn load(fname: &str) -> PyResult<Self> {
114-
let hnsw = bincode::deserialize_from::<_, instant_distance::Hnsw<_, _, _>>(
115-
BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?),
116-
)
131+
let hnsw = bincode::deserialize_from::<
132+
_,
133+
instant_distance::Hnsw<[f32], EuclidMetric, PointStorage>,
134+
>(BufReader::with_capacity(
135+
32 * 1024 * 1024,
136+
File::open(fname)?,
137+
))
117138
.map_err(|e| PyValueError::new_err(format!("deserialization error: {e:?}")))?;
118139
Ok(Self { inner: hnsw })
119140
}
@@ -134,7 +155,7 @@ impl Hnsw {
134155
///
135156
/// For best performance, reusing `Search` objects is recommended.
136157
fn search(slf: Py<Self>, point: &PyAny, search: &mut Search, py: Python<'_>) -> PyResult<()> {
137-
let point = FloatArray::try_from(point)?;
158+
let point = try_avec_from_py(point)?;
138159
let _ = slf.try_borrow(py)?.inner.search(&point, &mut search.inner);
139160
search.cur = Some((HnswType::Hnsw(slf.clone_ref(py)), 0));
140161
Ok(())
@@ -144,7 +165,7 @@ impl Hnsw {
144165
/// Search buffer and result set
145166
#[pyclass]
146167
struct Search {
147-
inner: instant_distance::Search<FloatArray, EuclidMetric>,
168+
inner: instant_distance::Search<[f32], EuclidMetric>,
148169
cur: Option<(HnswType, usize)>,
149170
}
150171

@@ -345,42 +366,36 @@ impl Neighbor {
345366
}
346367
}
347368

348-
#[repr(align(32))]
349-
#[derive(Clone, Deserialize, Serialize)]
350-
pub struct FloatArray(#[serde(with = "BigArray")] pub [f32; DIMENSIONS]);
351-
352-
impl TryFrom<&PyAny> for FloatArray {
353-
type Error = PyErr;
354-
355-
fn try_from(value: &PyAny) -> Result<Self, Self::Error> {
356-
let mut new = FloatArray([0.0; DIMENSIONS]);
357-
for (i, val) in value.iter()?.enumerate() {
358-
match i >= DIMENSIONS {
359-
true => return Err(PyTypeError::new_err("point array too long")),
360-
false => new.0[i] = val?.extract::<f32>()?,
361-
}
362-
}
363-
Ok(new)
369+
fn try_avec_from_py(value: &PyAny) -> Result<AVec<f32, ConstAlign<ALIGNMENT>>, PyErr> {
370+
let mut new = AVec::new(ALIGNMENT);
371+
for val in value.iter()? {
372+
new.push(val?.extract::<f32>()?);
364373
}
374+
for _ in 0..PointStorage::padding(new.len()) {
375+
new.push(0.0);
376+
}
377+
Ok(new)
365378
}
366379

367380
#[derive(Clone, Copy, Deserialize, Serialize)]
368381
pub struct EuclidMetric;
369382

370-
impl Metric<FloatArray> for EuclidMetric {
371-
fn distance(lhs: &FloatArray, rhs: &FloatArray) -> f32 {
383+
impl Metric<[f32]> for EuclidMetric {
384+
fn distance(lhs: &[f32], rhs: &[f32]) -> f32 {
385+
debug_assert_eq!(lhs.len(), rhs.len());
386+
372387
#[cfg(target_arch = "x86_64")]
373388
{
374389
use std::arch::x86_64::{
375390
_mm256_castps256_ps128, _mm256_extractf128_ps, _mm256_fmadd_ps, _mm256_load_ps,
376391
_mm256_setzero_ps, _mm256_sub_ps, _mm_add_ps, _mm_add_ss, _mm_cvtss_f32,
377-
_mm_fmadd_ps, _mm_load_ps, _mm_movehl_ps, _mm_shuffle_ps, _mm_sub_ps,
392+
_mm_movehl_ps, _mm_shuffle_ps,
378393
};
379-
debug_assert_eq!(lhs.0.len() % 8, 4);
394+
debug_assert_eq!(lhs.len() % 8, 0);
380395

381396
unsafe {
382397
let mut acc_8x = _mm256_setzero_ps();
383-
for (lh_slice, rh_slice) in lhs.0.chunks_exact(8).zip(rhs.0.chunks_exact(8)) {
398+
for (lh_slice, rh_slice) in lhs.chunks_exact(8).zip(rhs.chunks_exact(8)) {
384399
let lh_8x = _mm256_load_ps(lh_slice.as_ptr());
385400
let rh_8x = _mm256_load_ps(rh_slice.as_ptr());
386401
let diff = _mm256_sub_ps(lh_8x, rh_8x);
@@ -391,11 +406,6 @@ impl Metric<FloatArray> for EuclidMetric {
391406
let right = _mm256_castps256_ps128(acc_8x); // lower half
392407
acc_4x = _mm_add_ps(acc_4x, right); // sum halves
393408

394-
let lh_4x = _mm_load_ps(lhs.0[DIMENSIONS - 4..].as_ptr());
395-
let rh_4x = _mm_load_ps(rhs.0[DIMENSIONS - 4..].as_ptr());
396-
let diff = _mm_sub_ps(lh_4x, rh_4x);
397-
acc_4x = _mm_fmadd_ps(diff, diff, acc_4x);
398-
399409
let lower = _mm_movehl_ps(acc_4x, acc_4x);
400410
acc_4x = _mm_add_ps(acc_4x, lower);
401411
let upper = _mm_shuffle_ps(acc_4x, acc_4x, 0x1);
@@ -412,6 +422,114 @@ impl Metric<FloatArray> for EuclidMetric {
412422
}
413423
}
414424

425+
#[derive(Debug, Deserialize, Serialize)]
426+
pub struct PointStorage {
427+
point_len: usize,
428+
points_data: AVec<f32>,
429+
}
430+
431+
impl PointStorage {
432+
const fn padding(len: usize) -> usize {
433+
let floats_per_alignment = ALIGNMENT / std::mem::size_of::<f32>();
434+
match len % floats_per_alignment {
435+
0 => 0,
436+
floats_over_alignment => floats_per_alignment - floats_over_alignment,
437+
}
438+
}
439+
440+
pub fn iter(&self) -> impl Iterator<Item = &[f32]> {
441+
self.points_data.chunks_exact(self.point_len)
442+
}
443+
}
444+
445+
impl Default for PointStorage {
446+
fn default() -> Self {
447+
Self {
448+
point_len: 1,
449+
points_data: AVec::new(ALIGNMENT),
450+
}
451+
}
452+
}
453+
454+
impl Index<usize> for PointStorage {
455+
type Output = [f32];
456+
457+
fn index(&self, index: usize) -> &Self::Output {
458+
let raw_idx = index * self.point_len;
459+
&self.points_data[raw_idx..(raw_idx + self.point_len)]
460+
}
461+
}
462+
463+
impl Index<PointId> for PointStorage {
464+
type Output = [f32];
465+
466+
fn index(&self, index: PointId) -> &Self::Output {
467+
self.index(index.into_inner() as usize)
468+
}
469+
}
470+
471+
impl From<Vec<Vec<f32>>> for PointStorage {
472+
fn from(value: Vec<Vec<f32>>) -> Self {
473+
if let Some(point) = value.first() {
474+
let point_len = point.len();
475+
let padding = PointStorage::padding(point_len);
476+
let mut points_data =
477+
AVec::with_capacity(ALIGNMENT, value.len() * (point_len + padding));
478+
for point in value {
479+
// all points should have the same length
480+
debug_assert_eq!(point.len(), point_len);
481+
for v in point.into_iter().chain(repeat(0.0).take(padding)) {
482+
points_data.push(v);
483+
}
484+
}
485+
Self {
486+
point_len: point_len + padding,
487+
points_data,
488+
}
489+
} else {
490+
Default::default()
491+
}
492+
}
493+
}
494+
495+
impl Len for PointStorage {
496+
fn len(&self) -> usize {
497+
self.points_data.len() / self.point_len
498+
}
499+
}
500+
501+
impl<'a> IntoIterator for &'a PointStorage {
502+
type Item = &'a [f32];
503+
504+
type IntoIter = PointStorageIterator<'a>;
505+
506+
fn into_iter(self) -> Self::IntoIter {
507+
PointStorageIterator {
508+
storage: self,
509+
next_idx: 0,
510+
}
511+
}
512+
}
513+
514+
pub struct PointStorageIterator<'a> {
515+
storage: &'a PointStorage,
516+
next_idx: usize,
517+
}
518+
519+
impl<'a> Iterator for PointStorageIterator<'a> {
520+
type Item = &'a [f32];
521+
522+
fn next(&mut self) -> Option<Self::Item> {
523+
if self.next_idx < self.storage.len() {
524+
let result = &self.storage[self.next_idx];
525+
self.next_idx += 1;
526+
Some(result)
527+
} else {
528+
None
529+
}
530+
}
531+
}
532+
415533
#[derive(Clone, Deserialize, Serialize)]
416534
enum MapValue {
417535
String(String),
@@ -433,4 +551,4 @@ impl IntoPy<Py<PyAny>> for &'_ MapValue {
433551
}
434552
}
435553

436-
const DIMENSIONS: usize = 300;
554+
const ALIGNMENT: usize = 32;

0 commit comments

Comments
 (0)