Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[workspace]
members = ["instant-distance", "instant-distance-py"]
members = ["distance-metrics", "instant-distance", "instant-distance-py"]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious what your motivation for adding a new crate for this is?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separate mostly because I couldn't get benchmark to access instant-distance-py internals and the metrics felt to specific for instant-distance crate.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I gave this another try. There are 2 things that prevent me from adding benches directly to instant-distance-py:

  1. crate type is cdylib - that one is easy we could have 2 types (cdylib and regular lib),
  2. the name is instance-distance, just like the other crate - AFAIK it makes it impossible to use both crates at the same time.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, let's just change the crate name to instant-distance-py in the Rust metadata (but should verify that the Python name will still be instant-distance), and then if we add the extra type we wouldn't need the extra crate anymore, right?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be possible if we add a tiny python wrapper module. That's to force maturin to build wheel for mixed rust/python project. Without the wrapper maturin uses lib.name as top level module name.


[profile.bench]
debug = true
21 changes: 21 additions & 0 deletions distance-metrics/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
[package]
name = "distance-metrics"
version = "0.6.0"
license = "MIT OR Apache-2.0"
edition = "2021"
rust-version = "1.58"
homepage = "https://github.com/InstantDomain/instant-distance"
repository = "https://github.com/InstantDomain/instant-distance"
documentation = "https://docs.rs/instant-distance"
workspace = ".."
readme = "../README.md"

[dependencies]

[dev-dependencies]
bencher = "0.1.5"
rand = { version = "0.8", features = ["small_rng"] }

[[bench]]
name = "all"
harness = false
33 changes: 33 additions & 0 deletions distance-metrics/benches/all.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
use bencher::{benchmark_group, benchmark_main, Bencher};

use distance_metrics::{EuclidMetric, Metric};
use rand::{rngs::StdRng, Rng, SeedableRng};

benchmark_main!(benches);
benchmark_group!(benches, legacy, non_simd, metric::<EuclidMetric>,);

fn legacy(bench: &mut Bencher) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be curious to see some benchmarking results in this PR for the legacy approach vs your new implementations!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The most recent results:

test legacy               ... bench:          21 ns/iter (+/- 0)
test non_simd             ... bench:         183 ns/iter (+/- 2)
test simple<CosineMetric> ... bench:          22 ns/iter (+/- 0)
test simple<EuclidMetric> ... bench:          25 ns/iter (+/- 0)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I think legacy is euclid except it doesn't bother taking the square root, right? If it's still 4ns faster than Euclid, we should maybe expose it under a name other than EuclidMetric, as I think we'd want to stick to it (since the difference between the square root and the squared value doesn't usually matter for our use case?).

Did you run any tests to show that the platform-specific implementations have the same result (or close to it) as the non-SIMD implementation?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the legacy is euclid but with square root. I believe we don't care about having squared values. If I drop the square root calculation then they are almost the same (2-5% difference):

test legacy               ... bench:          54 ns/iter (+/- 6)
test metric<CosineMetric> ... bench:          56 ns/iter (+/- 11)
test metric<EuclidMetric> ... bench:          55 ns/iter (+/- 6)
test non_simd             ... bench:         441 ns/iter (+/- 40)

(the numbers are higher because I'm on battery and the CPU is throttled)

We could use that since it doesn't have dimensions limitation. Another idea would be to have the current metric (as for example Euclid300) for the use case of 300 dimensions.

let mut rng = StdRng::seed_from_u64(SEED);
let point_a = distance_metrics::FloatArray([rng.gen(); 300]);
let point_b = distance_metrics::FloatArray([rng.gen(); 300]);

bench.iter(|| distance_metrics::legacy_distance(&point_a, &point_b))
}

fn non_simd(bench: &mut Bencher) {
let mut rng = StdRng::seed_from_u64(SEED);
let point_a = [rng.gen(); 300];
let point_b = [rng.gen(); 300];

bench.iter(|| distance_metrics::euclid_distance(&point_a, &point_b))
}

fn metric<M: Metric>(bench: &mut Bencher) {
let mut rng = StdRng::seed_from_u64(SEED);
let point_a = [rng.gen(); 300];
let point_b = [rng.gen(); 300];

bench.iter(|| M::distance(&point_a, &point_b))
}

const SEED: u64 = 123456789;
115 changes: 115 additions & 0 deletions distance-metrics/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
pub mod simd_sse;

#[cfg(target_arch = "x86_64")]
pub mod simd_avx;

#[cfg(target_arch = "aarch64")]
pub mod simd_neon;

/// Defines how to compare vectors
pub trait Metric {
/// Greater the value - more distant the vectors
fn distance(v1: &[f32], v2: &[f32]) -> f32;
}

#[cfg(target_arch = "x86_64")]
const MIN_DIM_SIZE_AVX: usize = 32;

#[cfg(any(
target_arch = "x86",
target_arch = "x86_64",
all(target_arch = "aarch64", target_feature = "neon")
))]
const MIN_DIM_SIZE_SIMD: usize = 16;
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Style: in order to match the top-down ordering in other crates, let's move Metric, MIN_DIM_SIZE_AVX and MIN_DIM_SIZE_SIMD down, just above the FloatArray definition?


#[derive(Clone, Copy)]
pub struct EuclidMetric {}

impl Metric for EuclidMetric {
fn distance(v1: &[f32], v2: &[f32]) -> f32 {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx")
&& is_x86_feature_detected!("fma")
&& v1.len() >= MIN_DIM_SIZE_AVX
{
return unsafe { simd_avx::euclid_distance_avx(v1, v2) };
}
}

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("sse") && v1.len() >= MIN_DIM_SIZE_SIMD {
return unsafe { simd_sse::euclid_distance_sse(v1, v2) };
}
}

#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
{
if std::arch::is_aarch64_feature_detected!("neon") && v1.len() >= MIN_DIM_SIZE_SIMD {
return unsafe { simple_neon::euclid_distance_neon(v1, v2) };
}
}

euclid_distance(v1, v2)
}
}

pub fn euclid_distance(v1: &[f32], v2: &[f32]) -> f32 {
let s: f32 = v1
.iter()
.copied()
.zip(v2.iter().copied())
.map(|(a, b)| (a - b).powi(2))
.sum();
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Style nit: I usually prefer turbofishing over type annotations, so sum::<f32>() instead of s: f32. In this case, that would also allow you to chain abs() and sqrt() directly on the expression, which is nice IMO.

s.abs().sqrt()
}
Comment on lines +128 to +136
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be a NaiveEuclid type with a Metric impl instead?


pub fn legacy_distance(lhs: &FloatArray, rhs: &FloatArray) -> f32 {
#[cfg(target_arch = "x86_64")]
{
use std::arch::x86_64::{
_mm256_castps256_ps128, _mm256_extractf128_ps, _mm256_fmadd_ps, _mm256_load_ps,
_mm256_setzero_ps, _mm256_sub_ps, _mm_add_ps, _mm_add_ss, _mm_cvtss_f32, _mm_fmadd_ps,
_mm_load_ps, _mm_movehl_ps, _mm_shuffle_ps, _mm_sub_ps,
};
debug_assert_eq!(lhs.0.len() % 8, 4);

unsafe {
let mut acc_8x = _mm256_setzero_ps();
for (lh_slice, rh_slice) in lhs.0.chunks_exact(8).zip(rhs.0.chunks_exact(8)) {
let lh_8x = _mm256_load_ps(lh_slice.as_ptr());
let rh_8x = _mm256_load_ps(rh_slice.as_ptr());
let diff = _mm256_sub_ps(lh_8x, rh_8x);
acc_8x = _mm256_fmadd_ps(diff, diff, acc_8x);
}

let mut acc_4x = _mm256_extractf128_ps(acc_8x, 1); // upper half
let right = _mm256_castps256_ps128(acc_8x); // lower half
acc_4x = _mm_add_ps(acc_4x, right); // sum halves

let lh_4x = _mm_load_ps(lhs.0[DIMENSIONS - 4..].as_ptr());
let rh_4x = _mm_load_ps(rhs.0[DIMENSIONS - 4..].as_ptr());
let diff = _mm_sub_ps(lh_4x, rh_4x);
acc_4x = _mm_fmadd_ps(diff, diff, acc_4x);

let lower = _mm_movehl_ps(acc_4x, acc_4x);
acc_4x = _mm_add_ps(acc_4x, lower);
let upper = _mm_shuffle_ps(acc_4x, acc_4x, 0x1);
acc_4x = _mm_add_ss(acc_4x, upper);
_mm_cvtss_f32(acc_4x)
}
}
#[cfg(not(target_arch = "x86_64"))]
lhs.0
.iter()
.zip(rhs.0.iter())
.map(|(&a, &b)| (a - b).powi(2))
.sum::<f32>()
}

#[repr(align(32))]
pub struct FloatArray(pub [f32; DIMENSIONS]);

const DIMENSIONS: usize = 300;
54 changes: 54 additions & 0 deletions distance-metrics/src/simd_avx.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
use std::arch::x86_64::*;

#[target_feature(enable = "avx")]
#[target_feature(enable = "fma")]
unsafe fn hsum256_ps_avx(x: __m256) -> f32 {
let x128: __m128 = _mm_add_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x));
let x64: __m128 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128));
let x32: __m128 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55));
_mm_cvtss_f32(x32)
}

#[target_feature(enable = "avx")]
#[target_feature(enable = "fma")]
pub(crate) unsafe fn euclid_distance_avx(v1: &[f32], v2: &[f32]) -> f32 {
let n = v1.len();
let m = n - (n % 32);
let mut ptr1: *const f32 = v1.as_ptr();
let mut ptr2: *const f32 = v2.as_ptr();
let mut sum256_1: __m256 = _mm256_setzero_ps();
let mut sum256_2: __m256 = _mm256_setzero_ps();
let mut sum256_3: __m256 = _mm256_setzero_ps();
let mut sum256_4: __m256 = _mm256_setzero_ps();
let mut i: usize = 0;
while i < m {
let sub256_1: __m256 =
_mm256_sub_ps(_mm256_loadu_ps(ptr1.add(0)), _mm256_loadu_ps(ptr2.add(0)));
sum256_1 = _mm256_fmadd_ps(sub256_1, sub256_1, sum256_1);

let sub256_2: __m256 =
_mm256_sub_ps(_mm256_loadu_ps(ptr1.add(8)), _mm256_loadu_ps(ptr2.add(8)));
sum256_2 = _mm256_fmadd_ps(sub256_2, sub256_2, sum256_2);

let sub256_3: __m256 =
_mm256_sub_ps(_mm256_loadu_ps(ptr1.add(16)), _mm256_loadu_ps(ptr2.add(16)));
sum256_3 = _mm256_fmadd_ps(sub256_3, sub256_3, sum256_3);

let sub256_4: __m256 =
_mm256_sub_ps(_mm256_loadu_ps(ptr1.add(24)), _mm256_loadu_ps(ptr2.add(24)));
sum256_4 = _mm256_fmadd_ps(sub256_4, sub256_4, sum256_4);

ptr1 = ptr1.add(32);
ptr2 = ptr2.add(32);
i += 32;
}

let mut result = hsum256_ps_avx(sum256_1)
+ hsum256_ps_avx(sum256_2)
+ hsum256_ps_avx(sum256_3)
+ hsum256_ps_avx(sum256_4);
for i in 0..n - m {
result += (*ptr1.add(i) - *ptr2.add(i)).powi(2);
}
result.abs().sqrt()
}
38 changes: 38 additions & 0 deletions distance-metrics/src/simd_neon.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#[cfg(target_feature = "neon")]
use std::arch::aarch64::*;

#[cfg(target_feature = "neon")]
pub(crate) unsafe fn euclid_distance_neon(v1: &[f32], v2: &[f32]) -> f32 {
let n = v1.len();
let m = n - (n % 16);
let mut ptr1: *const f32 = v1.as_ptr();
let mut ptr2: *const f32 = v2.as_ptr();
let mut sum1 = vdupq_n_f32(0.);
let mut sum2 = vdupq_n_f32(0.);
let mut sum3 = vdupq_n_f32(0.);
let mut sum4 = vdupq_n_f32(0.);

let mut i: usize = 0;
while i < m {
let sub1 = vsubq_f32(vld1q_f32(ptr1), vld1q_f32(ptr2));
sum1 = vfmaq_f32(sum1, sub1, sub1);

let sub2 = vsubq_f32(vld1q_f32(ptr1.add(4)), vld1q_f32(ptr2.add(4)));
sum2 = vfmaq_f32(sum2, sub2, sub2);

let sub3 = vsubq_f32(vld1q_f32(ptr1.add(8)), vld1q_f32(ptr2.add(8)));
sum3 = vfmaq_f32(sum3, sub3, sub3);

let sub4 = vsubq_f32(vld1q_f32(ptr1.add(12)), vld1q_f32(ptr2.add(12)));
sum4 = vfmaq_f32(sum4, sub4, sub4);

ptr1 = ptr1.add(16);
ptr2 = ptr2.add(16);
i += 16;
}
let mut result = vaddvq_f32(sum1) + vaddvq_f32(sum2) + vaddvq_f32(sum3) + vaddvq_f32(sum4);
for i in 0..n - m {
result += (*ptr1.add(i) - *ptr2.add(i)).powi(2);
}
result.abs().sqrt()
}
50 changes: 50 additions & 0 deletions distance-metrics/src/simd_sse.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#[cfg(target_arch = "x86")]
use std::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;

#[target_feature(enable = "sse")]
unsafe fn hsum128_ps_sse(x: __m128) -> f32 {
let x64: __m128 = _mm_add_ps(x, _mm_movehl_ps(x, x));
let x32: __m128 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55));
_mm_cvtss_f32(x32)
}

#[target_feature(enable = "sse")]
pub(crate) unsafe fn euclid_distance_sse(v1: &[f32], v2: &[f32]) -> f32 {
let n = v1.len();
let m = n - (n % 16);
let mut ptr1: *const f32 = v1.as_ptr();
let mut ptr2: *const f32 = v2.as_ptr();
let mut sum128_1: __m128 = _mm_setzero_ps();
let mut sum128_2: __m128 = _mm_setzero_ps();
let mut sum128_3: __m128 = _mm_setzero_ps();
let mut sum128_4: __m128 = _mm_setzero_ps();
let mut i: usize = 0;
while i < m {
let sub128_1 = _mm_sub_ps(_mm_loadu_ps(ptr1), _mm_loadu_ps(ptr2));
sum128_1 = _mm_add_ps(_mm_mul_ps(sub128_1, sub128_1), sum128_1);

let sub128_2 = _mm_sub_ps(_mm_loadu_ps(ptr1.add(4)), _mm_loadu_ps(ptr2.add(4)));
sum128_2 = _mm_add_ps(_mm_mul_ps(sub128_2, sub128_2), sum128_2);

let sub128_3 = _mm_sub_ps(_mm_loadu_ps(ptr1.add(8)), _mm_loadu_ps(ptr2.add(8)));
sum128_3 = _mm_add_ps(_mm_mul_ps(sub128_3, sub128_3), sum128_3);

let sub128_4 = _mm_sub_ps(_mm_loadu_ps(ptr1.add(12)), _mm_loadu_ps(ptr2.add(12)));
sum128_4 = _mm_add_ps(_mm_mul_ps(sub128_4, sub128_4), sum128_4);

ptr1 = ptr1.add(16);
ptr2 = ptr2.add(16);
i += 16;
}

let mut result = hsum128_ps_sse(sum128_1)
+ hsum128_ps_sse(sum128_2)
+ hsum128_ps_sse(sum128_3)
+ hsum128_ps_sse(sum128_4);
for i in 0..n - m {
result += (*ptr1.add(i) - *ptr2.add(i)).powi(2);
}
result.abs().sqrt()
}
2 changes: 1 addition & 1 deletion instant-distance-py/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ crate-type = ["cdylib"]

[dependencies]
bincode = "1.3.1"
distance-metrics = { version = "0.6", path = "../distance-metrics" }
instant-distance = { version = "0.6", path = "../instant-distance", features = ["with-serde"] }
pyo3 = { version = "0.18.0", features = ["extension-module"] }
serde = { version = "1", features = ["derive"] }
serde-big-array = "0.4.1"
Loading