Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ test-python: instant-distance-py/test/instant_distance.so

bench-python: instant-distance-py/test/instant_distance.so
PYTHONPATH=instant-distance-py/test/ python3 -m timeit -n 10 -s 'import random, instant_distance; points = [[random.random() for _ in range(300)] for _ in range(1024)]; config = instant_distance.Config()' 'instant_distance.Hnsw.build(points, config)'
PYTHONPATH=instant-distance-py/test/ python3 -m timeit -n 10 -s 'import random, instant_distance; points = [[random.random() for _ in range(300)] for _ in range(1024)]; config = instant_distance.Config(); config.distance_metric = instant_distance.DistanceMetric.Cosine' 'instant_distance.Hnsw.build(points, config)'

clean:
cargo clean
Expand Down
18 changes: 14 additions & 4 deletions distance-metrics/benches/all.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
use bencher::{benchmark_group, benchmark_main, Bencher};

use distance_metrics::{EuclidMetric, Metric};
use distance_metrics::{
Metric, {CosineMetric, EuclidMetric},
};
use rand::{rngs::StdRng, Rng, SeedableRng};

benchmark_main!(benches);
benchmark_group!(benches, legacy, non_simd, metric::<EuclidMetric>,);
benchmark_group!(
benches,
legacy,
non_simd,
metric::<EuclidMetric>,
metric::<CosineMetric>
);

fn legacy(bench: &mut Bencher) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be curious to see some benchmarking results in this PR for the legacy approach vs your new implementations!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The most recent results:

test legacy               ... bench:          21 ns/iter (+/- 0)
test non_simd             ... bench:         183 ns/iter (+/- 2)
test simple<CosineMetric> ... bench:          22 ns/iter (+/- 0)
test simple<EuclidMetric> ... bench:          25 ns/iter (+/- 0)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I think legacy is euclid except it doesn't bother taking the square root, right? If it's still 4ns faster than Euclid, we should maybe expose it under a name other than EuclidMetric, as I think we'd want to stick to it (since the difference between the square root and the squared value doesn't usually matter for our use case?).

Did you run any tests to show that the platform-specific implementations have the same result (or close to it) as the non-SIMD implementation?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the legacy is euclid but with square root. I believe we don't care about having squared values. If I drop the square root calculation then they are almost the same (2-5% difference):

test legacy               ... bench:          54 ns/iter (+/- 6)
test metric<CosineMetric> ... bench:          56 ns/iter (+/- 11)
test metric<EuclidMetric> ... bench:          55 ns/iter (+/- 6)
test non_simd             ... bench:         441 ns/iter (+/- 40)

(the numbers are higher because I'm on battery and the CPU is throttled)

We could use that since it doesn't have dimensions limitation. Another idea would be to have the current metric (as for example Euclid300) for the use case of 300 dimensions.

let mut rng = StdRng::seed_from_u64(SEED);
Expand All @@ -24,8 +32,10 @@ fn non_simd(bench: &mut Bencher) {

fn metric<M: Metric>(bench: &mut Bencher) {
let mut rng = StdRng::seed_from_u64(SEED);
let point_a = [rng.gen(); 300];
let point_b = [rng.gen(); 300];
let mut point_a = [rng.gen(); 300];
let mut point_b = [rng.gen(); 300];
M::preprocess(&mut point_a);
M::preprocess(&mut point_b);

bench.iter(|| M::distance(&point_a, &point_b))
}
Expand Down
84 changes: 84 additions & 0 deletions distance-metrics/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ pub mod simd_neon;
pub trait Metric {
/// Greater the value - more distant the vectors
fn distance(v1: &[f32], v2: &[f32]) -> f32;

/// Necessary vector transformations performed before adding it to the collection (like normalization)
fn preprocess(vector: &mut [f32]);
}

#[cfg(target_arch = "x86_64")]
Expand Down Expand Up @@ -54,6 +57,72 @@ impl Metric for EuclidMetric {

euclid_distance(v1, v2)
}

fn preprocess(_vector: &mut [f32]) {
// no-op
}
}

#[derive(Clone, Copy)]
pub struct CosineMetric {}

impl Metric for CosineMetric {
fn distance(v1: &[f32], v2: &[f32]) -> f32 {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx")
&& is_x86_feature_detected!("fma")
&& v1.len() >= MIN_DIM_SIZE_AVX
{
return 1.0 - unsafe { simd_avx::dot_similarity_avx(v1, v2) };
}
}

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("sse") && v1.len() >= MIN_DIM_SIZE_SIMD {
return 1.0 - unsafe { simd_sse::dot_similarity_sse(v1, v2) };
}
}

#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
{
if std::arch::is_aarch64_feature_detected!("neon") && v1.len() >= MIN_DIM_SIZE_SIMD {
return 1.0 - unsafe { simd_neon::dot_similarity_neon(v1, v2) };
}
}

1.0 - dot_similarity(v1, v2)
}

fn preprocess(vector: &mut [f32]) {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx")
&& is_x86_feature_detected!("fma")
&& vector.len() >= MIN_DIM_SIZE_AVX
{
return unsafe { simd_avx::cosine_preprocess_avx(vector) };
}
}

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("sse") && vector.len() >= MIN_DIM_SIZE_SIMD {
return unsafe { simd_sse::cosine_preprocess_sse(vector) };
}
}

#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
{
if std::arch::is_aarch64_feature_detected!("neon") && vector.len() >= MIN_DIM_SIZE_SIMD
{
return unsafe { simd_neon::cosine_preprocess_neon(vector) };
}
}

cosine_preprocess(vector);
}
}

pub fn euclid_distance(v1: &[f32], v2: &[f32]) -> f32 {
Expand All @@ -66,6 +135,21 @@ pub fn euclid_distance(v1: &[f32], v2: &[f32]) -> f32 {
s.abs().sqrt()
}

pub fn cosine_preprocess(vector: &mut [f32]) {
let mut length: f32 = vector.iter().map(|x| x * x).sum();
if length < f32::EPSILON {
return;
}
length = length.sqrt();
for x in vector.iter_mut() {
*x /= length;
}
}

pub fn dot_similarity(v1: &[f32], v2: &[f32]) -> f32 {
v1.iter().zip(v2).map(|(a, b)| a * b).sum()
}

pub fn legacy_distance(lhs: &FloatArray, rhs: &FloatArray) -> f32 {
#[cfg(target_arch = "x86_64")]
{
Expand Down
132 changes: 132 additions & 0 deletions distance-metrics/src/simd_avx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,135 @@ pub(crate) unsafe fn euclid_distance_avx(v1: &[f32], v2: &[f32]) -> f32 {
}
result.abs().sqrt()
}

#[target_feature(enable = "avx")]
#[target_feature(enable = "fma")]
pub(crate) unsafe fn cosine_preprocess_avx(vector: &mut [f32]) {
let n = vector.len();
let m = n - (n % 32);
let mut ptr: *const f32 = vector.as_ptr();
let mut sum256_1: __m256 = _mm256_setzero_ps();
let mut sum256_2: __m256 = _mm256_setzero_ps();
let mut sum256_3: __m256 = _mm256_setzero_ps();
let mut sum256_4: __m256 = _mm256_setzero_ps();
let mut i: usize = 0;
while i < m {
let m256_1 = _mm256_loadu_ps(ptr);
sum256_1 = _mm256_fmadd_ps(m256_1, m256_1, sum256_1);

let m256_2 = _mm256_loadu_ps(ptr.add(8));
sum256_2 = _mm256_fmadd_ps(m256_2, m256_2, sum256_2);

let m256_3 = _mm256_loadu_ps(ptr.add(16));
sum256_3 = _mm256_fmadd_ps(m256_3, m256_3, sum256_3);

let m256_4 = _mm256_loadu_ps(ptr.add(24));
sum256_4 = _mm256_fmadd_ps(m256_4, m256_4, sum256_4);

ptr = ptr.add(32);
i += 32;
}

let mut length = hsum256_ps_avx(sum256_1)
+ hsum256_ps_avx(sum256_2)
+ hsum256_ps_avx(sum256_3)
+ hsum256_ps_avx(sum256_4);
for i in 0..n - m {
length += (*ptr.add(i)).powi(2);
}
if length < f32::EPSILON {
return;
}
length = length.sqrt();
for x in vector.iter_mut() {
*x /= length;
}
}

#[target_feature(enable = "avx")]
#[target_feature(enable = "fma")]
pub(crate) unsafe fn dot_similarity_avx(v1: &[f32], v2: &[f32]) -> f32 {
let n = v1.len();
let m = n - (n % 32);
let mut ptr1: *const f32 = v1.as_ptr();
let mut ptr2: *const f32 = v2.as_ptr();
let mut sum256_1: __m256 = _mm256_setzero_ps();
let mut sum256_2: __m256 = _mm256_setzero_ps();
let mut sum256_3: __m256 = _mm256_setzero_ps();
let mut sum256_4: __m256 = _mm256_setzero_ps();
let mut i: usize = 0;
while i < m {
sum256_1 = _mm256_fmadd_ps(_mm256_loadu_ps(ptr1), _mm256_loadu_ps(ptr2), sum256_1);
sum256_2 = _mm256_fmadd_ps(
_mm256_loadu_ps(ptr1.add(8)),
_mm256_loadu_ps(ptr2.add(8)),
sum256_2,
);
sum256_3 = _mm256_fmadd_ps(
_mm256_loadu_ps(ptr1.add(16)),
_mm256_loadu_ps(ptr2.add(16)),
sum256_3,
);
sum256_4 = _mm256_fmadd_ps(
_mm256_loadu_ps(ptr1.add(24)),
_mm256_loadu_ps(ptr2.add(24)),
sum256_4,
);

ptr1 = ptr1.add(32);
ptr2 = ptr2.add(32);
i += 32;
}

let mut result = hsum256_ps_avx(sum256_1)
+ hsum256_ps_avx(sum256_2)
+ hsum256_ps_avx(sum256_3)
+ hsum256_ps_avx(sum256_4);

for i in 0..n - m {
result += (*ptr1.add(i)) * (*ptr2.add(i));
}
result
}

#[cfg(test)]
mod tests {
#[test]
fn test_spaces_avx() {
use super::*;
use crate::*;

if is_x86_feature_detected!("avx") && is_x86_feature_detected!("fma") {
let v1: Vec<f32> = vec![
10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
26., 27., 28., 29., 30., 31.,
];
let v2: Vec<f32> = vec![
40., 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55.,
10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
56., 57., 58., 59., 60., 61.,
];

let euclid_simd = unsafe { euclid_distance_avx(&v1, &v2) };
let euclid = euclid_distance(&v1, &v2);
assert_eq!(euclid_simd, euclid);

let dot_simd = unsafe { dot_similarity_avx(&v1, &v2) };
let dot = dot_similarity(&v1, &v2);
assert_eq!(dot_simd, dot);

let mut v1 = v1;
let mut v1_copy = v1.clone();
unsafe { cosine_preprocess_avx(&mut v1) };
cosine_preprocess(&mut v1_copy);
assert_eq!(v1, v1_copy);
} else {
println!("avx test skipped");
Comment on lines +146 to +183
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it would be nice to split the addition of test code between the commits that add the test code (or alternatively, add it separately at the end).

}
}
}
105 changes: 105 additions & 0 deletions distance-metrics/src/simd_neon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,108 @@ pub(crate) unsafe fn euclid_distance_neon(v1: &[f32], v2: &[f32]) -> f32 {
}
result.abs().sqrt()
}

#[cfg(target_feature = "neon")]
pub(crate) unsafe fn cosine_preprocess_neon(vector: &mut [f32]) {
let n = vector.len();
let m = n - (n % 16);
let mut ptr: *const f32 = vector.as_ptr();
let mut sum1 = vdupq_n_f32(0.);
let mut sum2 = vdupq_n_f32(0.);
let mut sum3 = vdupq_n_f32(0.);
let mut sum4 = vdupq_n_f32(0.);

let mut i: usize = 0;
while i < m {
let d1 = vld1q_f32(ptr);
sum1 = vfmaq_f32(sum1, d1, d1);

let d2 = vld1q_f32(ptr.add(4));
sum2 = vfmaq_f32(sum2, d2, d2);

let d3 = vld1q_f32(ptr.add(8));
sum3 = vfmaq_f32(sum3, d3, d3);

let d4 = vld1q_f32(ptr.add(12));
sum4 = vfmaq_f32(sum4, d4, d4);

ptr = ptr.add(16);
i += 16;
}
let mut length = vaddvq_f32(sum1) + vaddvq_f32(sum2) + vaddvq_f32(sum3) + vaddvq_f32(sum4);
for v in vector.iter().take(n).skip(m) {
length += v.powi(2);
}
if length < f32::EPSILON {
return;
}
let length = length.sqrt();
for x in vector.iter_mut() {
*x /= length;
}
}

#[cfg(target_feature = "neon")]
pub(crate) unsafe fn dot_similarity_neon(v1: &[f32], v2: &[f32]) -> f32 {
let n = v1.len();
let m = n - (n % 16);
let mut ptr1: *const f32 = v1.as_ptr();
let mut ptr2: *const f32 = v2.as_ptr();
let mut sum1 = vdupq_n_f32(0.);
let mut sum2 = vdupq_n_f32(0.);
let mut sum3 = vdupq_n_f32(0.);
let mut sum4 = vdupq_n_f32(0.);

let mut i: usize = 0;
while i < m {
sum1 = vfmaq_f32(sum1, vld1q_f32(ptr1), vld1q_f32(ptr2));
sum2 = vfmaq_f32(sum2, vld1q_f32(ptr1.add(4)), vld1q_f32(ptr2.add(4)));
sum3 = vfmaq_f32(sum3, vld1q_f32(ptr1.add(8)), vld1q_f32(ptr2.add(8)));
sum4 = vfmaq_f32(sum4, vld1q_f32(ptr1.add(12)), vld1q_f32(ptr2.add(12)));
ptr1 = ptr1.add(16);
ptr2 = ptr2.add(16);
i += 16;
}
let mut result = vaddvq_f32(sum1) + vaddvq_f32(sum2) + vaddvq_f32(sum3) + vaddvq_f32(sum4);
for i in 0..n - m {
result += (*ptr1.add(i)) * (*ptr2.add(i));
}
result
}

#[cfg(test)]
mod tests {
#[cfg(target_feature = "neon")]
#[test]
fn test_spaces_neon() {
use super::*;
use crate::*;

if std::arch::is_aarch64_feature_detected!("neon") {
let v1: Vec<f32> = vec![
10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
26., 27., 28., 29., 30., 31.,
];
let v2: Vec<f32> = vec![
40., 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55.,
56., 57., 58., 59., 60., 61.,
];

let euclid_simd = unsafe { euclid_distance_neon(&v1, &v2) };
let euclid = euclid_distance(&v1, &v2);
assert_eq!(euclid_simd, euclid);

let dot_simd = unsafe { dot_similarity_neon(&v1, &v2) };
let dot = dot_similarity(&v1, &v2);
assert_eq!(dot_simd, dot);

let mut v1 = v1;
let mut v1_copy = v1.clone();
unsafe { cosine_preprocess_neon(&mut v1) };
cosine_preprocess(&mut v1_copy);
assert_eq!(v1, v1_copy);
} else {
println!("neon test skipped");
}
}
}
Loading