Skip to content

Commit 52077ec

Browse files
committed
Add custom storage support
1 parent acf86a8 commit 52077ec

File tree

7 files changed

+150
-66
lines changed

7 files changed

+150
-66
lines changed

instant-distance-py/benches/all.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,11 @@ fn build(bench: &mut Bencher) {
2222
.map(|_| FloatArray([rng.gen(); 300]))
2323
.collect::<Vec<_>>();
2424

25-
bench.iter(|| Builder::default().seed(SEED).build_hnsw(points.clone()));
25+
bench.iter(|| {
26+
Builder::default()
27+
.seed(SEED)
28+
.build_hnsw::<_, _, Vec<FloatArray>>(points.clone())
29+
});
2630
}
2731

2832
fn query(bench: &mut Bencher) {
@@ -31,7 +35,9 @@ fn query(bench: &mut Bencher) {
3135
.into_iter()
3236
.map(|_| FloatArray([rng.gen(); 300]))
3337
.collect::<Vec<_>>();
34-
let (hnsw, _) = Builder::default().seed(SEED).build_hnsw(points);
38+
let (hnsw, _) = Builder::default()
39+
.seed(SEED)
40+
.build_hnsw::<_, _, Vec<FloatArray>>(points);
3541
let point = FloatArray([rng.gen(); 300]);
3642

3743
bench.iter(|| {

instant-distance-py/src/lib.rs

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ fn instant_distance_py(_: Python, m: &PyModule) -> PyResult<()> {
2929

3030
#[pyclass]
3131
struct HnswMap {
32-
inner: instant_distance::HnswMap<FloatArray, MapValue>,
32+
inner: instant_distance::HnswMap<FloatArray, MapValue, Vec<FloatArray>>,
3333
}
3434

3535
#[pymethods]
@@ -54,11 +54,10 @@ impl HnswMap {
5454
/// Load an index from the given file name
5555
#[staticmethod]
5656
fn load(fname: &str) -> PyResult<Self> {
57-
let hnsw_map =
58-
bincode::deserialize_from::<_, instant_distance::HnswMap<FloatArray, MapValue>>(
59-
BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?),
60-
)
61-
.map_err(|e| PyValueError::new_err(format!("deserialization error: {e:?}")))?;
57+
let hnsw_map = bincode::deserialize_from::<_, instant_distance::HnswMap<_, _, _>>(
58+
BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?),
59+
)
60+
.map_err(|e| PyValueError::new_err(format!("deserialization error: {e:?}")))?;
6261
Ok(Self { inner: hnsw_map })
6362
}
6463

@@ -91,7 +90,7 @@ impl HnswMap {
9190
/// with a squared Euclidean distance metric.
9291
#[pyclass]
9392
struct Hnsw {
94-
inner: instant_distance::Hnsw<FloatArray>,
93+
inner: instant_distance::Hnsw<FloatArray, Vec<FloatArray>>,
9594
}
9695

9796
#[pymethods]
@@ -112,7 +111,7 @@ impl Hnsw {
112111
/// Load an index from the given file name
113112
#[staticmethod]
114113
fn load(fname: &str) -> PyResult<Self> {
115-
let hnsw = bincode::deserialize_from::<_, instant_distance::Hnsw<FloatArray>>(
114+
let hnsw = bincode::deserialize_from::<_, instant_distance::Hnsw<_, _>>(
116115
BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?),
117116
)
118117
.map_err(|e| PyValueError::new_err(format!("deserialization error: {e:?}")))?;
@@ -145,7 +144,7 @@ impl Hnsw {
145144
/// Search buffer and result set
146145
#[pyclass]
147146
struct Search {
148-
inner: instant_distance::Search,
147+
inner: instant_distance::Search<FloatArray>,
149148
cur: Option<(HnswType, usize)>,
150149
}
151150

instant-distance/benches/all.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@ fn build_heuristic(bench: &mut Bencher) {
1414
.map(|_| Point(rng.gen(), rng.gen()))
1515
.collect::<Vec<_>>();
1616

17-
bench.iter(|| Builder::default().seed(SEED).build_hnsw(points.clone()))
17+
bench.iter(|| {
18+
Builder::default()
19+
.seed(SEED)
20+
.build_hnsw::<Point, Point, Vec<Point>>(points.clone())
21+
})
1822
}
1923

2024
const SEED: u64 = 123456789;

instant-distance/examples/colors.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ fn main() {
44
let points = vec![Point(255, 0, 0), Point(0, 255, 0), Point(0, 0, 255)];
55
let values = vec!["red", "green", "blue"];
66

7-
let map = Builder::default().build(points, values);
7+
let map = Builder::default().build::<Point, Point, &str, Vec<Point>>(points, values);
88
let mut search = Search::default();
99

1010
let burnt_orange = Point(204, 85, 0);

0 commit comments

Comments
 (0)