@@ -24,41 +24,79 @@ fn instant_distance_py(_: Python, m: &PyModule) -> PyResult<()> {
24
24
m. add_class :: < Search > ( ) ?;
25
25
m. add_class :: < Hnsw > ( ) ?;
26
26
m. add_class :: < HnswMap > ( ) ?;
27
+ m. add_class :: < DistanceMetric > ( ) ?;
27
28
Ok ( ( ) )
28
29
}
29
30
31
+ #[ pyclass]
32
+ #[ derive( Copy , Clone ) ]
33
+ enum DistanceMetric {
34
+ Euclid ,
35
+ Cosine ,
36
+ }
37
+
38
+ impl Default for DistanceMetric {
39
+ fn default ( ) -> Self {
40
+ Self :: Euclid
41
+ }
42
+ }
43
+
44
+ // Helper macro for dispatching to inner implementation
45
+ macro_rules! impl_for_each_hnsw_with_metric {
46
+ ( $type: ident, $instance: expr, $inner: ident, $( $tokens: tt) +) => {
47
+ match $instance {
48
+ $type:: Euclid ( $inner) => {
49
+ $( $tokens) +
50
+ }
51
+ $type:: Cosine ( $inner) => {
52
+ $( $tokens) +
53
+ }
54
+ }
55
+ } ;
56
+ }
57
+
30
58
#[ pyclass]
31
59
struct HnswMap {
32
- inner : instant_distance:: HnswMap < FloatArray , MapValue > ,
60
+ inner : HnswMapWithMetric ,
61
+ }
62
+
63
+ #[ derive( Deserialize , Serialize ) ]
64
+ enum HnswMapWithMetric {
65
+ Euclid ( instant_distance:: HnswMap < FloatArray , MapValue > ) ,
66
+ Cosine ( instant_distance:: HnswMap < FloatArray , MapValue > ) ,
33
67
}
34
68
35
69
#[ pymethods]
36
70
impl HnswMap {
37
71
/// Build the index
38
72
#[ staticmethod]
39
73
fn build ( points : & PyList , values : & PyList , config : & Config ) -> PyResult < Self > {
40
- let points = points
41
- . into_iter ( )
42
- . map ( FloatArray :: try_from)
43
- . collect :: < Result < Vec < _ > , PyErr > > ( ) ?;
44
-
45
74
let values = values
46
75
. into_iter ( )
47
76
. map ( MapValue :: try_from)
48
77
. collect :: < Result < Vec < _ > , PyErr > > ( ) ?;
49
-
50
- let hsnw_map = instant_distance:: Builder :: from ( config) . build ( points, values) ;
51
- Ok ( Self { inner : hsnw_map } )
78
+ let builder = instant_distance:: Builder :: from ( config) ;
79
+ let inner = match config. distance_metric {
80
+ DistanceMetric :: Euclid => {
81
+ let points = FloatArray :: try_from_pylist ( points) ?;
82
+ HnswMapWithMetric :: Euclid ( builder. build ( points, values) )
83
+ }
84
+ DistanceMetric :: Cosine => {
85
+ let points = FloatArray :: try_from_pylist ( points) ?;
86
+ HnswMapWithMetric :: Cosine ( builder. build ( points, values) )
87
+ }
88
+ } ;
89
+ Ok ( Self { inner } )
52
90
}
53
91
54
92
/// Load an index from the given file name
55
93
#[ staticmethod]
56
94
fn load ( fname : & str ) -> PyResult < Self > {
57
- let hnsw_map =
58
- bincode :: deserialize_from :: < _ , instant_distance :: HnswMap < FloatArray , MapValue > > (
59
- BufReader :: with_capacity ( 32 * 1024 * 1024 , File :: open ( fname) ?) ,
60
- )
61
- . map_err ( |e| PyValueError :: new_err ( format ! ( "deserialization error: {e:?}" ) ) ) ?;
95
+ let hnsw_map = bincode :: deserialize_from :: < _ , HnswMapWithMetric > ( BufReader :: with_capacity (
96
+ 32 * 1024 * 1024 ,
97
+ File :: open ( fname) ?,
98
+ ) )
99
+ . map_err ( |e| PyValueError :: new_err ( format ! ( "deserialization error: {e:?}" ) ) ) ?;
62
100
Ok ( Self { inner : hnsw_map } )
63
101
}
64
102
@@ -78,8 +116,10 @@ impl HnswMap {
78
116
///
79
117
/// For best performance, reusing `Search` objects is recommended.
80
118
fn search ( slf : Py < Self > , point : & PyAny , search : & mut Search , py : Python < ' _ > ) -> PyResult < ( ) > {
81
- let point = FloatArray :: try_from ( point) ?;
82
- let _ = slf. try_borrow ( py) ?. inner . search ( & point, & mut search. inner ) ;
119
+ impl_for_each_hnsw_with_metric ! ( HnswMapWithMetric , & slf. try_borrow( py) ?. inner, hnsw, {
120
+ let point = FloatArray :: try_from( point) ?;
121
+ let _ = hnsw. search( & point, & mut search. inner) ;
122
+ } ) ;
83
123
search. cur = Some ( ( HnswType :: Map ( slf. clone_ref ( py) ) , 0 ) ) ;
84
124
Ok ( ( ) )
85
125
}
@@ -91,30 +131,44 @@ impl HnswMap {
91
131
/// with a squared Euclidean distance metric.
92
132
#[ pyclass]
93
133
struct Hnsw {
94
- inner : instant_distance:: Hnsw < FloatArray > ,
134
+ inner : HnswWithMetric ,
135
+ }
136
+
137
+ #[ derive( Deserialize , Serialize ) ]
138
+ enum HnswWithMetric {
139
+ Euclid ( instant_distance:: Hnsw < FloatArray > ) ,
140
+ Cosine ( instant_distance:: Hnsw < FloatArray > ) ,
95
141
}
96
142
97
143
#[ pymethods]
98
144
impl Hnsw {
99
145
/// Build the index
100
146
#[ staticmethod]
101
147
fn build ( input : & PyList , config : & Config ) -> PyResult < ( Self , Vec < u32 > ) > {
102
- let points = input
103
- . into_iter ( )
104
- . map ( FloatArray :: try_from)
105
- . collect :: < Result < Vec < _ > , PyErr > > ( ) ?;
106
-
107
- let ( inner, ids) = instant_distance:: Builder :: from ( config) . build_hnsw ( points) ;
148
+ let builder = instant_distance:: Builder :: from ( config) ;
149
+ let ( inner, ids) = match config. distance_metric {
150
+ DistanceMetric :: Euclid => {
151
+ let points = FloatArray :: try_from_pylist ( input) ?;
152
+ let ( hnsw, ids) = builder. build_hnsw ( points) ;
153
+ ( HnswWithMetric :: Euclid ( hnsw) , ids)
154
+ }
155
+ DistanceMetric :: Cosine => {
156
+ let points = FloatArray :: try_from_pylist ( input) ?;
157
+ let ( hnsw, ids) = builder. build_hnsw ( points) ;
158
+ ( HnswWithMetric :: Cosine ( hnsw) , ids)
159
+ }
160
+ } ;
108
161
let ids = Vec :: from_iter ( ids. into_iter ( ) . map ( |pid| pid. into_inner ( ) ) ) ;
109
162
Ok ( ( Self { inner } , ids) )
110
163
}
111
164
112
165
/// Load an index from the given file name
113
166
#[ staticmethod]
114
167
fn load ( fname : & str ) -> PyResult < Self > {
115
- let hnsw = bincode:: deserialize_from :: < _ , instant_distance:: Hnsw < FloatArray > > (
116
- BufReader :: with_capacity ( 32 * 1024 * 1024 , File :: open ( fname) ?) ,
117
- )
168
+ let hnsw = bincode:: deserialize_from :: < _ , HnswWithMetric > ( BufReader :: with_capacity (
169
+ 32 * 1024 * 1024 ,
170
+ File :: open ( fname) ?,
171
+ ) )
118
172
. map_err ( |e| PyValueError :: new_err ( format ! ( "deserialization error: {e:?}" ) ) ) ?;
119
173
Ok ( Self { inner : hnsw } )
120
174
}
@@ -135,8 +189,10 @@ impl Hnsw {
135
189
///
136
190
/// For best performance, reusing `Search` objects is recommended.
137
191
fn search ( slf : Py < Self > , point : & PyAny , search : & mut Search , py : Python < ' _ > ) -> PyResult < ( ) > {
138
- let point = FloatArray :: try_from ( point) ?;
139
- let _ = slf. try_borrow ( py) ?. inner . search ( & point, & mut search. inner ) ;
192
+ impl_for_each_hnsw_with_metric ! ( HnswWithMetric , & slf. try_borrow( py) ?. inner, hnsw, {
193
+ let point = FloatArray :: try_from( point) ?;
194
+ let _ = hnsw. search( & point, & mut search. inner) ;
195
+ } ) ;
140
196
search. cur = Some ( ( HnswType :: Hnsw ( slf. clone_ref ( py) ) , 0 ) ) ;
141
197
Ok ( ( ) )
142
198
}
@@ -175,20 +231,24 @@ impl Search {
175
231
let neighbor = match & index {
176
232
HnswType :: Hnsw ( hnsw) => {
177
233
let hnsw = hnsw. as_ref ( py) . borrow ( ) ;
178
- let item = hnsw. inner . get ( idx, & slf. inner ) ;
179
- item. map ( |item| Neighbor {
180
- distance : item. distance ,
181
- pid : item. pid . into_inner ( ) ,
182
- value : py. None ( ) ,
234
+ impl_for_each_hnsw_with_metric ! ( HnswWithMetric , & hnsw. inner, hnsw, {
235
+ let item = hnsw. get( idx, & slf. inner) ;
236
+ item. map( |item| Neighbor {
237
+ distance: item. distance,
238
+ pid: item. pid. into_inner( ) ,
239
+ value: py. None ( ) ,
240
+ } )
183
241
} )
184
242
}
185
243
HnswType :: Map ( map) => {
186
244
let map = map. as_ref ( py) . borrow ( ) ;
187
- let item = map. inner . get ( idx, & slf. inner ) ;
188
- item. map ( |item| Neighbor {
189
- distance : item. distance ,
190
- pid : item. pid . into_inner ( ) ,
191
- value : item. value . into_py ( py) ,
245
+ impl_for_each_hnsw_with_metric ! ( HnswMapWithMetric , & map. inner, map, {
246
+ let item = map. get( idx, & slf. inner) ;
247
+ item. map( |item| Neighbor {
248
+ distance: item. distance,
249
+ pid: item. pid. into_inner( ) ,
250
+ value: item. value. into_py( py) ,
251
+ } )
192
252
} )
193
253
}
194
254
} ;
@@ -226,6 +286,11 @@ struct Config {
226
286
/// in order to get better results on clustered data points.
227
287
#[ pyo3( get, set) ]
228
288
heuristic : Option < Heuristic > ,
289
+ /// Distance metric to use
290
+ ///
291
+ /// Defaults to Euclidean distance
292
+ #[ pyo3( get, set) ]
293
+ distance_metric : DistanceMetric ,
229
294
}
230
295
231
296
#[ pymethods]
@@ -235,12 +300,14 @@ impl Config {
235
300
let builder = instant_distance:: Builder :: default ( ) ;
236
301
let ( ef_search, ef_construction, ml, seed) = builder. into_parts ( ) ;
237
302
let heuristic = Some ( Heuristic :: default ( ) ) ;
303
+ let distance_metric = DistanceMetric :: default ( ) ;
238
304
Self {
239
305
ef_search,
240
306
ef_construction,
241
307
ml,
242
308
seed,
243
309
heuristic,
310
+ distance_metric,
244
311
}
245
312
}
246
313
}
@@ -253,6 +320,7 @@ impl From<&Config> for instant_distance::Builder {
253
320
ml,
254
321
seed,
255
322
heuristic,
323
+ distance_metric : _,
256
324
} = * py;
257
325
Self :: default ( )
258
326
. ef_search ( ef_search)
@@ -350,6 +418,12 @@ impl Neighbor {
350
418
#[ derive( Clone , Deserialize , Serialize ) ]
351
419
struct FloatArray ( #[ serde( with = "BigArray" ) ] [ f32 ; DIMENSIONS ] ) ;
352
420
421
+ impl FloatArray {
422
+ fn try_from_pylist ( list : & PyList ) -> Result < Vec < Self > , PyErr > {
423
+ list. into_iter ( ) . map ( FloatArray :: try_from) . collect ( )
424
+ }
425
+ }
426
+
353
427
impl TryFrom < & PyAny > for FloatArray {
354
428
type Error = PyErr ;
355
429
0 commit comments