Skip to content

Commit bca31ad

Browse files
committed
Refactor to prepare for multiple distance metrics
1 parent f1cb9ee commit bca31ad

File tree

3 files changed

+127
-43
lines changed

3 files changed

+127
-43
lines changed

Makefile

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
1-
test-python:
2-
cargo build --release
3-
cp target/release/libinstant_distance.dylib instant-distance-py/test/instant_distance.so
1+
instant-distance-py/test/instant_distance.so: instant-distance-py/src/lib.rs
2+
RUSTFLAGS="-C target-cpu=native" cargo build --release
3+
([ -f target/release/libinstant_distance.dylib ] && cp target/release/libinstant_distance.dylib instant-distance-py/test/instant_distance.so) || \
4+
([ -f target/release/libinstant_distance.so ] && cp target/release/libinstant_distance.so instant-distance-py/test/instant_distance.so)
5+
6+
test-python: instant-distance-py/test/instant_distance.so
47
PYTHONPATH=instant-distance-py/test/ python3 -m test
58

9+
bench-python: instant-distance-py/test/instant_distance.so
10+
PYTHONPATH=instant-distance-py/test/ python3 -m timeit -n 10 -s 'import random, instant_distance; points = [[random.random() for _ in range(300)] for _ in range(1024)]; config = instant_distance.Config()' 'instant_distance.Hnsw.build(points, config)'
11+
612
clean:
713
cargo clean
814
rm -f instant-distance-py/test/instant_distance.so

instant-distance-py/src/lib.rs

Lines changed: 112 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -24,41 +24,79 @@ fn instant_distance_py(_: Python, m: &PyModule) -> PyResult<()> {
2424
m.add_class::<Search>()?;
2525
m.add_class::<Hnsw>()?;
2626
m.add_class::<HnswMap>()?;
27+
m.add_class::<DistanceMetric>()?;
2728
Ok(())
2829
}
2930

31+
#[pyclass]
32+
#[derive(Copy, Clone)]
33+
enum DistanceMetric {
34+
Euclid,
35+
Cosine,
36+
}
37+
38+
impl Default for DistanceMetric {
39+
fn default() -> Self {
40+
Self::Euclid
41+
}
42+
}
43+
44+
// Helper macro for dispatching to inner implementation
45+
macro_rules! impl_for_each_hnsw_with_metric {
46+
($type:ident, $instance:expr, $inner:ident, $($tokens:tt)+) => {
47+
match $instance {
48+
$type::Euclid($inner) => {
49+
$($tokens)+
50+
}
51+
$type::Cosine($inner) => {
52+
$($tokens)+
53+
}
54+
}
55+
};
56+
}
57+
3058
#[pyclass]
3159
struct HnswMap {
32-
inner: instant_distance::HnswMap<FloatArray, MapValue>,
60+
inner: HnswMapWithMetric,
61+
}
62+
63+
#[derive(Deserialize, Serialize)]
64+
enum HnswMapWithMetric {
65+
Euclid(instant_distance::HnswMap<FloatArray, MapValue>),
66+
Cosine(instant_distance::HnswMap<FloatArray, MapValue>),
3367
}
3468

3569
#[pymethods]
3670
impl HnswMap {
3771
/// Build the index
3872
#[staticmethod]
3973
fn build(points: &PyList, values: &PyList, config: &Config) -> PyResult<Self> {
40-
let points = points
41-
.into_iter()
42-
.map(FloatArray::try_from)
43-
.collect::<Result<Vec<_>, PyErr>>()?;
44-
4574
let values = values
4675
.into_iter()
4776
.map(MapValue::try_from)
4877
.collect::<Result<Vec<_>, PyErr>>()?;
49-
50-
let hsnw_map = instant_distance::Builder::from(config).build(points, values);
51-
Ok(Self { inner: hsnw_map })
78+
let builder = instant_distance::Builder::from(config);
79+
let inner = match config.distance_metric {
80+
DistanceMetric::Euclid => {
81+
let points = FloatArray::try_from_pylist(points)?;
82+
HnswMapWithMetric::Euclid(builder.build(points, values))
83+
}
84+
DistanceMetric::Cosine => {
85+
let points = FloatArray::try_from_pylist(points)?;
86+
HnswMapWithMetric::Cosine(builder.build(points, values))
87+
}
88+
};
89+
Ok(Self { inner })
5290
}
5391

5492
/// Load an index from the given file name
5593
#[staticmethod]
5694
fn load(fname: &str) -> PyResult<Self> {
57-
let hnsw_map =
58-
bincode::deserialize_from::<_, instant_distance::HnswMap<FloatArray, MapValue>>(
59-
BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?),
60-
)
61-
.map_err(|e| PyValueError::new_err(format!("deserialization error: {e:?}")))?;
95+
let hnsw_map = bincode::deserialize_from::<_, HnswMapWithMetric>(BufReader::with_capacity(
96+
32 * 1024 * 1024,
97+
File::open(fname)?,
98+
))
99+
.map_err(|e| PyValueError::new_err(format!("deserialization error: {e:?}")))?;
62100
Ok(Self { inner: hnsw_map })
63101
}
64102

@@ -78,8 +116,10 @@ impl HnswMap {
78116
///
79117
/// For best performance, reusing `Search` objects is recommended.
80118
fn search(slf: Py<Self>, point: &PyAny, search: &mut Search, py: Python<'_>) -> PyResult<()> {
81-
let point = FloatArray::try_from(point)?;
82-
let _ = slf.try_borrow(py)?.inner.search(&point, &mut search.inner);
119+
impl_for_each_hnsw_with_metric!(HnswMapWithMetric, &slf.try_borrow(py)?.inner, hnsw, {
120+
let point = FloatArray::try_from(point)?;
121+
let _ = hnsw.search(&point, &mut search.inner);
122+
});
83123
search.cur = Some((HnswType::Map(slf.clone_ref(py)), 0));
84124
Ok(())
85125
}
@@ -91,30 +131,44 @@ impl HnswMap {
91131
/// with a squared Euclidean distance metric.
92132
#[pyclass]
93133
struct Hnsw {
94-
inner: instant_distance::Hnsw<FloatArray>,
134+
inner: HnswWithMetric,
135+
}
136+
137+
#[derive(Deserialize, Serialize)]
138+
enum HnswWithMetric {
139+
Euclid(instant_distance::Hnsw<FloatArray>),
140+
Cosine(instant_distance::Hnsw<FloatArray>),
95141
}
96142

97143
#[pymethods]
98144
impl Hnsw {
99145
/// Build the index
100146
#[staticmethod]
101147
fn build(input: &PyList, config: &Config) -> PyResult<(Self, Vec<u32>)> {
102-
let points = input
103-
.into_iter()
104-
.map(FloatArray::try_from)
105-
.collect::<Result<Vec<_>, PyErr>>()?;
106-
107-
let (inner, ids) = instant_distance::Builder::from(config).build_hnsw(points);
148+
let builder = instant_distance::Builder::from(config);
149+
let (inner, ids) = match config.distance_metric {
150+
DistanceMetric::Euclid => {
151+
let points = FloatArray::try_from_pylist(input)?;
152+
let (hnsw, ids) = builder.build_hnsw(points);
153+
(HnswWithMetric::Euclid(hnsw), ids)
154+
}
155+
DistanceMetric::Cosine => {
156+
let points = FloatArray::try_from_pylist(input)?;
157+
let (hnsw, ids) = builder.build_hnsw(points);
158+
(HnswWithMetric::Cosine(hnsw), ids)
159+
}
160+
};
108161
let ids = Vec::from_iter(ids.into_iter().map(|pid| pid.into_inner()));
109162
Ok((Self { inner }, ids))
110163
}
111164

112165
/// Load an index from the given file name
113166
#[staticmethod]
114167
fn load(fname: &str) -> PyResult<Self> {
115-
let hnsw = bincode::deserialize_from::<_, instant_distance::Hnsw<FloatArray>>(
116-
BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?),
117-
)
168+
let hnsw = bincode::deserialize_from::<_, HnswWithMetric>(BufReader::with_capacity(
169+
32 * 1024 * 1024,
170+
File::open(fname)?,
171+
))
118172
.map_err(|e| PyValueError::new_err(format!("deserialization error: {e:?}")))?;
119173
Ok(Self { inner: hnsw })
120174
}
@@ -135,8 +189,10 @@ impl Hnsw {
135189
///
136190
/// For best performance, reusing `Search` objects is recommended.
137191
fn search(slf: Py<Self>, point: &PyAny, search: &mut Search, py: Python<'_>) -> PyResult<()> {
138-
let point = FloatArray::try_from(point)?;
139-
let _ = slf.try_borrow(py)?.inner.search(&point, &mut search.inner);
192+
impl_for_each_hnsw_with_metric!(HnswWithMetric, &slf.try_borrow(py)?.inner, hnsw, {
193+
let point = FloatArray::try_from(point)?;
194+
let _ = hnsw.search(&point, &mut search.inner);
195+
});
140196
search.cur = Some((HnswType::Hnsw(slf.clone_ref(py)), 0));
141197
Ok(())
142198
}
@@ -175,20 +231,24 @@ impl Search {
175231
let neighbor = match &index {
176232
HnswType::Hnsw(hnsw) => {
177233
let hnsw = hnsw.as_ref(py).borrow();
178-
let item = hnsw.inner.get(idx, &slf.inner);
179-
item.map(|item| Neighbor {
180-
distance: item.distance,
181-
pid: item.pid.into_inner(),
182-
value: py.None(),
234+
impl_for_each_hnsw_with_metric!(HnswWithMetric, &hnsw.inner, hnsw, {
235+
let item = hnsw.get(idx, &slf.inner);
236+
item.map(|item| Neighbor {
237+
distance: item.distance,
238+
pid: item.pid.into_inner(),
239+
value: py.None(),
240+
})
183241
})
184242
}
185243
HnswType::Map(map) => {
186244
let map = map.as_ref(py).borrow();
187-
let item = map.inner.get(idx, &slf.inner);
188-
item.map(|item| Neighbor {
189-
distance: item.distance,
190-
pid: item.pid.into_inner(),
191-
value: item.value.into_py(py),
245+
impl_for_each_hnsw_with_metric!(HnswMapWithMetric, &map.inner, map, {
246+
let item = map.get(idx, &slf.inner);
247+
item.map(|item| Neighbor {
248+
distance: item.distance,
249+
pid: item.pid.into_inner(),
250+
value: item.value.into_py(py),
251+
})
192252
})
193253
}
194254
};
@@ -226,6 +286,11 @@ struct Config {
226286
/// in order to get better results on clustered data points.
227287
#[pyo3(get, set)]
228288
heuristic: Option<Heuristic>,
289+
/// Distance metric to use
290+
///
291+
/// Defaults to Euclidean distance
292+
#[pyo3(get, set)]
293+
distance_metric: DistanceMetric,
229294
}
230295

231296
#[pymethods]
@@ -235,12 +300,14 @@ impl Config {
235300
let builder = instant_distance::Builder::default();
236301
let (ef_search, ef_construction, ml, seed) = builder.into_parts();
237302
let heuristic = Some(Heuristic::default());
303+
let distance_metric = DistanceMetric::default();
238304
Self {
239305
ef_search,
240306
ef_construction,
241307
ml,
242308
seed,
243309
heuristic,
310+
distance_metric,
244311
}
245312
}
246313
}
@@ -253,6 +320,7 @@ impl From<&Config> for instant_distance::Builder {
253320
ml,
254321
seed,
255322
heuristic,
323+
distance_metric: _,
256324
} = *py;
257325
Self::default()
258326
.ef_search(ef_search)
@@ -350,6 +418,12 @@ impl Neighbor {
350418
#[derive(Clone, Deserialize, Serialize)]
351419
struct FloatArray(#[serde(with = "BigArray")] [f32; DIMENSIONS]);
352420

421+
impl FloatArray {
422+
fn try_from_pylist(list: &PyList) -> Result<Vec<Self>, PyErr> {
423+
list.into_iter().map(FloatArray::try_from).collect()
424+
}
425+
}
426+
353427
impl TryFrom<&PyAny> for FloatArray {
354428
type Error = PyErr;
355429

instant-distance-py/test/test.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import instant_distance, random
22

33

4-
def test_hsnw():
4+
def test_hsnw(distance_metric=instant_distance.DistanceMetric.Euclid):
55
points = [[random.random() for _ in range(300)] for _ in range(1024)]
66
config = instant_distance.Config()
7+
config.distance_metric = distance_metric
78
(hnsw, ids) = instant_distance.Hnsw.build(points, config)
89
p = [random.random() for _ in range(300)]
910
search = instant_distance.Search()
@@ -12,14 +13,15 @@ def test_hsnw():
1213
print(candidate)
1314

1415

15-
def test_hsnw_map():
16+
def test_hsnw_map(distance_metric=instant_distance.DistanceMetric.Euclid):
1617
the_chosen_one = 123
1718

1819
embeddings = [[random.random() for _ in range(300)] for _ in range(1024)]
1920
with open("/usr/share/dict/words", "r") as f: # *nix only
2021
values = f.read().splitlines()[1024:]
2122

2223
config = instant_distance.Config()
24+
config.distance_metric = distance_metric
2325
hnsw_map = instant_distance.HnswMap.build(embeddings, values, config)
2426

2527
search = instant_distance.Search()
@@ -38,3 +40,5 @@ def test_hsnw_map():
3840
if __name__ == "__main__":
3941
test_hsnw()
4042
test_hsnw_map()
43+
test_hsnw(instant_distance.DistanceMetric.Cosine)
44+
test_hsnw_map(instant_distance.DistanceMetric.Cosine)

0 commit comments

Comments
 (0)