@@ -6,14 +6,14 @@ use std::fs::File;
6
6
use std:: io:: { BufReader , BufWriter } ;
7
7
use std:: iter:: FromIterator ;
8
8
9
+ use aligned_vec:: AVec ;
9
10
use instant_distance:: Point ;
10
11
use pyo3:: conversion:: IntoPy ;
11
- use pyo3:: exceptions:: { PyTypeError , PyValueError } ;
12
+ use pyo3:: exceptions:: PyValueError ;
12
13
use pyo3:: types:: { PyList , PyModule , PyString } ;
13
14
use pyo3:: { pyclass, pymethods, pymodule} ;
14
15
use pyo3:: { Py , PyAny , PyErr , PyObject , PyRef , PyRefMut , PyResult , Python } ;
15
16
use serde:: { Deserialize , Serialize } ;
16
- use serde_big_array:: BigArray ;
17
17
18
18
#[ pymodule]
19
19
#[ pyo3( name = "instant_distance" ) ]
@@ -87,8 +87,7 @@ impl HnswMap {
87
87
88
88
/// An instance of hierarchical navigable small worlds
89
89
///
90
- /// For now, this is specialized to only support 300-element (32-bit) float vectors
91
- /// with a squared Euclidean distance metric.
90
+ /// For now, this uses a squared Euclidean distance metric.
92
91
#[ pyclass]
93
92
struct Hnsw {
94
93
inner : instant_distance:: Hnsw < FloatArray > ,
@@ -346,35 +345,32 @@ impl Neighbor {
346
345
}
347
346
}
348
347
349
- #[ repr( align( 32 ) ) ]
350
348
#[ derive( Clone , Deserialize , Serialize ) ]
351
- struct FloatArray ( # [ serde ( with = "BigArray" ) ] [ f32 ; DIMENSIONS ] ) ;
349
+ struct FloatArray ( AVec < f32 > ) ;
352
350
353
351
impl TryFrom < & PyAny > for FloatArray {
354
352
type Error = PyErr ;
355
353
356
354
fn try_from ( value : & PyAny ) -> Result < Self , Self :: Error > {
357
- let mut new = FloatArray ( [ 0.0 ; DIMENSIONS ] ) ;
358
- for ( i, val) in value. iter ( ) ?. enumerate ( ) {
359
- match i >= DIMENSIONS {
360
- true => return Err ( PyTypeError :: new_err ( "point array too long" ) ) ,
361
- false => new. 0 [ i] = val?. extract :: < f32 > ( ) ?,
362
- }
355
+ let mut new = FloatArray ( AVec :: with_capacity ( 32 , value. len ( ) ?) ) ;
356
+ for val in value. iter ( ) ? {
357
+ new. 0 . push ( val?. extract ( ) ?) ;
363
358
}
364
359
Ok ( new)
365
360
}
366
361
}
367
362
368
363
impl Point for FloatArray {
369
364
fn distance ( & self , rhs : & Self ) -> f32 {
365
+ debug_assert_eq ! ( self . 0 . len( ) , rhs. 0 . len( ) ) ;
366
+
370
367
#[ cfg( target_arch = "x86_64" ) ]
371
368
{
372
369
use std:: arch:: x86_64:: {
373
370
_mm256_castps256_ps128, _mm256_extractf128_ps, _mm256_fmadd_ps, _mm256_load_ps,
374
371
_mm256_setzero_ps, _mm256_sub_ps, _mm_add_ps, _mm_add_ss, _mm_cvtss_f32,
375
372
_mm_fmadd_ps, _mm_load_ps, _mm_movehl_ps, _mm_shuffle_ps, _mm_sub_ps,
376
373
} ;
377
- debug_assert_eq ! ( self . 0 . len( ) % 8 , 4 ) ;
378
374
379
375
unsafe {
380
376
let mut acc_8x = _mm256_setzero_ps ( ) ;
@@ -389,16 +385,36 @@ impl Point for FloatArray {
389
385
let right = _mm256_castps256_ps128 ( acc_8x) ; // lower half
390
386
acc_4x = _mm_add_ps ( acc_4x, right) ; // sum halves
391
387
392
- let lh_4x = _mm_load_ps ( self . 0 [ DIMENSIONS - 4 ..] . as_ptr ( ) ) ;
393
- let rh_4x = _mm_load_ps ( rhs. 0 [ DIMENSIONS - 4 ..] . as_ptr ( ) ) ;
394
- let diff = _mm_sub_ps ( lh_4x, rh_4x) ;
395
- acc_4x = _mm_fmadd_ps ( diff, diff, acc_4x) ;
388
+ // count of already processed dimensions
389
+ let mut processed_count = self . 0 . len ( ) - self . 0 . len ( ) % 8 ;
390
+
391
+ if self . 0 . len ( ) % 8 >= 4 {
392
+ // there are 4+ dimensions to process
393
+ // let's process another 4 in a batch
394
+ let lh_4x = _mm_load_ps ( self . 0 [ processed_count..] . as_ptr ( ) ) ;
395
+ let rh_4x = _mm_load_ps ( rhs. 0 [ processed_count..] . as_ptr ( ) ) ;
396
+ let diff = _mm_sub_ps ( lh_4x, rh_4x) ;
397
+ acc_4x = _mm_fmadd_ps ( diff, diff, acc_4x) ;
398
+ processed_count += 4 ;
399
+ }
396
400
401
+ // sum up the registers
397
402
let lower = _mm_movehl_ps ( acc_4x, acc_4x) ;
398
403
acc_4x = _mm_add_ps ( acc_4x, lower) ;
399
404
let upper = _mm_shuffle_ps ( acc_4x, acc_4x, 0x1 ) ;
400
405
acc_4x = _mm_add_ss ( acc_4x, upper) ;
401
- _mm_cvtss_f32 ( acc_4x)
406
+ let mut distance = _mm_cvtss_f32 ( acc_4x) ;
407
+
408
+ // process the leftover dimensions (if any are left)
409
+ if processed_count < self . 0 . len ( ) {
410
+ distance += self . 0 [ processed_count..]
411
+ . iter ( )
412
+ . zip ( rhs. 0 [ processed_count..] . iter ( ) )
413
+ . map ( |( & a, & b) | ( a - b) . powi ( 2 ) )
414
+ . sum :: < f32 > ( )
415
+ }
416
+
417
+ distance
402
418
}
403
419
}
404
420
#[ cfg( not( target_arch = "x86_64" ) ) ]
@@ -430,5 +446,3 @@ impl IntoPy<Py<PyAny>> for &'_ MapValue {
430
446
}
431
447
}
432
448
}
433
-
434
- const DIMENSIONS : usize = 300 ;
0 commit comments