-
-
Notifications
You must be signed in to change notification settings - Fork 30
Implement distance metric selection #35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,14 @@ | ||
test-python: | ||
cargo build --release | ||
cp target/release/libinstant_distance.dylib instant-distance-py/test/instant_distance.so | ||
instant-distance-py/test/instant_distance.so: instant-distance-py/src/lib.rs | ||
RUSTFLAGS="-C target-cpu=native" cargo build --release | ||
([ -f target/release/libinstant_distance.dylib ] && cp target/release/libinstant_distance.dylib instant-distance-py/test/instant_distance.so) || \ | ||
([ -f target/release/libinstant_distance.so ] && cp target/release/libinstant_distance.so instant-distance-py/test/instant_distance.so) | ||
|
||
test-python: instant-distance-py/test/instant_distance.so | ||
PYTHONPATH=instant-distance-py/test/ python3 -m test | ||
|
||
bench-python: instant-distance-py/test/instant_distance.so | ||
PYTHONPATH=instant-distance-py/test/ python3 -m timeit -n 10 -s 'import random, instant_distance; points = [[random.random() for _ in range(300)] for _ in range(1024)]; config = instant_distance.Config()' 'instant_distance.Hnsw.build(points, config)' | ||
|
||
clean: | ||
cargo clean | ||
rm -f instant-distance-py/test/instant_distance.so |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,41 +24,79 @@ fn instant_distance_py(_: Python, m: &PyModule) -> PyResult<()> { | |
m.add_class::<Search>()?; | ||
m.add_class::<Hnsw>()?; | ||
m.add_class::<HnswMap>()?; | ||
m.add_class::<DistanceMetric>()?; | ||
Ok(()) | ||
} | ||
|
||
#[pyclass] | ||
#[derive(Copy, Clone)] | ||
enum DistanceMetric { | ||
Euclid, | ||
Cosine, | ||
} | ||
|
||
impl Default for DistanceMetric { | ||
fn default() -> Self { | ||
Self::Euclid | ||
} | ||
} | ||
|
||
// Helper macro for dispatching to inner implementation | ||
macro_rules! impl_for_each_hnsw_with_metric { | ||
($type:ident, $instance:expr, $inner:ident, $($tokens:tt)+) => { | ||
match $instance { | ||
$type::Euclid($inner) => { | ||
$($tokens)+ | ||
} | ||
$type::Cosine($inner) => { | ||
$($tokens)+ | ||
} | ||
} | ||
}; | ||
} | ||
Comment on lines
+46
to
+58
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Style nit: even though we have to define the macro here before using it, I'd like to move the |
||
|
||
#[pyclass] | ||
struct HnswMap { | ||
inner: instant_distance::HnswMap<FloatArray, MapValue>, | ||
inner: HnswMapWithMetric, | ||
} | ||
|
||
#[derive(Deserialize, Serialize)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Style: in order to keep the |
||
enum HnswMapWithMetric { | ||
Euclid(instant_distance::HnswMap<FloatArray, MapValue>), | ||
Cosine(instant_distance::HnswMap<FloatArray, MapValue>), | ||
} | ||
|
||
#[pymethods] | ||
impl HnswMap { | ||
/// Build the index | ||
#[staticmethod] | ||
fn build(points: &PyList, values: &PyList, config: &Config) -> PyResult<Self> { | ||
let points = points | ||
.into_iter() | ||
.map(FloatArray::try_from) | ||
.collect::<Result<Vec<_>, PyErr>>()?; | ||
|
||
let values = values | ||
.into_iter() | ||
.map(MapValue::try_from) | ||
.collect::<Result<Vec<_>, PyErr>>()?; | ||
|
||
let hsnw_map = instant_distance::Builder::from(config).build(points, values); | ||
Ok(Self { inner: hsnw_map }) | ||
let builder = instant_distance::Builder::from(config); | ||
let inner = match config.distance_metric { | ||
DistanceMetric::Euclid => { | ||
let points = FloatArray::try_from_pylist(points)?; | ||
djc marked this conversation as resolved.
Show resolved
Hide resolved
|
||
HnswMapWithMetric::Euclid(builder.build(points, values)) | ||
} | ||
DistanceMetric::Cosine => { | ||
let points = FloatArray::try_from_pylist(points)?; | ||
HnswMapWithMetric::Cosine(builder.build(points, values)) | ||
} | ||
}; | ||
Ok(Self { inner }) | ||
} | ||
|
||
/// Load an index from the given file name | ||
#[staticmethod] | ||
fn load(fname: &str) -> PyResult<Self> { | ||
let hnsw_map = | ||
bincode::deserialize_from::<_, instant_distance::HnswMap<FloatArray, MapValue>>( | ||
BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?), | ||
) | ||
.map_err(|e| PyValueError::new_err(format!("deserialization error: {e:?}")))?; | ||
let hnsw_map = bincode::deserialize_from::<_, HnswMapWithMetric>(BufReader::with_capacity( | ||
32 * 1024 * 1024, | ||
File::open(fname)?, | ||
)) | ||
.map_err(|e| PyValueError::new_err(format!("deserialization error: {e:?}")))?; | ||
Ok(Self { inner: hnsw_map }) | ||
} | ||
|
||
|
@@ -78,8 +116,10 @@ impl HnswMap { | |
/// | ||
/// For best performance, reusing `Search` objects is recommended. | ||
fn search(slf: Py<Self>, point: &PyAny, search: &mut Search, py: Python<'_>) -> PyResult<()> { | ||
let point = FloatArray::try_from(point)?; | ||
let _ = slf.try_borrow(py)?.inner.search(&point, &mut search.inner); | ||
impl_for_each_hnsw_with_metric!(HnswMapWithMetric, &slf.try_borrow(py)?.inner, hnsw, { | ||
let point = FloatArray::try_from(point)?; | ||
let _ = hnsw.search(&point, &mut search.inner); | ||
}); | ||
search.cur = Some((HnswType::Map(slf.clone_ref(py)), 0)); | ||
Ok(()) | ||
} | ||
|
@@ -91,30 +131,44 @@ impl HnswMap { | |
/// with a squared Euclidean distance metric. | ||
#[pyclass] | ||
struct Hnsw { | ||
inner: instant_distance::Hnsw<FloatArray>, | ||
inner: HnswWithMetric, | ||
} | ||
|
||
#[derive(Deserialize, Serialize)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here. |
||
enum HnswWithMetric { | ||
Euclid(instant_distance::Hnsw<FloatArray>), | ||
Cosine(instant_distance::Hnsw<FloatArray>), | ||
} | ||
|
||
#[pymethods] | ||
impl Hnsw { | ||
/// Build the index | ||
#[staticmethod] | ||
fn build(input: &PyList, config: &Config) -> PyResult<(Self, Vec<u32>)> { | ||
let points = input | ||
.into_iter() | ||
.map(FloatArray::try_from) | ||
.collect::<Result<Vec<_>, PyErr>>()?; | ||
|
||
let (inner, ids) = instant_distance::Builder::from(config).build_hnsw(points); | ||
let builder = instant_distance::Builder::from(config); | ||
let (inner, ids) = match config.distance_metric { | ||
DistanceMetric::Euclid => { | ||
let points = FloatArray::try_from_pylist(input)?; | ||
let (hnsw, ids) = builder.build_hnsw(points); | ||
(HnswWithMetric::Euclid(hnsw), ids) | ||
} | ||
DistanceMetric::Cosine => { | ||
let points = FloatArray::try_from_pylist(input)?; | ||
let (hnsw, ids) = builder.build_hnsw(points); | ||
(HnswWithMetric::Cosine(hnsw), ids) | ||
} | ||
}; | ||
let ids = Vec::from_iter(ids.into_iter().map(|pid| pid.into_inner())); | ||
Ok((Self { inner }, ids)) | ||
} | ||
|
||
/// Load an index from the given file name | ||
#[staticmethod] | ||
fn load(fname: &str) -> PyResult<Self> { | ||
let hnsw = bincode::deserialize_from::<_, instant_distance::Hnsw<FloatArray>>( | ||
BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?), | ||
) | ||
let hnsw = bincode::deserialize_from::<_, HnswWithMetric>(BufReader::with_capacity( | ||
32 * 1024 * 1024, | ||
File::open(fname)?, | ||
)) | ||
.map_err(|e| PyValueError::new_err(format!("deserialization error: {e:?}")))?; | ||
Ok(Self { inner: hnsw }) | ||
} | ||
|
@@ -135,8 +189,10 @@ impl Hnsw { | |
/// | ||
/// For best performance, reusing `Search` objects is recommended. | ||
fn search(slf: Py<Self>, point: &PyAny, search: &mut Search, py: Python<'_>) -> PyResult<()> { | ||
let point = FloatArray::try_from(point)?; | ||
let _ = slf.try_borrow(py)?.inner.search(&point, &mut search.inner); | ||
impl_for_each_hnsw_with_metric!(HnswWithMetric, &slf.try_borrow(py)?.inner, hnsw, { | ||
let point = FloatArray::try_from(point)?; | ||
let _ = hnsw.search(&point, &mut search.inner); | ||
}); | ||
search.cur = Some((HnswType::Hnsw(slf.clone_ref(py)), 0)); | ||
Ok(()) | ||
} | ||
|
@@ -175,20 +231,24 @@ impl Search { | |
let neighbor = match &index { | ||
HnswType::Hnsw(hnsw) => { | ||
let hnsw = hnsw.as_ref(py).borrow(); | ||
let item = hnsw.inner.get(idx, &slf.inner); | ||
item.map(|item| Neighbor { | ||
distance: item.distance, | ||
pid: item.pid.into_inner(), | ||
value: py.None(), | ||
impl_for_each_hnsw_with_metric!(HnswWithMetric, &hnsw.inner, hnsw, { | ||
let item = hnsw.get(idx, &slf.inner); | ||
item.map(|item| Neighbor { | ||
distance: item.distance, | ||
pid: item.pid.into_inner(), | ||
value: py.None(), | ||
}) | ||
}) | ||
} | ||
HnswType::Map(map) => { | ||
let map = map.as_ref(py).borrow(); | ||
let item = map.inner.get(idx, &slf.inner); | ||
item.map(|item| Neighbor { | ||
distance: item.distance, | ||
pid: item.pid.into_inner(), | ||
value: item.value.into_py(py), | ||
impl_for_each_hnsw_with_metric!(HnswMapWithMetric, &map.inner, map, { | ||
let item = map.get(idx, &slf.inner); | ||
item.map(|item| Neighbor { | ||
distance: item.distance, | ||
pid: item.pid.into_inner(), | ||
value: item.value.into_py(py), | ||
}) | ||
}) | ||
} | ||
}; | ||
|
@@ -226,6 +286,11 @@ struct Config { | |
/// in order to get better results on clustered data points. | ||
#[pyo3(get, set)] | ||
heuristic: Option<Heuristic>, | ||
/// Distance metric to use | ||
/// | ||
/// Defaults to Euclidean distance | ||
#[pyo3(get, set)] | ||
distance_metric: DistanceMetric, | ||
} | ||
|
||
#[pymethods] | ||
|
@@ -235,12 +300,14 @@ impl Config { | |
let builder = instant_distance::Builder::default(); | ||
let (ef_search, ef_construction, ml, seed) = builder.into_parts(); | ||
let heuristic = Some(Heuristic::default()); | ||
let distance_metric = DistanceMetric::default(); | ||
Self { | ||
ef_search, | ||
ef_construction, | ||
ml, | ||
seed, | ||
heuristic, | ||
distance_metric, | ||
} | ||
} | ||
} | ||
|
@@ -253,6 +320,7 @@ impl From<&Config> for instant_distance::Builder { | |
ml, | ||
seed, | ||
heuristic, | ||
distance_metric: _, | ||
} = *py; | ||
Self::default() | ||
.ef_search(ef_search) | ||
|
@@ -350,6 +418,12 @@ impl Neighbor { | |
#[derive(Clone, Deserialize, Serialize)] | ||
struct FloatArray(#[serde(with = "BigArray")] [f32; DIMENSIONS]); | ||
|
||
impl FloatArray { | ||
fn try_from_pylist(list: &PyList) -> Result<Vec<Self>, PyErr> { | ||
list.into_iter().map(FloatArray::try_from).collect() | ||
} | ||
} | ||
|
||
impl TryFrom<&PyAny> for FloatArray { | ||
type Error = PyErr; | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given that we don't actually have a
Cosine
implementation at this point, we should not introduce this variant here until the commit where we add the backing implementation.