-
-
Notifications
You must be signed in to change notification settings - Fork 30
Implement distance metric selection #35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
[workspace] | ||
members = ["instant-distance", "instant-distance-py"] | ||
members = ["distance-metrics", "instant-distance", "instant-distance-py"] | ||
|
||
[profile.bench] | ||
debug = true |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
[package] | ||
name = "distance-metrics" | ||
version = "0.6.0" | ||
license = "MIT OR Apache-2.0" | ||
edition = "2021" | ||
rust-version = "1.58" | ||
homepage = "https://github.com/InstantDomain/instant-distance" | ||
repository = "https://github.com/InstantDomain/instant-distance" | ||
documentation = "https://docs.rs/instant-distance" | ||
workspace = ".." | ||
readme = "../README.md" | ||
|
||
[dependencies] | ||
|
||
[dev-dependencies] | ||
bencher = "0.1.5" | ||
rand = { version = "0.8", features = ["small_rng"] } | ||
|
||
[[bench]] | ||
name = "all" | ||
harness = false |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
use bencher::{benchmark_group, benchmark_main, Bencher}; | ||
|
||
use distance_metrics::{EuclidMetric, Metric}; | ||
use rand::{rngs::StdRng, Rng, SeedableRng}; | ||
|
||
benchmark_main!(benches); | ||
benchmark_group!(benches, legacy, non_simd, metric::<EuclidMetric>,); | ||
|
||
fn legacy(bench: &mut Bencher) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would be curious to see some benchmarking results in this PR for the legacy approach vs your new implementations! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The most recent results:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So I think Did you run any tests to show that the platform-specific implementations have the same result (or close to it) as the non-SIMD implementation? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, the
(the numbers are higher because I'm on battery and the CPU is throttled) We could use that since it doesn't have dimensions limitation. Another idea would be to have the current metric (as for example |
||
let mut rng = StdRng::seed_from_u64(SEED); | ||
let point_a = distance_metrics::FloatArray([rng.gen(); 300]); | ||
let point_b = distance_metrics::FloatArray([rng.gen(); 300]); | ||
|
||
bench.iter(|| distance_metrics::legacy_distance(&point_a, &point_b)) | ||
} | ||
|
||
fn non_simd(bench: &mut Bencher) { | ||
let mut rng = StdRng::seed_from_u64(SEED); | ||
let point_a = [rng.gen(); 300]; | ||
let point_b = [rng.gen(); 300]; | ||
|
||
bench.iter(|| distance_metrics::euclid_distance(&point_a, &point_b)) | ||
} | ||
|
||
fn metric<M: Metric>(bench: &mut Bencher) { | ||
let mut rng = StdRng::seed_from_u64(SEED); | ||
let point_a = [rng.gen(); 300]; | ||
let point_b = [rng.gen(); 300]; | ||
|
||
bench.iter(|| M::distance(&point_a, &point_b)) | ||
} | ||
|
||
const SEED: u64 = 123456789; |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] | ||
pub mod simd_sse; | ||
|
||
#[cfg(target_arch = "x86_64")] | ||
pub mod simd_avx; | ||
|
||
#[cfg(target_arch = "aarch64")] | ||
pub mod simd_neon; | ||
|
||
/// Defines how to compare vectors | ||
pub trait Metric { | ||
/// Greater the value - more distant the vectors | ||
fn distance(v1: &[f32], v2: &[f32]) -> f32; | ||
} | ||
|
||
#[cfg(target_arch = "x86_64")] | ||
const MIN_DIM_SIZE_AVX: usize = 32; | ||
|
||
#[cfg(any( | ||
target_arch = "x86", | ||
target_arch = "x86_64", | ||
all(target_arch = "aarch64", target_feature = "neon") | ||
))] | ||
const MIN_DIM_SIZE_SIMD: usize = 16; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Style: in order to match the top-down ordering in other crates, let's move |
||
|
||
#[derive(Clone, Copy)] | ||
pub struct EuclidMetric {} | ||
|
||
impl Metric for EuclidMetric { | ||
fn distance(v1: &[f32], v2: &[f32]) -> f32 { | ||
#[cfg(target_arch = "x86_64")] | ||
{ | ||
if is_x86_feature_detected!("avx") | ||
&& is_x86_feature_detected!("fma") | ||
&& v1.len() >= MIN_DIM_SIZE_AVX | ||
{ | ||
return unsafe { simd_avx::euclid_distance_avx(v1, v2) }; | ||
} | ||
} | ||
|
||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] | ||
{ | ||
if is_x86_feature_detected!("sse") && v1.len() >= MIN_DIM_SIZE_SIMD { | ||
return unsafe { simd_sse::euclid_distance_sse(v1, v2) }; | ||
} | ||
} | ||
|
||
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))] | ||
{ | ||
if std::arch::is_aarch64_feature_detected!("neon") && v1.len() >= MIN_DIM_SIZE_SIMD { | ||
return unsafe { simple_neon::euclid_distance_neon(v1, v2) }; | ||
} | ||
} | ||
|
||
euclid_distance(v1, v2) | ||
} | ||
} | ||
|
||
pub fn euclid_distance(v1: &[f32], v2: &[f32]) -> f32 { | ||
let s: f32 = v1 | ||
.iter() | ||
.copied() | ||
.zip(v2.iter().copied()) | ||
.map(|(a, b)| (a - b).powi(2)) | ||
.sum(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Style nit: I usually prefer turbofishing over type annotations, so |
||
s.abs().sqrt() | ||
} | ||
Comment on lines
+128
to
+136
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be a |
||
|
||
pub fn legacy_distance(lhs: &FloatArray, rhs: &FloatArray) -> f32 { | ||
#[cfg(target_arch = "x86_64")] | ||
{ | ||
use std::arch::x86_64::{ | ||
_mm256_castps256_ps128, _mm256_extractf128_ps, _mm256_fmadd_ps, _mm256_load_ps, | ||
_mm256_setzero_ps, _mm256_sub_ps, _mm_add_ps, _mm_add_ss, _mm_cvtss_f32, _mm_fmadd_ps, | ||
_mm_load_ps, _mm_movehl_ps, _mm_shuffle_ps, _mm_sub_ps, | ||
}; | ||
debug_assert_eq!(lhs.0.len() % 8, 4); | ||
|
||
unsafe { | ||
let mut acc_8x = _mm256_setzero_ps(); | ||
for (lh_slice, rh_slice) in lhs.0.chunks_exact(8).zip(rhs.0.chunks_exact(8)) { | ||
let lh_8x = _mm256_load_ps(lh_slice.as_ptr()); | ||
let rh_8x = _mm256_load_ps(rh_slice.as_ptr()); | ||
let diff = _mm256_sub_ps(lh_8x, rh_8x); | ||
acc_8x = _mm256_fmadd_ps(diff, diff, acc_8x); | ||
} | ||
|
||
let mut acc_4x = _mm256_extractf128_ps(acc_8x, 1); // upper half | ||
let right = _mm256_castps256_ps128(acc_8x); // lower half | ||
acc_4x = _mm_add_ps(acc_4x, right); // sum halves | ||
|
||
let lh_4x = _mm_load_ps(lhs.0[DIMENSIONS - 4..].as_ptr()); | ||
let rh_4x = _mm_load_ps(rhs.0[DIMENSIONS - 4..].as_ptr()); | ||
let diff = _mm_sub_ps(lh_4x, rh_4x); | ||
acc_4x = _mm_fmadd_ps(diff, diff, acc_4x); | ||
|
||
let lower = _mm_movehl_ps(acc_4x, acc_4x); | ||
acc_4x = _mm_add_ps(acc_4x, lower); | ||
let upper = _mm_shuffle_ps(acc_4x, acc_4x, 0x1); | ||
acc_4x = _mm_add_ss(acc_4x, upper); | ||
_mm_cvtss_f32(acc_4x) | ||
} | ||
} | ||
#[cfg(not(target_arch = "x86_64"))] | ||
lhs.0 | ||
.iter() | ||
.zip(rhs.0.iter()) | ||
.map(|(&a, &b)| (a - b).powi(2)) | ||
.sum::<f32>() | ||
} | ||
|
||
#[repr(align(32))] | ||
pub struct FloatArray(pub [f32; DIMENSIONS]); | ||
|
||
const DIMENSIONS: usize = 300; |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
use std::arch::x86_64::*; | ||
|
||
#[target_feature(enable = "avx")] | ||
#[target_feature(enable = "fma")] | ||
unsafe fn hsum256_ps_avx(x: __m256) -> f32 { | ||
let x128: __m128 = _mm_add_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x)); | ||
let x64: __m128 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128)); | ||
let x32: __m128 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55)); | ||
_mm_cvtss_f32(x32) | ||
} | ||
|
||
#[target_feature(enable = "avx")] | ||
#[target_feature(enable = "fma")] | ||
pub(crate) unsafe fn euclid_distance_avx(v1: &[f32], v2: &[f32]) -> f32 { | ||
let n = v1.len(); | ||
let m = n - (n % 32); | ||
let mut ptr1: *const f32 = v1.as_ptr(); | ||
let mut ptr2: *const f32 = v2.as_ptr(); | ||
let mut sum256_1: __m256 = _mm256_setzero_ps(); | ||
let mut sum256_2: __m256 = _mm256_setzero_ps(); | ||
let mut sum256_3: __m256 = _mm256_setzero_ps(); | ||
let mut sum256_4: __m256 = _mm256_setzero_ps(); | ||
let mut i: usize = 0; | ||
while i < m { | ||
let sub256_1: __m256 = | ||
_mm256_sub_ps(_mm256_loadu_ps(ptr1.add(0)), _mm256_loadu_ps(ptr2.add(0))); | ||
sum256_1 = _mm256_fmadd_ps(sub256_1, sub256_1, sum256_1); | ||
|
||
let sub256_2: __m256 = | ||
_mm256_sub_ps(_mm256_loadu_ps(ptr1.add(8)), _mm256_loadu_ps(ptr2.add(8))); | ||
sum256_2 = _mm256_fmadd_ps(sub256_2, sub256_2, sum256_2); | ||
|
||
let sub256_3: __m256 = | ||
_mm256_sub_ps(_mm256_loadu_ps(ptr1.add(16)), _mm256_loadu_ps(ptr2.add(16))); | ||
sum256_3 = _mm256_fmadd_ps(sub256_3, sub256_3, sum256_3); | ||
|
||
let sub256_4: __m256 = | ||
_mm256_sub_ps(_mm256_loadu_ps(ptr1.add(24)), _mm256_loadu_ps(ptr2.add(24))); | ||
sum256_4 = _mm256_fmadd_ps(sub256_4, sub256_4, sum256_4); | ||
|
||
ptr1 = ptr1.add(32); | ||
ptr2 = ptr2.add(32); | ||
i += 32; | ||
} | ||
|
||
let mut result = hsum256_ps_avx(sum256_1) | ||
+ hsum256_ps_avx(sum256_2) | ||
+ hsum256_ps_avx(sum256_3) | ||
+ hsum256_ps_avx(sum256_4); | ||
for i in 0..n - m { | ||
result += (*ptr1.add(i) - *ptr2.add(i)).powi(2); | ||
} | ||
result.abs().sqrt() | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
#[cfg(target_feature = "neon")] | ||
use std::arch::aarch64::*; | ||
|
||
#[cfg(target_feature = "neon")] | ||
pub(crate) unsafe fn euclid_distance_neon(v1: &[f32], v2: &[f32]) -> f32 { | ||
let n = v1.len(); | ||
let m = n - (n % 16); | ||
let mut ptr1: *const f32 = v1.as_ptr(); | ||
let mut ptr2: *const f32 = v2.as_ptr(); | ||
let mut sum1 = vdupq_n_f32(0.); | ||
let mut sum2 = vdupq_n_f32(0.); | ||
let mut sum3 = vdupq_n_f32(0.); | ||
let mut sum4 = vdupq_n_f32(0.); | ||
|
||
let mut i: usize = 0; | ||
while i < m { | ||
let sub1 = vsubq_f32(vld1q_f32(ptr1), vld1q_f32(ptr2)); | ||
sum1 = vfmaq_f32(sum1, sub1, sub1); | ||
|
||
let sub2 = vsubq_f32(vld1q_f32(ptr1.add(4)), vld1q_f32(ptr2.add(4))); | ||
sum2 = vfmaq_f32(sum2, sub2, sub2); | ||
|
||
let sub3 = vsubq_f32(vld1q_f32(ptr1.add(8)), vld1q_f32(ptr2.add(8))); | ||
sum3 = vfmaq_f32(sum3, sub3, sub3); | ||
|
||
let sub4 = vsubq_f32(vld1q_f32(ptr1.add(12)), vld1q_f32(ptr2.add(12))); | ||
sum4 = vfmaq_f32(sum4, sub4, sub4); | ||
|
||
ptr1 = ptr1.add(16); | ||
ptr2 = ptr2.add(16); | ||
i += 16; | ||
} | ||
let mut result = vaddvq_f32(sum1) + vaddvq_f32(sum2) + vaddvq_f32(sum3) + vaddvq_f32(sum4); | ||
for i in 0..n - m { | ||
result += (*ptr1.add(i) - *ptr2.add(i)).powi(2); | ||
} | ||
result.abs().sqrt() | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
#[cfg(target_arch = "x86")] | ||
use std::arch::x86::*; | ||
#[cfg(target_arch = "x86_64")] | ||
use std::arch::x86_64::*; | ||
|
||
#[target_feature(enable = "sse")] | ||
unsafe fn hsum128_ps_sse(x: __m128) -> f32 { | ||
let x64: __m128 = _mm_add_ps(x, _mm_movehl_ps(x, x)); | ||
let x32: __m128 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55)); | ||
_mm_cvtss_f32(x32) | ||
} | ||
|
||
#[target_feature(enable = "sse")] | ||
pub(crate) unsafe fn euclid_distance_sse(v1: &[f32], v2: &[f32]) -> f32 { | ||
let n = v1.len(); | ||
let m = n - (n % 16); | ||
let mut ptr1: *const f32 = v1.as_ptr(); | ||
let mut ptr2: *const f32 = v2.as_ptr(); | ||
let mut sum128_1: __m128 = _mm_setzero_ps(); | ||
let mut sum128_2: __m128 = _mm_setzero_ps(); | ||
let mut sum128_3: __m128 = _mm_setzero_ps(); | ||
let mut sum128_4: __m128 = _mm_setzero_ps(); | ||
let mut i: usize = 0; | ||
while i < m { | ||
let sub128_1 = _mm_sub_ps(_mm_loadu_ps(ptr1), _mm_loadu_ps(ptr2)); | ||
sum128_1 = _mm_add_ps(_mm_mul_ps(sub128_1, sub128_1), sum128_1); | ||
|
||
let sub128_2 = _mm_sub_ps(_mm_loadu_ps(ptr1.add(4)), _mm_loadu_ps(ptr2.add(4))); | ||
sum128_2 = _mm_add_ps(_mm_mul_ps(sub128_2, sub128_2), sum128_2); | ||
|
||
let sub128_3 = _mm_sub_ps(_mm_loadu_ps(ptr1.add(8)), _mm_loadu_ps(ptr2.add(8))); | ||
sum128_3 = _mm_add_ps(_mm_mul_ps(sub128_3, sub128_3), sum128_3); | ||
|
||
let sub128_4 = _mm_sub_ps(_mm_loadu_ps(ptr1.add(12)), _mm_loadu_ps(ptr2.add(12))); | ||
sum128_4 = _mm_add_ps(_mm_mul_ps(sub128_4, sub128_4), sum128_4); | ||
|
||
ptr1 = ptr1.add(16); | ||
ptr2 = ptr2.add(16); | ||
i += 16; | ||
} | ||
|
||
let mut result = hsum128_ps_sse(sum128_1) | ||
+ hsum128_ps_sse(sum128_2) | ||
+ hsum128_ps_sse(sum128_3) | ||
+ hsum128_ps_sse(sum128_4); | ||
for i in 0..n - m { | ||
result += (*ptr1.add(i) - *ptr2.add(i)).powi(2); | ||
} | ||
result.abs().sqrt() | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious what your motivation for adding a new crate for this is?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Separate mostly because I couldn't get benchmark to access
instant-distance-py
internals and the metrics felt to specific forinstant-distance
crate.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I gave this another try. There are 2 things that prevent me from adding benches directly to
instant-distance-py
:cdylib
- that one is easy we could have 2 types (cdylib
and regularlib
),instance-distance
, just like the other crate - AFAIK it makes it impossible to use both crates at the same time.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, let's just change the crate name to instant-distance-py in the Rust metadata (but should verify that the Python name will still be instant-distance), and then if we add the extra type we wouldn't need the extra crate anymore, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be possible if we add a tiny python wrapper module. That's to force maturin to build wheel for mixed rust/python project. Without the wrapper maturin uses lib.name as top level module name.