@@ -6,7 +6,7 @@ use std::fs::File;
6
6
use std:: io:: { BufReader , BufWriter } ;
7
7
use std:: iter:: FromIterator ;
8
8
9
- use instant_distance:: Point ;
9
+ use instant_distance:: Metric ;
10
10
use pyo3:: conversion:: IntoPy ;
11
11
use pyo3:: exceptions:: { PyTypeError , PyValueError } ;
12
12
use pyo3:: types:: { PyList , PyModule , PyString } ;
@@ -29,7 +29,7 @@ fn instant_distance_py(_: Python, m: &PyModule) -> PyResult<()> {
29
29
30
30
#[ pyclass]
31
31
struct HnswMap {
32
- inner : instant_distance:: HnswMap < FloatArray , MapValue , Vec < FloatArray > > ,
32
+ inner : instant_distance:: HnswMap < FloatArray , EuclidMetric , MapValue , Vec < FloatArray > > ,
33
33
}
34
34
35
35
#[ pymethods]
@@ -54,7 +54,7 @@ impl HnswMap {
54
54
/// Load an index from the given file name
55
55
#[ staticmethod]
56
56
fn load ( fname : & str ) -> PyResult < Self > {
57
- let hnsw_map = bincode:: deserialize_from :: < _ , instant_distance:: HnswMap < _ , _ , _ > > (
57
+ let hnsw_map = bincode:: deserialize_from :: < _ , instant_distance:: HnswMap < _ , _ , _ , _ > > (
58
58
BufReader :: with_capacity ( 32 * 1024 * 1024 , File :: open ( fname) ?) ,
59
59
)
60
60
. map_err ( |e| PyValueError :: new_err ( format ! ( "deserialization error: {e:?}" ) ) ) ?;
@@ -90,7 +90,7 @@ impl HnswMap {
90
90
/// with a squared Euclidean distance metric.
91
91
#[ pyclass]
92
92
struct Hnsw {
93
- inner : instant_distance:: Hnsw < FloatArray , Vec < FloatArray > > ,
93
+ inner : instant_distance:: Hnsw < FloatArray , EuclidMetric , Vec < FloatArray > > ,
94
94
}
95
95
96
96
#[ pymethods]
@@ -111,7 +111,7 @@ impl Hnsw {
111
111
/// Load an index from the given file name
112
112
#[ staticmethod]
113
113
fn load ( fname : & str ) -> PyResult < Self > {
114
- let hnsw = bincode:: deserialize_from :: < _ , instant_distance:: Hnsw < _ , _ > > (
114
+ let hnsw = bincode:: deserialize_from :: < _ , instant_distance:: Hnsw < _ , _ , _ > > (
115
115
BufReader :: with_capacity ( 32 * 1024 * 1024 , File :: open ( fname) ?) ,
116
116
)
117
117
. map_err ( |e| PyValueError :: new_err ( format ! ( "deserialization error: {e:?}" ) ) ) ?;
@@ -144,7 +144,7 @@ impl Hnsw {
144
144
/// Search buffer and result set
145
145
#[ pyclass]
146
146
struct Search {
147
- inner : instant_distance:: Search < FloatArray > ,
147
+ inner : instant_distance:: Search < FloatArray , EuclidMetric > ,
148
148
cur : Option < ( HnswType , usize ) > ,
149
149
}
150
150
@@ -364,20 +364,23 @@ impl TryFrom<&PyAny> for FloatArray {
364
364
}
365
365
}
366
366
367
- impl Point for FloatArray {
368
- fn distance ( & self , rhs : & Self ) -> f32 {
367
+ #[ derive( Clone , Copy , Deserialize , Serialize ) ]
368
+ pub struct EuclidMetric ;
369
+
370
+ impl Metric < FloatArray > for EuclidMetric {
371
+ fn distance ( lhs : & FloatArray , rhs : & FloatArray ) -> f32 {
369
372
#[ cfg( target_arch = "x86_64" ) ]
370
373
{
371
374
use std:: arch:: x86_64:: {
372
375
_mm256_castps256_ps128, _mm256_extractf128_ps, _mm256_fmadd_ps, _mm256_load_ps,
373
376
_mm256_setzero_ps, _mm256_sub_ps, _mm_add_ps, _mm_add_ss, _mm_cvtss_f32,
374
377
_mm_fmadd_ps, _mm_load_ps, _mm_movehl_ps, _mm_shuffle_ps, _mm_sub_ps,
375
378
} ;
376
- debug_assert_eq ! ( self . 0 . len( ) % 8 , 4 ) ;
379
+ debug_assert_eq ! ( lhs . 0 . len( ) % 8 , 4 ) ;
377
380
378
381
unsafe {
379
382
let mut acc_8x = _mm256_setzero_ps ( ) ;
380
- for ( lh_slice, rh_slice) in self . 0 . chunks_exact ( 8 ) . zip ( rhs. 0 . chunks_exact ( 8 ) ) {
383
+ for ( lh_slice, rh_slice) in lhs . 0 . chunks_exact ( 8 ) . zip ( rhs. 0 . chunks_exact ( 8 ) ) {
381
384
let lh_8x = _mm256_load_ps ( lh_slice. as_ptr ( ) ) ;
382
385
let rh_8x = _mm256_load_ps ( rh_slice. as_ptr ( ) ) ;
383
386
let diff = _mm256_sub_ps ( lh_8x, rh_8x) ;
@@ -388,7 +391,7 @@ impl Point for FloatArray {
388
391
let right = _mm256_castps256_ps128 ( acc_8x) ; // lower half
389
392
acc_4x = _mm_add_ps ( acc_4x, right) ; // sum halves
390
393
391
- let lh_4x = _mm_load_ps ( self . 0 [ DIMENSIONS - 4 ..] . as_ptr ( ) ) ;
394
+ let lh_4x = _mm_load_ps ( lhs . 0 [ DIMENSIONS - 4 ..] . as_ptr ( ) ) ;
392
395
let rh_4x = _mm_load_ps ( rhs. 0 [ DIMENSIONS - 4 ..] . as_ptr ( ) ) ;
393
396
let diff = _mm_sub_ps ( lh_4x, rh_4x) ;
394
397
acc_4x = _mm_fmadd_ps ( diff, diff, acc_4x) ;
@@ -401,7 +404,7 @@ impl Point for FloatArray {
401
404
}
402
405
}
403
406
#[ cfg( not( target_arch = "x86_64" ) ) ]
404
- self . 0
407
+ lhs . 0
405
408
. iter ( )
406
409
. zip ( rhs. 0 . iter ( ) )
407
410
. map ( |( & a, & b) | ( a - b) . powi ( 2 ) )
0 commit comments