Skip to content

Commit ac42574

Browse files
committed
Replace Euclid metric implementation
1 parent bca31ad commit ac42574

File tree

9 files changed

+343
-66
lines changed

9 files changed

+343
-66
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[workspace]
2-
members = ["instant-distance", "instant-distance-py"]
2+
members = ["distance-metrics", "instant-distance", "instant-distance-py"]
33

44
[profile.bench]
55
debug = true

distance-metrics/Cargo.toml

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
[package]
2+
name = "distance-metrics"
3+
version = "0.6.0"
4+
license = "MIT OR Apache-2.0"
5+
edition = "2021"
6+
rust-version = "1.58"
7+
homepage = "https://github.com/InstantDomain/instant-distance"
8+
repository = "https://github.com/InstantDomain/instant-distance"
9+
documentation = "https://docs.rs/instant-distance"
10+
workspace = ".."
11+
readme = "../README.md"
12+
13+
[dependencies]
14+
15+
[dev-dependencies]
16+
bencher = "0.1.5"
17+
rand = { version = "0.8", features = ["small_rng"] }
18+
19+
[[bench]]
20+
name = "all"
21+
harness = false

distance-metrics/benches/all.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
use bencher::{benchmark_group, benchmark_main, Bencher};
2+
3+
use distance_metrics::{EuclidMetric, Metric};
4+
use rand::{rngs::StdRng, Rng, SeedableRng};
5+
6+
benchmark_main!(benches);
7+
benchmark_group!(benches, legacy, non_simd, metric::<EuclidMetric>,);
8+
9+
fn legacy(bench: &mut Bencher) {
10+
let mut rng = StdRng::seed_from_u64(SEED);
11+
let point_a = distance_metrics::FloatArray([rng.gen(); 300]);
12+
let point_b = distance_metrics::FloatArray([rng.gen(); 300]);
13+
14+
bench.iter(|| distance_metrics::legacy_distance(&point_a, &point_b))
15+
}
16+
17+
fn non_simd(bench: &mut Bencher) {
18+
let mut rng = StdRng::seed_from_u64(SEED);
19+
let point_a = [rng.gen(); 300];
20+
let point_b = [rng.gen(); 300];
21+
22+
bench.iter(|| distance_metrics::euclid_distance(&point_a, &point_b))
23+
}
24+
25+
fn metric<M: Metric>(bench: &mut Bencher) {
26+
let mut rng = StdRng::seed_from_u64(SEED);
27+
let point_a = [rng.gen(); 300];
28+
let point_b = [rng.gen(); 300];
29+
30+
bench.iter(|| M::distance(&point_a, &point_b))
31+
}
32+
33+
const SEED: u64 = 123456789;

distance-metrics/src/lib.rs

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2+
pub mod simd_sse;
3+
4+
#[cfg(target_arch = "x86_64")]
5+
pub mod simd_avx;
6+
7+
#[cfg(target_arch = "aarch64")]
8+
pub mod simd_neon;
9+
10+
/// Defines how to compare vectors
11+
pub trait Metric {
12+
/// Greater the value - more distant the vectors
13+
fn distance(v1: &[f32], v2: &[f32]) -> f32;
14+
}
15+
16+
#[cfg(target_arch = "x86_64")]
17+
const MIN_DIM_SIZE_AVX: usize = 32;
18+
19+
#[cfg(any(
20+
target_arch = "x86",
21+
target_arch = "x86_64",
22+
all(target_arch = "aarch64", target_feature = "neon")
23+
))]
24+
const MIN_DIM_SIZE_SIMD: usize = 16;
25+
26+
#[derive(Clone, Copy)]
27+
pub struct EuclidMetric {}
28+
29+
impl Metric for EuclidMetric {
30+
fn distance(v1: &[f32], v2: &[f32]) -> f32 {
31+
#[cfg(target_arch = "x86_64")]
32+
{
33+
if is_x86_feature_detected!("avx")
34+
&& is_x86_feature_detected!("fma")
35+
&& v1.len() >= MIN_DIM_SIZE_AVX
36+
{
37+
return unsafe { simd_avx::euclid_distance_avx(v1, v2) };
38+
}
39+
}
40+
41+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
42+
{
43+
if is_x86_feature_detected!("sse") && v1.len() >= MIN_DIM_SIZE_SIMD {
44+
return unsafe { simd_sse::euclid_distance_sse(v1, v2) };
45+
}
46+
}
47+
48+
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
49+
{
50+
if std::arch::is_aarch64_feature_detected!("neon") && v1.len() >= MIN_DIM_SIZE_SIMD {
51+
return unsafe { simple_neon::euclid_distance_neon(v1, v2) };
52+
}
53+
}
54+
55+
euclid_distance(v1, v2)
56+
}
57+
}
58+
59+
pub fn euclid_distance(v1: &[f32], v2: &[f32]) -> f32 {
60+
let s: f32 = v1
61+
.iter()
62+
.copied()
63+
.zip(v2.iter().copied())
64+
.map(|(a, b)| (a - b).powi(2))
65+
.sum();
66+
s.abs().sqrt()
67+
}
68+
69+
pub fn legacy_distance(lhs: &FloatArray, rhs: &FloatArray) -> f32 {
70+
#[cfg(target_arch = "x86_64")]
71+
{
72+
use std::arch::x86_64::{
73+
_mm256_castps256_ps128, _mm256_extractf128_ps, _mm256_fmadd_ps, _mm256_load_ps,
74+
_mm256_setzero_ps, _mm256_sub_ps, _mm_add_ps, _mm_add_ss, _mm_cvtss_f32, _mm_fmadd_ps,
75+
_mm_load_ps, _mm_movehl_ps, _mm_shuffle_ps, _mm_sub_ps,
76+
};
77+
debug_assert_eq!(lhs.0.len() % 8, 4);
78+
79+
unsafe {
80+
let mut acc_8x = _mm256_setzero_ps();
81+
for (lh_slice, rh_slice) in lhs.0.chunks_exact(8).zip(rhs.0.chunks_exact(8)) {
82+
let lh_8x = _mm256_load_ps(lh_slice.as_ptr());
83+
let rh_8x = _mm256_load_ps(rh_slice.as_ptr());
84+
let diff = _mm256_sub_ps(lh_8x, rh_8x);
85+
acc_8x = _mm256_fmadd_ps(diff, diff, acc_8x);
86+
}
87+
88+
let mut acc_4x = _mm256_extractf128_ps(acc_8x, 1); // upper half
89+
let right = _mm256_castps256_ps128(acc_8x); // lower half
90+
acc_4x = _mm_add_ps(acc_4x, right); // sum halves
91+
92+
let lh_4x = _mm_load_ps(lhs.0[DIMENSIONS - 4..].as_ptr());
93+
let rh_4x = _mm_load_ps(rhs.0[DIMENSIONS - 4..].as_ptr());
94+
let diff = _mm_sub_ps(lh_4x, rh_4x);
95+
acc_4x = _mm_fmadd_ps(diff, diff, acc_4x);
96+
97+
let lower = _mm_movehl_ps(acc_4x, acc_4x);
98+
acc_4x = _mm_add_ps(acc_4x, lower);
99+
let upper = _mm_shuffle_ps(acc_4x, acc_4x, 0x1);
100+
acc_4x = _mm_add_ss(acc_4x, upper);
101+
_mm_cvtss_f32(acc_4x)
102+
}
103+
}
104+
#[cfg(not(target_arch = "x86_64"))]
105+
lhs.0
106+
.iter()
107+
.zip(rhs.0.iter())
108+
.map(|(&a, &b)| (a - b).powi(2))
109+
.sum::<f32>()
110+
}
111+
112+
#[repr(align(32))]
113+
pub struct FloatArray(pub [f32; DIMENSIONS]);
114+
115+
const DIMENSIONS: usize = 300;

distance-metrics/src/simd_avx.rs

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
use std::arch::x86_64::*;
2+
3+
#[target_feature(enable = "avx")]
4+
#[target_feature(enable = "fma")]
5+
unsafe fn hsum256_ps_avx(x: __m256) -> f32 {
6+
let x128: __m128 = _mm_add_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x));
7+
let x64: __m128 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128));
8+
let x32: __m128 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55));
9+
_mm_cvtss_f32(x32)
10+
}
11+
12+
#[target_feature(enable = "avx")]
13+
#[target_feature(enable = "fma")]
14+
pub(crate) unsafe fn euclid_distance_avx(v1: &[f32], v2: &[f32]) -> f32 {
15+
let n = v1.len();
16+
let m = n - (n % 32);
17+
let mut ptr1: *const f32 = v1.as_ptr();
18+
let mut ptr2: *const f32 = v2.as_ptr();
19+
let mut sum256_1: __m256 = _mm256_setzero_ps();
20+
let mut sum256_2: __m256 = _mm256_setzero_ps();
21+
let mut sum256_3: __m256 = _mm256_setzero_ps();
22+
let mut sum256_4: __m256 = _mm256_setzero_ps();
23+
let mut i: usize = 0;
24+
while i < m {
25+
let sub256_1: __m256 =
26+
_mm256_sub_ps(_mm256_loadu_ps(ptr1.add(0)), _mm256_loadu_ps(ptr2.add(0)));
27+
sum256_1 = _mm256_fmadd_ps(sub256_1, sub256_1, sum256_1);
28+
29+
let sub256_2: __m256 =
30+
_mm256_sub_ps(_mm256_loadu_ps(ptr1.add(8)), _mm256_loadu_ps(ptr2.add(8)));
31+
sum256_2 = _mm256_fmadd_ps(sub256_2, sub256_2, sum256_2);
32+
33+
let sub256_3: __m256 =
34+
_mm256_sub_ps(_mm256_loadu_ps(ptr1.add(16)), _mm256_loadu_ps(ptr2.add(16)));
35+
sum256_3 = _mm256_fmadd_ps(sub256_3, sub256_3, sum256_3);
36+
37+
let sub256_4: __m256 =
38+
_mm256_sub_ps(_mm256_loadu_ps(ptr1.add(24)), _mm256_loadu_ps(ptr2.add(24)));
39+
sum256_4 = _mm256_fmadd_ps(sub256_4, sub256_4, sum256_4);
40+
41+
ptr1 = ptr1.add(32);
42+
ptr2 = ptr2.add(32);
43+
i += 32;
44+
}
45+
46+
let mut result = hsum256_ps_avx(sum256_1)
47+
+ hsum256_ps_avx(sum256_2)
48+
+ hsum256_ps_avx(sum256_3)
49+
+ hsum256_ps_avx(sum256_4);
50+
for i in 0..n - m {
51+
result += (*ptr1.add(i) - *ptr2.add(i)).powi(2);
52+
}
53+
result.abs().sqrt()
54+
}

distance-metrics/src/simd_neon.rs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#[cfg(target_feature = "neon")]
2+
use std::arch::aarch64::*;
3+
4+
#[cfg(target_feature = "neon")]
5+
pub(crate) unsafe fn euclid_distance_neon(v1: &[f32], v2: &[f32]) -> f32 {
6+
let n = v1.len();
7+
let m = n - (n % 16);
8+
let mut ptr1: *const f32 = v1.as_ptr();
9+
let mut ptr2: *const f32 = v2.as_ptr();
10+
let mut sum1 = vdupq_n_f32(0.);
11+
let mut sum2 = vdupq_n_f32(0.);
12+
let mut sum3 = vdupq_n_f32(0.);
13+
let mut sum4 = vdupq_n_f32(0.);
14+
15+
let mut i: usize = 0;
16+
while i < m {
17+
let sub1 = vsubq_f32(vld1q_f32(ptr1), vld1q_f32(ptr2));
18+
sum1 = vfmaq_f32(sum1, sub1, sub1);
19+
20+
let sub2 = vsubq_f32(vld1q_f32(ptr1.add(4)), vld1q_f32(ptr2.add(4)));
21+
sum2 = vfmaq_f32(sum2, sub2, sub2);
22+
23+
let sub3 = vsubq_f32(vld1q_f32(ptr1.add(8)), vld1q_f32(ptr2.add(8)));
24+
sum3 = vfmaq_f32(sum3, sub3, sub3);
25+
26+
let sub4 = vsubq_f32(vld1q_f32(ptr1.add(12)), vld1q_f32(ptr2.add(12)));
27+
sum4 = vfmaq_f32(sum4, sub4, sub4);
28+
29+
ptr1 = ptr1.add(16);
30+
ptr2 = ptr2.add(16);
31+
i += 16;
32+
}
33+
let mut result = vaddvq_f32(sum1) + vaddvq_f32(sum2) + vaddvq_f32(sum3) + vaddvq_f32(sum4);
34+
for i in 0..n - m {
35+
result += (*ptr1.add(i) - *ptr2.add(i)).powi(2);
36+
}
37+
result.abs().sqrt()
38+
}

distance-metrics/src/simd_sse.rs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#[cfg(target_arch = "x86")]
2+
use std::arch::x86::*;
3+
#[cfg(target_arch = "x86_64")]
4+
use std::arch::x86_64::*;
5+
6+
#[target_feature(enable = "sse")]
7+
unsafe fn hsum128_ps_sse(x: __m128) -> f32 {
8+
let x64: __m128 = _mm_add_ps(x, _mm_movehl_ps(x, x));
9+
let x32: __m128 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55));
10+
_mm_cvtss_f32(x32)
11+
}
12+
13+
#[target_feature(enable = "sse")]
14+
pub(crate) unsafe fn euclid_distance_sse(v1: &[f32], v2: &[f32]) -> f32 {
15+
let n = v1.len();
16+
let m = n - (n % 16);
17+
let mut ptr1: *const f32 = v1.as_ptr();
18+
let mut ptr2: *const f32 = v2.as_ptr();
19+
let mut sum128_1: __m128 = _mm_setzero_ps();
20+
let mut sum128_2: __m128 = _mm_setzero_ps();
21+
let mut sum128_3: __m128 = _mm_setzero_ps();
22+
let mut sum128_4: __m128 = _mm_setzero_ps();
23+
let mut i: usize = 0;
24+
while i < m {
25+
let sub128_1 = _mm_sub_ps(_mm_loadu_ps(ptr1), _mm_loadu_ps(ptr2));
26+
sum128_1 = _mm_add_ps(_mm_mul_ps(sub128_1, sub128_1), sum128_1);
27+
28+
let sub128_2 = _mm_sub_ps(_mm_loadu_ps(ptr1.add(4)), _mm_loadu_ps(ptr2.add(4)));
29+
sum128_2 = _mm_add_ps(_mm_mul_ps(sub128_2, sub128_2), sum128_2);
30+
31+
let sub128_3 = _mm_sub_ps(_mm_loadu_ps(ptr1.add(8)), _mm_loadu_ps(ptr2.add(8)));
32+
sum128_3 = _mm_add_ps(_mm_mul_ps(sub128_3, sub128_3), sum128_3);
33+
34+
let sub128_4 = _mm_sub_ps(_mm_loadu_ps(ptr1.add(12)), _mm_loadu_ps(ptr2.add(12)));
35+
sum128_4 = _mm_add_ps(_mm_mul_ps(sub128_4, sub128_4), sum128_4);
36+
37+
ptr1 = ptr1.add(16);
38+
ptr2 = ptr2.add(16);
39+
i += 16;
40+
}
41+
42+
let mut result = hsum128_ps_sse(sum128_1)
43+
+ hsum128_ps_sse(sum128_2)
44+
+ hsum128_ps_sse(sum128_3)
45+
+ hsum128_ps_sse(sum128_4);
46+
for i in 0..n - m {
47+
result += (*ptr1.add(i) - *ptr2.add(i)).powi(2);
48+
}
49+
result.abs().sqrt()
50+
}

instant-distance-py/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ crate-type = ["cdylib"]
1616

1717
[dependencies]
1818
bincode = "1.3.1"
19+
distance-metrics = { version = "0.6", path = "../distance-metrics" }
1920
instant-distance = { version = "0.6", path = "../instant-distance", features = ["with-serde"] }
2021
pyo3 = { version = "0.18.0", features = ["extension-module"] }
2122
serde = { version = "1", features = ["derive"] }
22-
serde-big-array = "0.4.1"

0 commit comments

Comments
 (0)