Skip to content

Commit 9d8c83f

Browse files
committed
Add Cosine distance metric
1 parent ac42574 commit 9d8c83f

File tree

7 files changed

+470
-8
lines changed

7 files changed

+470
-8
lines changed

Makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ test-python: instant-distance-py/test/instant_distance.so
88

99
bench-python: instant-distance-py/test/instant_distance.so
1010
PYTHONPATH=instant-distance-py/test/ python3 -m timeit -n 10 -s 'import random, instant_distance; points = [[random.random() for _ in range(300)] for _ in range(1024)]; config = instant_distance.Config()' 'instant_distance.Hnsw.build(points, config)'
11+
PYTHONPATH=instant-distance-py/test/ python3 -m timeit -n 10 -s 'import random, instant_distance; points = [[random.random() for _ in range(300)] for _ in range(1024)]; config = instant_distance.Config(); config.distance_metric = instant_distance.DistanceMetric.Cosine' 'instant_distance.Hnsw.build(points, config)'
1112

1213
clean:
1314
cargo clean

distance-metrics/benches/all.rs

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
11
use bencher::{benchmark_group, benchmark_main, Bencher};
22

3-
use distance_metrics::{EuclidMetric, Metric};
3+
use distance_metrics::{
4+
Metric, {CosineMetric, EuclidMetric},
5+
};
46
use rand::{rngs::StdRng, Rng, SeedableRng};
57

68
benchmark_main!(benches);
7-
benchmark_group!(benches, legacy, non_simd, metric::<EuclidMetric>,);
9+
benchmark_group!(
10+
benches,
11+
legacy,
12+
non_simd,
13+
metric::<EuclidMetric>,
14+
metric::<CosineMetric>
15+
);
816

917
fn legacy(bench: &mut Bencher) {
1018
let mut rng = StdRng::seed_from_u64(SEED);
@@ -24,8 +32,10 @@ fn non_simd(bench: &mut Bencher) {
2432

2533
fn metric<M: Metric>(bench: &mut Bencher) {
2634
let mut rng = StdRng::seed_from_u64(SEED);
27-
let point_a = [rng.gen(); 300];
28-
let point_b = [rng.gen(); 300];
35+
let mut point_a = [rng.gen(); 300];
36+
let mut point_b = [rng.gen(); 300];
37+
M::preprocess(&mut point_a);
38+
M::preprocess(&mut point_b);
2939

3040
bench.iter(|| M::distance(&point_a, &point_b))
3141
}

distance-metrics/src/lib.rs

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ pub mod simd_neon;
1111
pub trait Metric {
1212
/// Greater the value - more distant the vectors
1313
fn distance(v1: &[f32], v2: &[f32]) -> f32;
14+
15+
/// Necessary vector transformations performed before adding it to the collection (like normalization)
16+
fn preprocess(vector: &mut [f32]);
1417
}
1518

1619
#[cfg(target_arch = "x86_64")]
@@ -54,6 +57,72 @@ impl Metric for EuclidMetric {
5457

5558
euclid_distance(v1, v2)
5659
}
60+
61+
fn preprocess(_vector: &mut [f32]) {
62+
// no-op
63+
}
64+
}
65+
66+
#[derive(Clone, Copy)]
67+
pub struct CosineMetric {}
68+
69+
impl Metric for CosineMetric {
70+
fn distance(v1: &[f32], v2: &[f32]) -> f32 {
71+
#[cfg(target_arch = "x86_64")]
72+
{
73+
if is_x86_feature_detected!("avx")
74+
&& is_x86_feature_detected!("fma")
75+
&& v1.len() >= MIN_DIM_SIZE_AVX
76+
{
77+
return 1.0 - unsafe { simd_avx::dot_similarity_avx(v1, v2) };
78+
}
79+
}
80+
81+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
82+
{
83+
if is_x86_feature_detected!("sse") && v1.len() >= MIN_DIM_SIZE_SIMD {
84+
return 1.0 - unsafe { simd_sse::dot_similarity_sse(v1, v2) };
85+
}
86+
}
87+
88+
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
89+
{
90+
if std::arch::is_aarch64_feature_detected!("neon") && v1.len() >= MIN_DIM_SIZE_SIMD {
91+
return 1.0 - unsafe { simd_neon::dot_similarity_neon(v1, v2) };
92+
}
93+
}
94+
95+
1.0 - dot_similarity(v1, v2)
96+
}
97+
98+
fn preprocess(vector: &mut [f32]) {
99+
#[cfg(target_arch = "x86_64")]
100+
{
101+
if is_x86_feature_detected!("avx")
102+
&& is_x86_feature_detected!("fma")
103+
&& vector.len() >= MIN_DIM_SIZE_AVX
104+
{
105+
return unsafe { simd_avx::cosine_preprocess_avx(vector) };
106+
}
107+
}
108+
109+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
110+
{
111+
if is_x86_feature_detected!("sse") && vector.len() >= MIN_DIM_SIZE_SIMD {
112+
return unsafe { simd_sse::cosine_preprocess_sse(vector) };
113+
}
114+
}
115+
116+
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
117+
{
118+
if std::arch::is_aarch64_feature_detected!("neon") && vector.len() >= MIN_DIM_SIZE_SIMD
119+
{
120+
return unsafe { simd_neon::cosine_preprocess_neon(vector) };
121+
}
122+
}
123+
124+
cosine_preprocess(vector);
125+
}
57126
}
58127

59128
pub fn euclid_distance(v1: &[f32], v2: &[f32]) -> f32 {
@@ -66,6 +135,21 @@ pub fn euclid_distance(v1: &[f32], v2: &[f32]) -> f32 {
66135
s.abs().sqrt()
67136
}
68137

138+
pub fn cosine_preprocess(vector: &mut [f32]) {
139+
let mut length: f32 = vector.iter().map(|x| x * x).sum();
140+
if length < f32::EPSILON {
141+
return;
142+
}
143+
length = length.sqrt();
144+
for x in vector.iter_mut() {
145+
*x /= length;
146+
}
147+
}
148+
149+
pub fn dot_similarity(v1: &[f32], v2: &[f32]) -> f32 {
150+
v1.iter().zip(v2).map(|(a, b)| a * b).sum()
151+
}
152+
69153
pub fn legacy_distance(lhs: &FloatArray, rhs: &FloatArray) -> f32 {
70154
#[cfg(target_arch = "x86_64")]
71155
{

distance-metrics/src/simd_avx.rs

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,135 @@ pub(crate) unsafe fn euclid_distance_avx(v1: &[f32], v2: &[f32]) -> f32 {
5252
}
5353
result.abs().sqrt()
5454
}
55+
56+
#[target_feature(enable = "avx")]
57+
#[target_feature(enable = "fma")]
58+
pub(crate) unsafe fn cosine_preprocess_avx(vector: &mut [f32]) {
59+
let n = vector.len();
60+
let m = n - (n % 32);
61+
let mut ptr: *const f32 = vector.as_ptr();
62+
let mut sum256_1: __m256 = _mm256_setzero_ps();
63+
let mut sum256_2: __m256 = _mm256_setzero_ps();
64+
let mut sum256_3: __m256 = _mm256_setzero_ps();
65+
let mut sum256_4: __m256 = _mm256_setzero_ps();
66+
let mut i: usize = 0;
67+
while i < m {
68+
let m256_1 = _mm256_loadu_ps(ptr);
69+
sum256_1 = _mm256_fmadd_ps(m256_1, m256_1, sum256_1);
70+
71+
let m256_2 = _mm256_loadu_ps(ptr.add(8));
72+
sum256_2 = _mm256_fmadd_ps(m256_2, m256_2, sum256_2);
73+
74+
let m256_3 = _mm256_loadu_ps(ptr.add(16));
75+
sum256_3 = _mm256_fmadd_ps(m256_3, m256_3, sum256_3);
76+
77+
let m256_4 = _mm256_loadu_ps(ptr.add(24));
78+
sum256_4 = _mm256_fmadd_ps(m256_4, m256_4, sum256_4);
79+
80+
ptr = ptr.add(32);
81+
i += 32;
82+
}
83+
84+
let mut length = hsum256_ps_avx(sum256_1)
85+
+ hsum256_ps_avx(sum256_2)
86+
+ hsum256_ps_avx(sum256_3)
87+
+ hsum256_ps_avx(sum256_4);
88+
for i in 0..n - m {
89+
length += (*ptr.add(i)).powi(2);
90+
}
91+
if length < f32::EPSILON {
92+
return;
93+
}
94+
length = length.sqrt();
95+
for x in vector.iter_mut() {
96+
*x /= length;
97+
}
98+
}
99+
100+
#[target_feature(enable = "avx")]
101+
#[target_feature(enable = "fma")]
102+
pub(crate) unsafe fn dot_similarity_avx(v1: &[f32], v2: &[f32]) -> f32 {
103+
let n = v1.len();
104+
let m = n - (n % 32);
105+
let mut ptr1: *const f32 = v1.as_ptr();
106+
let mut ptr2: *const f32 = v2.as_ptr();
107+
let mut sum256_1: __m256 = _mm256_setzero_ps();
108+
let mut sum256_2: __m256 = _mm256_setzero_ps();
109+
let mut sum256_3: __m256 = _mm256_setzero_ps();
110+
let mut sum256_4: __m256 = _mm256_setzero_ps();
111+
let mut i: usize = 0;
112+
while i < m {
113+
sum256_1 = _mm256_fmadd_ps(_mm256_loadu_ps(ptr1), _mm256_loadu_ps(ptr2), sum256_1);
114+
sum256_2 = _mm256_fmadd_ps(
115+
_mm256_loadu_ps(ptr1.add(8)),
116+
_mm256_loadu_ps(ptr2.add(8)),
117+
sum256_2,
118+
);
119+
sum256_3 = _mm256_fmadd_ps(
120+
_mm256_loadu_ps(ptr1.add(16)),
121+
_mm256_loadu_ps(ptr2.add(16)),
122+
sum256_3,
123+
);
124+
sum256_4 = _mm256_fmadd_ps(
125+
_mm256_loadu_ps(ptr1.add(24)),
126+
_mm256_loadu_ps(ptr2.add(24)),
127+
sum256_4,
128+
);
129+
130+
ptr1 = ptr1.add(32);
131+
ptr2 = ptr2.add(32);
132+
i += 32;
133+
}
134+
135+
let mut result = hsum256_ps_avx(sum256_1)
136+
+ hsum256_ps_avx(sum256_2)
137+
+ hsum256_ps_avx(sum256_3)
138+
+ hsum256_ps_avx(sum256_4);
139+
140+
for i in 0..n - m {
141+
result += (*ptr1.add(i)) * (*ptr2.add(i));
142+
}
143+
result
144+
}
145+
146+
#[cfg(test)]
147+
mod tests {
148+
#[test]
149+
fn test_spaces_avx() {
150+
use super::*;
151+
use crate::*;
152+
153+
if is_x86_feature_detected!("avx") && is_x86_feature_detected!("fma") {
154+
let v1: Vec<f32> = vec![
155+
10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
156+
10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
157+
10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
158+
10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
159+
26., 27., 28., 29., 30., 31.,
160+
];
161+
let v2: Vec<f32> = vec![
162+
40., 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55.,
163+
10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
164+
10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
165+
10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
166+
56., 57., 58., 59., 60., 61.,
167+
];
168+
169+
let euclid_simd = unsafe { euclid_distance_avx(&v1, &v2) };
170+
let euclid = euclid_distance(&v1, &v2);
171+
assert_eq!(euclid_simd, euclid);
172+
173+
let dot_simd = unsafe { dot_similarity_avx(&v1, &v2) };
174+
let dot = dot_similarity(&v1, &v2);
175+
assert_eq!(dot_simd, dot);
176+
177+
let mut v1 = v1;
178+
let mut v1_copy = v1.clone();
179+
unsafe { cosine_preprocess_avx(&mut v1) };
180+
cosine_preprocess(&mut v1_copy);
181+
assert_eq!(v1, v1_copy);
182+
} else {
183+
println!("avx test skipped");
184+
}
185+
}
186+
}

distance-metrics/src/simd_neon.rs

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,108 @@ pub(crate) unsafe fn euclid_distance_neon(v1: &[f32], v2: &[f32]) -> f32 {
3636
}
3737
result.abs().sqrt()
3838
}
39+
40+
#[cfg(target_feature = "neon")]
41+
pub(crate) unsafe fn cosine_preprocess_neon(vector: &mut [f32]) {
42+
let n = vector.len();
43+
let m = n - (n % 16);
44+
let mut ptr: *const f32 = vector.as_ptr();
45+
let mut sum1 = vdupq_n_f32(0.);
46+
let mut sum2 = vdupq_n_f32(0.);
47+
let mut sum3 = vdupq_n_f32(0.);
48+
let mut sum4 = vdupq_n_f32(0.);
49+
50+
let mut i: usize = 0;
51+
while i < m {
52+
let d1 = vld1q_f32(ptr);
53+
sum1 = vfmaq_f32(sum1, d1, d1);
54+
55+
let d2 = vld1q_f32(ptr.add(4));
56+
sum2 = vfmaq_f32(sum2, d2, d2);
57+
58+
let d3 = vld1q_f32(ptr.add(8));
59+
sum3 = vfmaq_f32(sum3, d3, d3);
60+
61+
let d4 = vld1q_f32(ptr.add(12));
62+
sum4 = vfmaq_f32(sum4, d4, d4);
63+
64+
ptr = ptr.add(16);
65+
i += 16;
66+
}
67+
let mut length = vaddvq_f32(sum1) + vaddvq_f32(sum2) + vaddvq_f32(sum3) + vaddvq_f32(sum4);
68+
for v in vector.iter().take(n).skip(m) {
69+
length += v.powi(2);
70+
}
71+
if length < f32::EPSILON {
72+
return;
73+
}
74+
let length = length.sqrt();
75+
for x in vector.iter_mut() {
76+
*x /= length;
77+
}
78+
}
79+
80+
#[cfg(target_feature = "neon")]
81+
pub(crate) unsafe fn dot_similarity_neon(v1: &[f32], v2: &[f32]) -> f32 {
82+
let n = v1.len();
83+
let m = n - (n % 16);
84+
let mut ptr1: *const f32 = v1.as_ptr();
85+
let mut ptr2: *const f32 = v2.as_ptr();
86+
let mut sum1 = vdupq_n_f32(0.);
87+
let mut sum2 = vdupq_n_f32(0.);
88+
let mut sum3 = vdupq_n_f32(0.);
89+
let mut sum4 = vdupq_n_f32(0.);
90+
91+
let mut i: usize = 0;
92+
while i < m {
93+
sum1 = vfmaq_f32(sum1, vld1q_f32(ptr1), vld1q_f32(ptr2));
94+
sum2 = vfmaq_f32(sum2, vld1q_f32(ptr1.add(4)), vld1q_f32(ptr2.add(4)));
95+
sum3 = vfmaq_f32(sum3, vld1q_f32(ptr1.add(8)), vld1q_f32(ptr2.add(8)));
96+
sum4 = vfmaq_f32(sum4, vld1q_f32(ptr1.add(12)), vld1q_f32(ptr2.add(12)));
97+
ptr1 = ptr1.add(16);
98+
ptr2 = ptr2.add(16);
99+
i += 16;
100+
}
101+
let mut result = vaddvq_f32(sum1) + vaddvq_f32(sum2) + vaddvq_f32(sum3) + vaddvq_f32(sum4);
102+
for i in 0..n - m {
103+
result += (*ptr1.add(i)) * (*ptr2.add(i));
104+
}
105+
result
106+
}
107+
108+
#[cfg(test)]
109+
mod tests {
110+
#[cfg(target_feature = "neon")]
111+
#[test]
112+
fn test_spaces_neon() {
113+
use super::*;
114+
use crate::*;
115+
116+
if std::arch::is_aarch64_feature_detected!("neon") {
117+
let v1: Vec<f32> = vec![
118+
10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
119+
26., 27., 28., 29., 30., 31.,
120+
];
121+
let v2: Vec<f32> = vec![
122+
40., 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55.,
123+
56., 57., 58., 59., 60., 61.,
124+
];
125+
126+
let euclid_simd = unsafe { euclid_distance_neon(&v1, &v2) };
127+
let euclid = euclid_distance(&v1, &v2);
128+
assert_eq!(euclid_simd, euclid);
129+
130+
let dot_simd = unsafe { dot_similarity_neon(&v1, &v2) };
131+
let dot = dot_similarity(&v1, &v2);
132+
assert_eq!(dot_simd, dot);
133+
134+
let mut v1 = v1;
135+
let mut v1_copy = v1.clone();
136+
unsafe { cosine_preprocess_neon(&mut v1) };
137+
cosine_preprocess(&mut v1_copy);
138+
assert_eq!(v1, v1_copy);
139+
} else {
140+
println!("neon test skipped");
141+
}
142+
}
143+
}

0 commit comments

Comments
 (0)