Skip to content

Commit 11b5626

Browse files
committed
Add custom storage support
1 parent 53fbb45 commit 11b5626

File tree

7 files changed

+150
-66
lines changed

7 files changed

+150
-66
lines changed

instant-distance-py/benches/all.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,21 @@ fn build(bench: &mut Bencher) {
2121
.map(|_| FloatArray([rng.gen(); 300]))
2222
.collect::<Vec<_>>();
2323

24-
bench.iter(|| Builder::default().seed(SEED).build_hnsw(points.clone()));
24+
bench.iter(|| {
25+
Builder::default()
26+
.seed(SEED)
27+
.build_hnsw::<_, _, Vec<FloatArray>>(points.clone())
28+
});
2529
}
2630

2731
fn query(bench: &mut Bencher) {
2832
let mut rng = StdRng::seed_from_u64(SEED);
2933
let points = (0..1024)
3034
.map(|_| FloatArray([rng.gen(); 300]))
3135
.collect::<Vec<_>>();
32-
let (hnsw, _) = Builder::default().seed(SEED).build_hnsw(points);
36+
let (hnsw, _) = Builder::default()
37+
.seed(SEED)
38+
.build_hnsw::<_, _, Vec<FloatArray>>(points);
3339
let point = FloatArray([rng.gen(); 300]);
3440

3541
bench.iter(|| {

instant-distance-py/src/lib.rs

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ fn instant_distance_py(_: Python, m: &PyModule) -> PyResult<()> {
2929

3030
#[pyclass]
3131
struct HnswMap {
32-
inner: instant_distance::HnswMap<FloatArray, MapValue>,
32+
inner: instant_distance::HnswMap<FloatArray, MapValue, Vec<FloatArray>>,
3333
}
3434

3535
#[pymethods]
@@ -54,11 +54,10 @@ impl HnswMap {
5454
/// Load an index from the given file name
5555
#[staticmethod]
5656
fn load(fname: &str) -> PyResult<Self> {
57-
let hnsw_map =
58-
bincode::deserialize_from::<_, instant_distance::HnswMap<FloatArray, MapValue>>(
59-
BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?),
60-
)
61-
.map_err(|e| PyValueError::new_err(format!("deserialization error: {e:?}")))?;
57+
let hnsw_map = bincode::deserialize_from::<_, instant_distance::HnswMap<_, _, _>>(
58+
BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?),
59+
)
60+
.map_err(|e| PyValueError::new_err(format!("deserialization error: {e:?}")))?;
6261
Ok(Self { inner: hnsw_map })
6362
}
6463

@@ -91,7 +90,7 @@ impl HnswMap {
9190
/// with a squared Euclidean distance metric.
9291
#[pyclass]
9392
struct Hnsw {
94-
inner: instant_distance::Hnsw<FloatArray>,
93+
inner: instant_distance::Hnsw<FloatArray, Vec<FloatArray>>,
9594
}
9695

9796
#[pymethods]
@@ -112,7 +111,7 @@ impl Hnsw {
112111
/// Load an index from the given file name
113112
#[staticmethod]
114113
fn load(fname: &str) -> PyResult<Self> {
115-
let hnsw = bincode::deserialize_from::<_, instant_distance::Hnsw<FloatArray>>(
114+
let hnsw = bincode::deserialize_from::<_, instant_distance::Hnsw<_, _>>(
116115
BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?),
117116
)
118117
.map_err(|e| PyValueError::new_err(format!("deserialization error: {e:?}")))?;
@@ -145,7 +144,7 @@ impl Hnsw {
145144
/// Search buffer and result set
146145
#[pyclass]
147146
struct Search {
148-
inner: instant_distance::Search,
147+
inner: instant_distance::Search<FloatArray>,
149148
cur: Option<(HnswType, usize)>,
150149
}
151150

instant-distance/benches/all.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@ fn build_heuristic(bench: &mut Bencher) {
1414
.map(|_| Point(rng.gen(), rng.gen()))
1515
.collect::<Vec<_>>();
1616

17-
bench.iter(|| Builder::default().seed(SEED).build_hnsw(points.clone()))
17+
bench.iter(|| {
18+
Builder::default()
19+
.seed(SEED)
20+
.build_hnsw::<Point, Point, Vec<Point>>(points.clone())
21+
})
1822
}
1923

2024
const SEED: u64 = 123456789;

instant-distance/examples/colors.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ fn main() {
44
let points = vec![Point(255, 0, 0), Point(0, 255, 0), Point(0, 0, 255)];
55
let values = vec!["red", "green", "blue"];
66

7-
let map = Builder::default().build(points, values);
7+
let map = Builder::default().build::<Point, Point, &str, Vec<Point>>(points, values);
88
let mut search = Search::default();
99

1010
let burnt_orange = Point(204, 85, 0);

0 commit comments

Comments
 (0)