4
4
use std:: convert:: TryFrom ;
5
5
use std:: fs:: File ;
6
6
use std:: io:: { BufReader , BufWriter } ;
7
- use std:: iter:: FromIterator ;
7
+ use std:: iter:: { repeat, FromIterator } ;
8
+ use std:: ops:: Index ;
8
9
9
- use instant_distance:: Metric ;
10
+ use aligned_vec:: { AVec , ConstAlign } ;
11
+ use instant_distance:: { Len , Metric , PointId } ;
10
12
use pyo3:: conversion:: IntoPy ;
11
- use pyo3:: exceptions:: { PyTypeError , PyValueError } ;
13
+ use pyo3:: exceptions:: PyValueError ;
12
14
use pyo3:: types:: { PyList , PyModule , PyString } ;
13
15
use pyo3:: { pyclass, pymethods, pymodule} ;
14
16
use pyo3:: { Py , PyAny , PyErr , PyObject , PyRef , PyRefMut , PyResult , Python } ;
15
17
use serde:: { Deserialize , Serialize } ;
16
- use serde_big_array:: BigArray ;
17
18
18
19
#[ pymodule]
19
20
#[ pyo3( name = "instant_distance" ) ]
@@ -29,7 +30,7 @@ fn instant_distance_py(_: Python, m: &PyModule) -> PyResult<()> {
29
30
30
31
#[ pyclass]
31
32
struct HnswMap {
32
- inner : instant_distance:: HnswMap < FloatArray , EuclidMetric , MapValue , Vec < FloatArray > > ,
33
+ inner : instant_distance:: HnswMap < [ f32 ] , EuclidMetric , MapValue , PointStorage > ,
33
34
}
34
35
35
36
#[ pymethods]
@@ -39,24 +40,34 @@ impl HnswMap {
39
40
fn build ( points : & PyList , values : & PyList , config : & Config ) -> PyResult < Self > {
40
41
let points = points
41
42
. into_iter ( )
42
- . map ( FloatArray :: try_from)
43
+ . map ( |v| {
44
+ v. iter ( ) ?
45
+ . into_iter ( )
46
+ . map ( |x| x?. extract ( ) )
47
+ . collect :: < Result < Vec < _ > , PyErr > > ( )
48
+ } )
43
49
. collect :: < Result < Vec < _ > , PyErr > > ( ) ?;
44
50
45
51
let values = values
46
52
. into_iter ( )
47
53
. map ( MapValue :: try_from)
48
54
. collect :: < Result < Vec < _ > , PyErr > > ( ) ?;
49
55
50
- let hsnw_map = instant_distance:: Builder :: from ( config) . build ( points, values) ;
56
+ let hsnw_map = instant_distance:: Builder :: from ( config)
57
+ . build :: < Vec < _ > , [ f32 ] , EuclidMetric , MapValue , PointStorage > ( points, values) ;
51
58
Ok ( Self { inner : hsnw_map } )
52
59
}
53
60
54
61
/// Load an index from the given file name
55
62
#[ staticmethod]
56
63
fn load ( fname : & str ) -> PyResult < Self > {
57
- let hnsw_map = bincode:: deserialize_from :: < _ , instant_distance:: HnswMap < _ , _ , _ , _ > > (
58
- BufReader :: with_capacity ( 32 * 1024 * 1024 , File :: open ( fname) ?) ,
59
- )
64
+ let hnsw_map = bincode:: deserialize_from :: <
65
+ _ ,
66
+ instant_distance:: HnswMap < [ f32 ] , EuclidMetric , MapValue , PointStorage > ,
67
+ > ( BufReader :: with_capacity (
68
+ 32 * 1024 * 1024 ,
69
+ File :: open ( fname) ?,
70
+ ) )
60
71
. map_err ( |e| PyValueError :: new_err ( format ! ( "deserialization error: {e:?}" ) ) ) ?;
61
72
Ok ( Self { inner : hnsw_map } )
62
73
}
@@ -77,7 +88,7 @@ impl HnswMap {
77
88
///
78
89
/// For best performance, reusing `Search` objects is recommended.
79
90
fn search ( slf : Py < Self > , point : & PyAny , search : & mut Search , py : Python < ' _ > ) -> PyResult < ( ) > {
80
- let point = FloatArray :: try_from ( point) ?;
91
+ let point = try_avec_from_py ( point) ?;
81
92
let _ = slf. try_borrow ( py) ?. inner . search ( & point, & mut search. inner ) ;
82
93
search. cur = Some ( ( HnswType :: Map ( slf. clone_ref ( py) ) , 0 ) ) ;
83
94
Ok ( ( ) )
@@ -90,7 +101,7 @@ impl HnswMap {
90
101
/// with a squared Euclidean distance metric.
91
102
#[ pyclass]
92
103
struct Hnsw {
93
- inner : instant_distance:: Hnsw < FloatArray , EuclidMetric , Vec < FloatArray > > ,
104
+ inner : instant_distance:: Hnsw < [ f32 ] , EuclidMetric , PointStorage > ,
94
105
}
95
106
96
107
#[ pymethods]
@@ -100,20 +111,30 @@ impl Hnsw {
100
111
fn build ( input : & PyList , config : & Config ) -> PyResult < ( Self , Vec < u32 > ) > {
101
112
let points = input
102
113
. into_iter ( )
103
- . map ( FloatArray :: try_from)
114
+ . map ( |v| {
115
+ v. iter ( ) ?
116
+ . into_iter ( )
117
+ . map ( |x| x?. extract ( ) )
118
+ . collect :: < Result < Vec < _ > , PyErr > > ( )
119
+ } )
104
120
. collect :: < Result < Vec < _ > , PyErr > > ( ) ?;
105
121
106
- let ( inner, ids) = instant_distance:: Builder :: from ( config) . build_hnsw ( points) ;
122
+ let ( inner, ids) = instant_distance:: Builder :: from ( config)
123
+ . build_hnsw :: < Vec < f32 > , [ f32 ] , EuclidMetric , PointStorage > ( points) ;
107
124
let ids = Vec :: from_iter ( ids. into_iter ( ) . map ( |pid| pid. into_inner ( ) ) ) ;
108
125
Ok ( ( Self { inner } , ids) )
109
126
}
110
127
111
128
/// Load an index from the given file name
112
129
#[ staticmethod]
113
130
fn load ( fname : & str ) -> PyResult < Self > {
114
- let hnsw = bincode:: deserialize_from :: < _ , instant_distance:: Hnsw < _ , _ , _ > > (
115
- BufReader :: with_capacity ( 32 * 1024 * 1024 , File :: open ( fname) ?) ,
116
- )
131
+ let hnsw = bincode:: deserialize_from :: <
132
+ _ ,
133
+ instant_distance:: Hnsw < [ f32 ] , EuclidMetric , PointStorage > ,
134
+ > ( BufReader :: with_capacity (
135
+ 32 * 1024 * 1024 ,
136
+ File :: open ( fname) ?,
137
+ ) )
117
138
. map_err ( |e| PyValueError :: new_err ( format ! ( "deserialization error: {e:?}" ) ) ) ?;
118
139
Ok ( Self { inner : hnsw } )
119
140
}
@@ -134,7 +155,7 @@ impl Hnsw {
134
155
///
135
156
/// For best performance, reusing `Search` objects is recommended.
136
157
fn search ( slf : Py < Self > , point : & PyAny , search : & mut Search , py : Python < ' _ > ) -> PyResult < ( ) > {
137
- let point = FloatArray :: try_from ( point) ?;
158
+ let point = try_avec_from_py ( point) ?;
138
159
let _ = slf. try_borrow ( py) ?. inner . search ( & point, & mut search. inner ) ;
139
160
search. cur = Some ( ( HnswType :: Hnsw ( slf. clone_ref ( py) ) , 0 ) ) ;
140
161
Ok ( ( ) )
@@ -144,7 +165,7 @@ impl Hnsw {
144
165
/// Search buffer and result set
145
166
#[ pyclass]
146
167
struct Search {
147
- inner : instant_distance:: Search < FloatArray , EuclidMetric > ,
168
+ inner : instant_distance:: Search < [ f32 ] , EuclidMetric > ,
148
169
cur : Option < ( HnswType , usize ) > ,
149
170
}
150
171
@@ -345,42 +366,36 @@ impl Neighbor {
345
366
}
346
367
}
347
368
348
- #[ repr( align( 32 ) ) ]
349
- #[ derive( Clone , Deserialize , Serialize ) ]
350
- pub struct FloatArray ( #[ serde( with = "BigArray" ) ] pub [ f32 ; DIMENSIONS ] ) ;
351
-
352
- impl TryFrom < & PyAny > for FloatArray {
353
- type Error = PyErr ;
354
-
355
- fn try_from ( value : & PyAny ) -> Result < Self , Self :: Error > {
356
- let mut new = FloatArray ( [ 0.0 ; DIMENSIONS ] ) ;
357
- for ( i, val) in value. iter ( ) ?. enumerate ( ) {
358
- match i >= DIMENSIONS {
359
- true => return Err ( PyTypeError :: new_err ( "point array too long" ) ) ,
360
- false => new. 0 [ i] = val?. extract :: < f32 > ( ) ?,
361
- }
362
- }
363
- Ok ( new)
369
+ fn try_avec_from_py ( value : & PyAny ) -> Result < AVec < f32 , ConstAlign < ALIGNMENT > > , PyErr > {
370
+ let mut new = AVec :: new ( ALIGNMENT ) ;
371
+ for val in value. iter ( ) ? {
372
+ new. push ( val?. extract :: < f32 > ( ) ?) ;
364
373
}
374
+ for _ in 0 ..PointStorage :: padding ( new. len ( ) ) {
375
+ new. push ( 0.0 ) ;
376
+ }
377
+ Ok ( new)
365
378
}
366
379
367
380
#[ derive( Clone , Copy , Deserialize , Serialize ) ]
368
381
pub struct EuclidMetric ;
369
382
370
- impl Metric < FloatArray > for EuclidMetric {
371
- fn distance ( lhs : & FloatArray , rhs : & FloatArray ) -> f32 {
383
+ impl Metric < [ f32 ] > for EuclidMetric {
384
+ fn distance ( lhs : & [ f32 ] , rhs : & [ f32 ] ) -> f32 {
385
+ debug_assert_eq ! ( lhs. len( ) , rhs. len( ) ) ;
386
+
372
387
#[ cfg( target_arch = "x86_64" ) ]
373
388
{
374
389
use std:: arch:: x86_64:: {
375
390
_mm256_castps256_ps128, _mm256_extractf128_ps, _mm256_fmadd_ps, _mm256_load_ps,
376
391
_mm256_setzero_ps, _mm256_sub_ps, _mm_add_ps, _mm_add_ss, _mm_cvtss_f32,
377
- _mm_fmadd_ps , _mm_load_ps , _mm_movehl_ps, _mm_shuffle_ps, _mm_sub_ps ,
392
+ _mm_movehl_ps, _mm_shuffle_ps,
378
393
} ;
379
- debug_assert_eq ! ( lhs. 0 . len( ) % 8 , 4 ) ;
394
+ debug_assert_eq ! ( lhs. len( ) % 8 , 0 ) ;
380
395
381
396
unsafe {
382
397
let mut acc_8x = _mm256_setzero_ps ( ) ;
383
- for ( lh_slice, rh_slice) in lhs. 0 . chunks_exact ( 8 ) . zip ( rhs. 0 . chunks_exact ( 8 ) ) {
398
+ for ( lh_slice, rh_slice) in lhs. chunks_exact ( 8 ) . zip ( rhs. chunks_exact ( 8 ) ) {
384
399
let lh_8x = _mm256_load_ps ( lh_slice. as_ptr ( ) ) ;
385
400
let rh_8x = _mm256_load_ps ( rh_slice. as_ptr ( ) ) ;
386
401
let diff = _mm256_sub_ps ( lh_8x, rh_8x) ;
@@ -391,11 +406,6 @@ impl Metric<FloatArray> for EuclidMetric {
391
406
let right = _mm256_castps256_ps128 ( acc_8x) ; // lower half
392
407
acc_4x = _mm_add_ps ( acc_4x, right) ; // sum halves
393
408
394
- let lh_4x = _mm_load_ps ( lhs. 0 [ DIMENSIONS - 4 ..] . as_ptr ( ) ) ;
395
- let rh_4x = _mm_load_ps ( rhs. 0 [ DIMENSIONS - 4 ..] . as_ptr ( ) ) ;
396
- let diff = _mm_sub_ps ( lh_4x, rh_4x) ;
397
- acc_4x = _mm_fmadd_ps ( diff, diff, acc_4x) ;
398
-
399
409
let lower = _mm_movehl_ps ( acc_4x, acc_4x) ;
400
410
acc_4x = _mm_add_ps ( acc_4x, lower) ;
401
411
let upper = _mm_shuffle_ps ( acc_4x, acc_4x, 0x1 ) ;
@@ -412,6 +422,114 @@ impl Metric<FloatArray> for EuclidMetric {
412
422
}
413
423
}
414
424
425
+ #[ derive( Debug , Deserialize , Serialize ) ]
426
+ pub struct PointStorage {
427
+ point_len : usize ,
428
+ points_data : AVec < f32 > ,
429
+ }
430
+
431
+ impl PointStorage {
432
+ const fn padding ( len : usize ) -> usize {
433
+ let floats_per_alignment = ALIGNMENT / std:: mem:: size_of :: < f32 > ( ) ;
434
+ match len % floats_per_alignment {
435
+ 0 => 0 ,
436
+ floats_over_alignment => floats_per_alignment - floats_over_alignment,
437
+ }
438
+ }
439
+
440
+ pub fn iter ( & self ) -> impl Iterator < Item = & [ f32 ] > {
441
+ self . points_data . chunks_exact ( self . point_len )
442
+ }
443
+ }
444
+
445
+ impl Default for PointStorage {
446
+ fn default ( ) -> Self {
447
+ Self {
448
+ point_len : 1 ,
449
+ points_data : AVec :: new ( ALIGNMENT ) ,
450
+ }
451
+ }
452
+ }
453
+
454
+ impl Index < usize > for PointStorage {
455
+ type Output = [ f32 ] ;
456
+
457
+ fn index ( & self , index : usize ) -> & Self :: Output {
458
+ let raw_idx = index * self . point_len ;
459
+ & self . points_data [ raw_idx..( raw_idx + self . point_len ) ]
460
+ }
461
+ }
462
+
463
+ impl Index < PointId > for PointStorage {
464
+ type Output = [ f32 ] ;
465
+
466
+ fn index ( & self , index : PointId ) -> & Self :: Output {
467
+ self . index ( index. into_inner ( ) as usize )
468
+ }
469
+ }
470
+
471
+ impl From < Vec < Vec < f32 > > > for PointStorage {
472
+ fn from ( value : Vec < Vec < f32 > > ) -> Self {
473
+ if let Some ( point) = value. first ( ) {
474
+ let point_len = point. len ( ) ;
475
+ let padding = PointStorage :: padding ( point_len) ;
476
+ let mut points_data =
477
+ AVec :: with_capacity ( ALIGNMENT , value. len ( ) * ( point_len + padding) ) ;
478
+ for point in value {
479
+ // all points should have the same length
480
+ debug_assert_eq ! ( point. len( ) , point_len) ;
481
+ for v in point. into_iter ( ) . chain ( repeat ( 0.0 ) . take ( padding) ) {
482
+ points_data. push ( v) ;
483
+ }
484
+ }
485
+ Self {
486
+ point_len : point_len + padding,
487
+ points_data,
488
+ }
489
+ } else {
490
+ Default :: default ( )
491
+ }
492
+ }
493
+ }
494
+
495
+ impl Len for PointStorage {
496
+ fn len ( & self ) -> usize {
497
+ self . points_data . len ( ) / self . point_len
498
+ }
499
+ }
500
+
501
+ impl < ' a > IntoIterator for & ' a PointStorage {
502
+ type Item = & ' a [ f32 ] ;
503
+
504
+ type IntoIter = PointStorageIterator < ' a > ;
505
+
506
+ fn into_iter ( self ) -> Self :: IntoIter {
507
+ PointStorageIterator {
508
+ storage : self ,
509
+ next_idx : 0 ,
510
+ }
511
+ }
512
+ }
513
+
514
+ pub struct PointStorageIterator < ' a > {
515
+ storage : & ' a PointStorage ,
516
+ next_idx : usize ,
517
+ }
518
+
519
+ impl < ' a > Iterator for PointStorageIterator < ' a > {
520
+ type Item = & ' a [ f32 ] ;
521
+
522
+ fn next ( & mut self ) -> Option < Self :: Item > {
523
+ if self . next_idx < self . storage . len ( ) {
524
+ let result = & self . storage [ self . next_idx ] ;
525
+ self . next_idx += 1 ;
526
+ Some ( result)
527
+ } else {
528
+ None
529
+ }
530
+ }
531
+ }
532
+
415
533
#[ derive( Clone , Deserialize , Serialize ) ]
416
534
enum MapValue {
417
535
String ( String ) ,
@@ -433,4 +551,4 @@ impl IntoPy<Py<PyAny>> for &'_ MapValue {
433
551
}
434
552
}
435
553
436
- const DIMENSIONS : usize = 300 ;
554
+ const ALIGNMENT : usize = 32 ;
0 commit comments