Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
test-python:
cargo build --release
cp target/release/libinstant_distance.dylib instant-distance-py/test/instant_distance.so
instant-distance-py/test/instant_distance.so: instant-distance-py/src/lib.rs
RUSTFLAGS="-C target-cpu=native" cargo build --release
([ -f target/release/libinstant_distance.dylib ] && cp target/release/libinstant_distance.dylib instant-distance-py/test/instant_distance.so) || \
([ -f target/release/libinstant_distance.so ] && cp target/release/libinstant_distance.so instant-distance-py/test/instant_distance.so)

test-python: instant-distance-py/test/instant_distance.so
PYTHONPATH=instant-distance-py/test/ python3 -m test

bench-python: instant-distance-py/test/instant_distance.so
PYTHONPATH=instant-distance-py/test/ python3 -m timeit -n 10 -s 'import random, instant_distance; points = [[random.random() for _ in range(300)] for _ in range(1024)]; config = instant_distance.Config()' 'instant_distance.Hnsw.build(points, config)'

clean:
cargo clean
rm -f instant-distance-py/test/instant_distance.so
150 changes: 112 additions & 38 deletions instant-distance-py/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,41 +24,79 @@ fn instant_distance_py(_: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<Search>()?;
m.add_class::<Hnsw>()?;
m.add_class::<HnswMap>()?;
m.add_class::<DistanceMetric>()?;
Ok(())
}

#[pyclass]
#[derive(Copy, Clone)]
enum DistanceMetric {
Euclid,
Cosine,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that we don't actually have a Cosine implementation at this point, we should not introduce this variant here until the commit where we add the backing implementation.

}

impl Default for DistanceMetric {
fn default() -> Self {
Self::Euclid
}
}

// Helper macro for dispatching to inner implementation
macro_rules! impl_for_each_hnsw_with_metric {
($type:ident, $instance:expr, $inner:ident, $($tokens:tt)+) => {
match $instance {
$type::Euclid($inner) => {
$($tokens)+
}
$type::Cosine($inner) => {
$($tokens)+
}
}
};
}
Comment on lines +46 to +58
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Style nit: even though we have to define the macro here before using it, I'd like to move the DistanceMetric closer to the bottom of the module (probably below MapValue?), to fit it in with the top-down order.


#[pyclass]
struct HnswMap {
inner: instant_distance::HnswMap<FloatArray, MapValue>,
inner: HnswMapWithMetric,
}

#[derive(Deserialize, Serialize)]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Style: in order to keep the struct HnswMap close to the impl below, I'd prefer to move the HnswMapWithMetric below the HnswMap impl block (above the Hnsw).

enum HnswMapWithMetric {
Euclid(instant_distance::HnswMap<FloatArray, MapValue>),
Cosine(instant_distance::HnswMap<FloatArray, MapValue>),
}

#[pymethods]
impl HnswMap {
/// Build the index
#[staticmethod]
fn build(points: &PyList, values: &PyList, config: &Config) -> PyResult<Self> {
let points = points
.into_iter()
.map(FloatArray::try_from)
.collect::<Result<Vec<_>, PyErr>>()?;

let values = values
.into_iter()
.map(MapValue::try_from)
.collect::<Result<Vec<_>, PyErr>>()?;

let hsnw_map = instant_distance::Builder::from(config).build(points, values);
Ok(Self { inner: hsnw_map })
let builder = instant_distance::Builder::from(config);
let inner = match config.distance_metric {
DistanceMetric::Euclid => {
let points = FloatArray::try_from_pylist(points)?;
HnswMapWithMetric::Euclid(builder.build(points, values))
}
DistanceMetric::Cosine => {
let points = FloatArray::try_from_pylist(points)?;
HnswMapWithMetric::Cosine(builder.build(points, values))
}
};
Ok(Self { inner })
}

/// Load an index from the given file name
#[staticmethod]
fn load(fname: &str) -> PyResult<Self> {
let hnsw_map =
bincode::deserialize_from::<_, instant_distance::HnswMap<FloatArray, MapValue>>(
BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?),
)
.map_err(|e| PyValueError::new_err(format!("deserialization error: {e:?}")))?;
let hnsw_map = bincode::deserialize_from::<_, HnswMapWithMetric>(BufReader::with_capacity(
32 * 1024 * 1024,
File::open(fname)?,
))
.map_err(|e| PyValueError::new_err(format!("deserialization error: {e:?}")))?;
Ok(Self { inner: hnsw_map })
}

Expand All @@ -78,8 +116,10 @@ impl HnswMap {
///
/// For best performance, reusing `Search` objects is recommended.
fn search(slf: Py<Self>, point: &PyAny, search: &mut Search, py: Python<'_>) -> PyResult<()> {
let point = FloatArray::try_from(point)?;
let _ = slf.try_borrow(py)?.inner.search(&point, &mut search.inner);
impl_for_each_hnsw_with_metric!(HnswMapWithMetric, &slf.try_borrow(py)?.inner, hnsw, {
let point = FloatArray::try_from(point)?;
let _ = hnsw.search(&point, &mut search.inner);
});
search.cur = Some((HnswType::Map(slf.clone_ref(py)), 0));
Ok(())
}
Expand All @@ -91,30 +131,44 @@ impl HnswMap {
/// with a squared Euclidean distance metric.
#[pyclass]
struct Hnsw {
inner: instant_distance::Hnsw<FloatArray>,
inner: HnswWithMetric,
}

#[derive(Deserialize, Serialize)]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

enum HnswWithMetric {
Euclid(instant_distance::Hnsw<FloatArray>),
Cosine(instant_distance::Hnsw<FloatArray>),
}

#[pymethods]
impl Hnsw {
/// Build the index
#[staticmethod]
fn build(input: &PyList, config: &Config) -> PyResult<(Self, Vec<u32>)> {
let points = input
.into_iter()
.map(FloatArray::try_from)
.collect::<Result<Vec<_>, PyErr>>()?;

let (inner, ids) = instant_distance::Builder::from(config).build_hnsw(points);
let builder = instant_distance::Builder::from(config);
let (inner, ids) = match config.distance_metric {
DistanceMetric::Euclid => {
let points = FloatArray::try_from_pylist(input)?;
let (hnsw, ids) = builder.build_hnsw(points);
(HnswWithMetric::Euclid(hnsw), ids)
}
DistanceMetric::Cosine => {
let points = FloatArray::try_from_pylist(input)?;
let (hnsw, ids) = builder.build_hnsw(points);
(HnswWithMetric::Cosine(hnsw), ids)
}
};
let ids = Vec::from_iter(ids.into_iter().map(|pid| pid.into_inner()));
Ok((Self { inner }, ids))
}

/// Load an index from the given file name
#[staticmethod]
fn load(fname: &str) -> PyResult<Self> {
let hnsw = bincode::deserialize_from::<_, instant_distance::Hnsw<FloatArray>>(
BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?),
)
let hnsw = bincode::deserialize_from::<_, HnswWithMetric>(BufReader::with_capacity(
32 * 1024 * 1024,
File::open(fname)?,
))
.map_err(|e| PyValueError::new_err(format!("deserialization error: {e:?}")))?;
Ok(Self { inner: hnsw })
}
Expand All @@ -135,8 +189,10 @@ impl Hnsw {
///
/// For best performance, reusing `Search` objects is recommended.
fn search(slf: Py<Self>, point: &PyAny, search: &mut Search, py: Python<'_>) -> PyResult<()> {
let point = FloatArray::try_from(point)?;
let _ = slf.try_borrow(py)?.inner.search(&point, &mut search.inner);
impl_for_each_hnsw_with_metric!(HnswWithMetric, &slf.try_borrow(py)?.inner, hnsw, {
let point = FloatArray::try_from(point)?;
let _ = hnsw.search(&point, &mut search.inner);
});
search.cur = Some((HnswType::Hnsw(slf.clone_ref(py)), 0));
Ok(())
}
Expand Down Expand Up @@ -175,20 +231,24 @@ impl Search {
let neighbor = match &index {
HnswType::Hnsw(hnsw) => {
let hnsw = hnsw.as_ref(py).borrow();
let item = hnsw.inner.get(idx, &slf.inner);
item.map(|item| Neighbor {
distance: item.distance,
pid: item.pid.into_inner(),
value: py.None(),
impl_for_each_hnsw_with_metric!(HnswWithMetric, &hnsw.inner, hnsw, {
let item = hnsw.get(idx, &slf.inner);
item.map(|item| Neighbor {
distance: item.distance,
pid: item.pid.into_inner(),
value: py.None(),
})
})
}
HnswType::Map(map) => {
let map = map.as_ref(py).borrow();
let item = map.inner.get(idx, &slf.inner);
item.map(|item| Neighbor {
distance: item.distance,
pid: item.pid.into_inner(),
value: item.value.into_py(py),
impl_for_each_hnsw_with_metric!(HnswMapWithMetric, &map.inner, map, {
let item = map.get(idx, &slf.inner);
item.map(|item| Neighbor {
distance: item.distance,
pid: item.pid.into_inner(),
value: item.value.into_py(py),
})
})
}
};
Expand Down Expand Up @@ -226,6 +286,11 @@ struct Config {
/// in order to get better results on clustered data points.
#[pyo3(get, set)]
heuristic: Option<Heuristic>,
/// Distance metric to use
///
/// Defaults to Euclidean distance
#[pyo3(get, set)]
distance_metric: DistanceMetric,
}

#[pymethods]
Expand All @@ -235,12 +300,14 @@ impl Config {
let builder = instant_distance::Builder::default();
let (ef_search, ef_construction, ml, seed) = builder.into_parts();
let heuristic = Some(Heuristic::default());
let distance_metric = DistanceMetric::default();
Self {
ef_search,
ef_construction,
ml,
seed,
heuristic,
distance_metric,
}
}
}
Expand All @@ -253,6 +320,7 @@ impl From<&Config> for instant_distance::Builder {
ml,
seed,
heuristic,
distance_metric: _,
} = *py;
Self::default()
.ef_search(ef_search)
Expand Down Expand Up @@ -350,6 +418,12 @@ impl Neighbor {
#[derive(Clone, Deserialize, Serialize)]
struct FloatArray(#[serde(with = "BigArray")] [f32; DIMENSIONS]);

impl FloatArray {
fn try_from_pylist(list: &PyList) -> Result<Vec<Self>, PyErr> {
list.into_iter().map(FloatArray::try_from).collect()
}
}

impl TryFrom<&PyAny> for FloatArray {
type Error = PyErr;

Expand Down
8 changes: 6 additions & 2 deletions instant-distance-py/test/test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import instant_distance, random


def test_hsnw():
def test_hsnw(distance_metric=instant_distance.DistanceMetric.Euclid):
points = [[random.random() for _ in range(300)] for _ in range(1024)]
config = instant_distance.Config()
config.distance_metric = distance_metric
(hnsw, ids) = instant_distance.Hnsw.build(points, config)
p = [random.random() for _ in range(300)]
search = instant_distance.Search()
Expand All @@ -12,14 +13,15 @@ def test_hsnw():
print(candidate)


def test_hsnw_map():
def test_hsnw_map(distance_metric=instant_distance.DistanceMetric.Euclid):
the_chosen_one = 123

embeddings = [[random.random() for _ in range(300)] for _ in range(1024)]
with open("/usr/share/dict/words", "r") as f: # *nix only
values = f.read().splitlines()[1024:]

config = instant_distance.Config()
config.distance_metric = distance_metric
hnsw_map = instant_distance.HnswMap.build(embeddings, values, config)

search = instant_distance.Search()
Expand All @@ -38,3 +40,5 @@ def test_hsnw_map():
if __name__ == "__main__":
test_hsnw()
test_hsnw_map()
test_hsnw(instant_distance.DistanceMetric.Cosine)
test_hsnw_map(instant_distance.DistanceMetric.Cosine)