diff --git a/Cargo.toml b/Cargo.toml index 0f7eac6..38984ea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["instant-distance", "instant-distance-py"] +members = ["distance-metrics", "instant-distance", "instant-distance-py"] [profile.bench] debug = true diff --git a/Makefile b/Makefile index ef56aec..1b40721 100644 --- a/Makefile +++ b/Makefile @@ -1,8 +1,15 @@ -test-python: - cargo build --release - cp target/release/libinstant_distance.dylib instant-distance-py/test/instant_distance.so +instant-distance-py/test/instant_distance.so: instant-distance-py/src/lib.rs + RUSTFLAGS="-C target-cpu=native" cargo build --release + ([ -f target/release/libinstant_distance.dylib ] && cp target/release/libinstant_distance.dylib instant-distance-py/test/instant_distance.so) || \ + ([ -f target/release/libinstant_distance.so ] && cp target/release/libinstant_distance.so instant-distance-py/test/instant_distance.so) + +test-python: instant-distance-py/test/instant_distance.so PYTHONPATH=instant-distance-py/test/ python3 -m test +bench-python: instant-distance-py/test/instant_distance.so + PYTHONPATH=instant-distance-py/test/ python3 -m timeit -n 10 -s 'import random, instant_distance; points = [[random.random() for _ in range(300)] for _ in range(1024)]; config = instant_distance.Config()' 'instant_distance.Hnsw.build(points, config)' + PYTHONPATH=instant-distance-py/test/ python3 -m timeit -n 10 -s 'import random, instant_distance; points = [[random.random() for _ in range(300)] for _ in range(1024)]; config = instant_distance.Config(); config.distance_metric = instant_distance.DistanceMetric.Cosine' 'instant_distance.Hnsw.build(points, config)' + clean: cargo clean rm -f instant-distance-py/test/instant_distance.so diff --git a/distance-metrics/Cargo.toml b/distance-metrics/Cargo.toml new file mode 100644 index 0000000..7d0b580 --- /dev/null +++ b/distance-metrics/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "distance-metrics" +version = "0.6.0" +license = "MIT OR Apache-2.0" +edition = "2021" +rust-version = "1.58" +homepage = "https://github.com/InstantDomain/instant-distance" +repository = "https://github.com/InstantDomain/instant-distance" +documentation = "https://docs.rs/instant-distance" +workspace = ".." +readme = "../README.md" + +[dependencies] + +[dev-dependencies] +bencher = "0.1.5" +rand = { version = "0.8", features = ["small_rng"] } + +[[bench]] +name = "all" +harness = false diff --git a/distance-metrics/benches/all.rs b/distance-metrics/benches/all.rs new file mode 100644 index 0000000..b3405f7 --- /dev/null +++ b/distance-metrics/benches/all.rs @@ -0,0 +1,43 @@ +use bencher::{benchmark_group, benchmark_main, Bencher}; + +use distance_metrics::{ + Metric, {CosineMetric, EuclidMetric}, +}; +use rand::{rngs::StdRng, Rng, SeedableRng}; + +benchmark_main!(benches); +benchmark_group!( + benches, + legacy, + non_simd, + metric::, + metric:: +); + +fn legacy(bench: &mut Bencher) { + let mut rng = StdRng::seed_from_u64(SEED); + let point_a = distance_metrics::FloatArray([rng.gen(); 300]); + let point_b = distance_metrics::FloatArray([rng.gen(); 300]); + + bench.iter(|| distance_metrics::legacy_distance(&point_a, &point_b)) +} + +fn non_simd(bench: &mut Bencher) { + let mut rng = StdRng::seed_from_u64(SEED); + let point_a = [rng.gen(); 300]; + let point_b = [rng.gen(); 300]; + + bench.iter(|| distance_metrics::euclid_distance(&point_a, &point_b)) +} + +fn metric(bench: &mut Bencher) { + let mut rng = StdRng::seed_from_u64(SEED); + let mut point_a = [rng.gen(); 300]; + let mut point_b = [rng.gen(); 300]; + M::preprocess(&mut point_a); + M::preprocess(&mut point_b); + + bench.iter(|| M::distance(&point_a, &point_b)) +} + +const SEED: u64 = 123456789; diff --git a/distance-metrics/src/lib.rs b/distance-metrics/src/lib.rs new file mode 100644 index 0000000..e197918 --- /dev/null +++ b/distance-metrics/src/lib.rs @@ -0,0 +1,199 @@ +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +pub mod simd_sse; + +#[cfg(target_arch = "x86_64")] +pub mod simd_avx; + +#[cfg(target_arch = "aarch64")] +pub mod simd_neon; + +/// Defines how to compare vectors +pub trait Metric { + /// Greater the value - more distant the vectors + fn distance(v1: &[f32], v2: &[f32]) -> f32; + + /// Necessary vector transformations performed before adding it to the collection (like normalization) + fn preprocess(vector: &mut [f32]); +} + +#[cfg(target_arch = "x86_64")] +const MIN_DIM_SIZE_AVX: usize = 32; + +#[cfg(any( + target_arch = "x86", + target_arch = "x86_64", + all(target_arch = "aarch64", target_feature = "neon") +))] +const MIN_DIM_SIZE_SIMD: usize = 16; + +#[derive(Clone, Copy)] +pub struct EuclidMetric {} + +impl Metric for EuclidMetric { + fn distance(v1: &[f32], v2: &[f32]) -> f32 { + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("avx") + && is_x86_feature_detected!("fma") + && v1.len() >= MIN_DIM_SIZE_AVX + { + return unsafe { simd_avx::euclid_distance_avx(v1, v2) }; + } + } + + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + { + if is_x86_feature_detected!("sse") && v1.len() >= MIN_DIM_SIZE_SIMD { + return unsafe { simd_sse::euclid_distance_sse(v1, v2) }; + } + } + + #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] + { + if std::arch::is_aarch64_feature_detected!("neon") && v1.len() >= MIN_DIM_SIZE_SIMD { + return unsafe { simple_neon::euclid_distance_neon(v1, v2) }; + } + } + + euclid_distance(v1, v2) + } + + fn preprocess(_vector: &mut [f32]) { + // no-op + } +} + +#[derive(Clone, Copy)] +pub struct CosineMetric {} + +impl Metric for CosineMetric { + fn distance(v1: &[f32], v2: &[f32]) -> f32 { + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("avx") + && is_x86_feature_detected!("fma") + && v1.len() >= MIN_DIM_SIZE_AVX + { + return 1.0 - unsafe { simd_avx::dot_similarity_avx(v1, v2) }; + } + } + + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + { + if is_x86_feature_detected!("sse") && v1.len() >= MIN_DIM_SIZE_SIMD { + return 1.0 - unsafe { simd_sse::dot_similarity_sse(v1, v2) }; + } + } + + #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] + { + if std::arch::is_aarch64_feature_detected!("neon") && v1.len() >= MIN_DIM_SIZE_SIMD { + return 1.0 - unsafe { simd_neon::dot_similarity_neon(v1, v2) }; + } + } + + 1.0 - dot_similarity(v1, v2) + } + + fn preprocess(vector: &mut [f32]) { + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("avx") + && is_x86_feature_detected!("fma") + && vector.len() >= MIN_DIM_SIZE_AVX + { + return unsafe { simd_avx::cosine_preprocess_avx(vector) }; + } + } + + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + { + if is_x86_feature_detected!("sse") && vector.len() >= MIN_DIM_SIZE_SIMD { + return unsafe { simd_sse::cosine_preprocess_sse(vector) }; + } + } + + #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] + { + if std::arch::is_aarch64_feature_detected!("neon") && vector.len() >= MIN_DIM_SIZE_SIMD + { + return unsafe { simd_neon::cosine_preprocess_neon(vector) }; + } + } + + cosine_preprocess(vector); + } +} + +pub fn euclid_distance(v1: &[f32], v2: &[f32]) -> f32 { + let s: f32 = v1 + .iter() + .copied() + .zip(v2.iter().copied()) + .map(|(a, b)| (a - b).powi(2)) + .sum(); + s.abs().sqrt() +} + +pub fn cosine_preprocess(vector: &mut [f32]) { + let mut length: f32 = vector.iter().map(|x| x * x).sum(); + if length < f32::EPSILON { + return; + } + length = length.sqrt(); + for x in vector.iter_mut() { + *x /= length; + } +} + +pub fn dot_similarity(v1: &[f32], v2: &[f32]) -> f32 { + v1.iter().zip(v2).map(|(a, b)| a * b).sum() +} + +pub fn legacy_distance(lhs: &FloatArray, rhs: &FloatArray) -> f32 { + #[cfg(target_arch = "x86_64")] + { + use std::arch::x86_64::{ + _mm256_castps256_ps128, _mm256_extractf128_ps, _mm256_fmadd_ps, _mm256_load_ps, + _mm256_setzero_ps, _mm256_sub_ps, _mm_add_ps, _mm_add_ss, _mm_cvtss_f32, _mm_fmadd_ps, + _mm_load_ps, _mm_movehl_ps, _mm_shuffle_ps, _mm_sub_ps, + }; + debug_assert_eq!(lhs.0.len() % 8, 4); + + unsafe { + let mut acc_8x = _mm256_setzero_ps(); + for (lh_slice, rh_slice) in lhs.0.chunks_exact(8).zip(rhs.0.chunks_exact(8)) { + let lh_8x = _mm256_load_ps(lh_slice.as_ptr()); + let rh_8x = _mm256_load_ps(rh_slice.as_ptr()); + let diff = _mm256_sub_ps(lh_8x, rh_8x); + acc_8x = _mm256_fmadd_ps(diff, diff, acc_8x); + } + + let mut acc_4x = _mm256_extractf128_ps(acc_8x, 1); // upper half + let right = _mm256_castps256_ps128(acc_8x); // lower half + acc_4x = _mm_add_ps(acc_4x, right); // sum halves + + let lh_4x = _mm_load_ps(lhs.0[DIMENSIONS - 4..].as_ptr()); + let rh_4x = _mm_load_ps(rhs.0[DIMENSIONS - 4..].as_ptr()); + let diff = _mm_sub_ps(lh_4x, rh_4x); + acc_4x = _mm_fmadd_ps(diff, diff, acc_4x); + + let lower = _mm_movehl_ps(acc_4x, acc_4x); + acc_4x = _mm_add_ps(acc_4x, lower); + let upper = _mm_shuffle_ps(acc_4x, acc_4x, 0x1); + acc_4x = _mm_add_ss(acc_4x, upper); + _mm_cvtss_f32(acc_4x) + } + } + #[cfg(not(target_arch = "x86_64"))] + lhs.0 + .iter() + .zip(rhs.0.iter()) + .map(|(&a, &b)| (a - b).powi(2)) + .sum::() +} + +#[repr(align(32))] +pub struct FloatArray(pub [f32; DIMENSIONS]); + +const DIMENSIONS: usize = 300; diff --git a/distance-metrics/src/simd_avx.rs b/distance-metrics/src/simd_avx.rs new file mode 100644 index 0000000..ea8a566 --- /dev/null +++ b/distance-metrics/src/simd_avx.rs @@ -0,0 +1,186 @@ +use std::arch::x86_64::*; + +#[target_feature(enable = "avx")] +#[target_feature(enable = "fma")] +unsafe fn hsum256_ps_avx(x: __m256) -> f32 { + let x128: __m128 = _mm_add_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x)); + let x64: __m128 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128)); + let x32: __m128 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55)); + _mm_cvtss_f32(x32) +} + +#[target_feature(enable = "avx")] +#[target_feature(enable = "fma")] +pub(crate) unsafe fn euclid_distance_avx(v1: &[f32], v2: &[f32]) -> f32 { + let n = v1.len(); + let m = n - (n % 32); + let mut ptr1: *const f32 = v1.as_ptr(); + let mut ptr2: *const f32 = v2.as_ptr(); + let mut sum256_1: __m256 = _mm256_setzero_ps(); + let mut sum256_2: __m256 = _mm256_setzero_ps(); + let mut sum256_3: __m256 = _mm256_setzero_ps(); + let mut sum256_4: __m256 = _mm256_setzero_ps(); + let mut i: usize = 0; + while i < m { + let sub256_1: __m256 = + _mm256_sub_ps(_mm256_loadu_ps(ptr1.add(0)), _mm256_loadu_ps(ptr2.add(0))); + sum256_1 = _mm256_fmadd_ps(sub256_1, sub256_1, sum256_1); + + let sub256_2: __m256 = + _mm256_sub_ps(_mm256_loadu_ps(ptr1.add(8)), _mm256_loadu_ps(ptr2.add(8))); + sum256_2 = _mm256_fmadd_ps(sub256_2, sub256_2, sum256_2); + + let sub256_3: __m256 = + _mm256_sub_ps(_mm256_loadu_ps(ptr1.add(16)), _mm256_loadu_ps(ptr2.add(16))); + sum256_3 = _mm256_fmadd_ps(sub256_3, sub256_3, sum256_3); + + let sub256_4: __m256 = + _mm256_sub_ps(_mm256_loadu_ps(ptr1.add(24)), _mm256_loadu_ps(ptr2.add(24))); + sum256_4 = _mm256_fmadd_ps(sub256_4, sub256_4, sum256_4); + + ptr1 = ptr1.add(32); + ptr2 = ptr2.add(32); + i += 32; + } + + let mut result = hsum256_ps_avx(sum256_1) + + hsum256_ps_avx(sum256_2) + + hsum256_ps_avx(sum256_3) + + hsum256_ps_avx(sum256_4); + for i in 0..n - m { + result += (*ptr1.add(i) - *ptr2.add(i)).powi(2); + } + result.abs().sqrt() +} + +#[target_feature(enable = "avx")] +#[target_feature(enable = "fma")] +pub(crate) unsafe fn cosine_preprocess_avx(vector: &mut [f32]) { + let n = vector.len(); + let m = n - (n % 32); + let mut ptr: *const f32 = vector.as_ptr(); + let mut sum256_1: __m256 = _mm256_setzero_ps(); + let mut sum256_2: __m256 = _mm256_setzero_ps(); + let mut sum256_3: __m256 = _mm256_setzero_ps(); + let mut sum256_4: __m256 = _mm256_setzero_ps(); + let mut i: usize = 0; + while i < m { + let m256_1 = _mm256_loadu_ps(ptr); + sum256_1 = _mm256_fmadd_ps(m256_1, m256_1, sum256_1); + + let m256_2 = _mm256_loadu_ps(ptr.add(8)); + sum256_2 = _mm256_fmadd_ps(m256_2, m256_2, sum256_2); + + let m256_3 = _mm256_loadu_ps(ptr.add(16)); + sum256_3 = _mm256_fmadd_ps(m256_3, m256_3, sum256_3); + + let m256_4 = _mm256_loadu_ps(ptr.add(24)); + sum256_4 = _mm256_fmadd_ps(m256_4, m256_4, sum256_4); + + ptr = ptr.add(32); + i += 32; + } + + let mut length = hsum256_ps_avx(sum256_1) + + hsum256_ps_avx(sum256_2) + + hsum256_ps_avx(sum256_3) + + hsum256_ps_avx(sum256_4); + for i in 0..n - m { + length += (*ptr.add(i)).powi(2); + } + if length < f32::EPSILON { + return; + } + length = length.sqrt(); + for x in vector.iter_mut() { + *x /= length; + } +} + +#[target_feature(enable = "avx")] +#[target_feature(enable = "fma")] +pub(crate) unsafe fn dot_similarity_avx(v1: &[f32], v2: &[f32]) -> f32 { + let n = v1.len(); + let m = n - (n % 32); + let mut ptr1: *const f32 = v1.as_ptr(); + let mut ptr2: *const f32 = v2.as_ptr(); + let mut sum256_1: __m256 = _mm256_setzero_ps(); + let mut sum256_2: __m256 = _mm256_setzero_ps(); + let mut sum256_3: __m256 = _mm256_setzero_ps(); + let mut sum256_4: __m256 = _mm256_setzero_ps(); + let mut i: usize = 0; + while i < m { + sum256_1 = _mm256_fmadd_ps(_mm256_loadu_ps(ptr1), _mm256_loadu_ps(ptr2), sum256_1); + sum256_2 = _mm256_fmadd_ps( + _mm256_loadu_ps(ptr1.add(8)), + _mm256_loadu_ps(ptr2.add(8)), + sum256_2, + ); + sum256_3 = _mm256_fmadd_ps( + _mm256_loadu_ps(ptr1.add(16)), + _mm256_loadu_ps(ptr2.add(16)), + sum256_3, + ); + sum256_4 = _mm256_fmadd_ps( + _mm256_loadu_ps(ptr1.add(24)), + _mm256_loadu_ps(ptr2.add(24)), + sum256_4, + ); + + ptr1 = ptr1.add(32); + ptr2 = ptr2.add(32); + i += 32; + } + + let mut result = hsum256_ps_avx(sum256_1) + + hsum256_ps_avx(sum256_2) + + hsum256_ps_avx(sum256_3) + + hsum256_ps_avx(sum256_4); + + for i in 0..n - m { + result += (*ptr1.add(i)) * (*ptr2.add(i)); + } + result +} + +#[cfg(test)] +mod tests { + #[test] + fn test_spaces_avx() { + use super::*; + use crate::*; + + if is_x86_feature_detected!("avx") && is_x86_feature_detected!("fma") { + let v1: Vec = vec![ + 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., + 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., + 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., + 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., + 26., 27., 28., 29., 30., 31., + ]; + let v2: Vec = vec![ + 40., 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55., + 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., + 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., + 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., + 56., 57., 58., 59., 60., 61., + ]; + + let euclid_simd = unsafe { euclid_distance_avx(&v1, &v2) }; + let euclid = euclid_distance(&v1, &v2); + assert_eq!(euclid_simd, euclid); + + let dot_simd = unsafe { dot_similarity_avx(&v1, &v2) }; + let dot = dot_similarity(&v1, &v2); + assert_eq!(dot_simd, dot); + + let mut v1 = v1; + let mut v1_copy = v1.clone(); + unsafe { cosine_preprocess_avx(&mut v1) }; + cosine_preprocess(&mut v1_copy); + assert_eq!(v1, v1_copy); + } else { + println!("avx test skipped"); + } + } +} diff --git a/distance-metrics/src/simd_neon.rs b/distance-metrics/src/simd_neon.rs new file mode 100644 index 0000000..7141d85 --- /dev/null +++ b/distance-metrics/src/simd_neon.rs @@ -0,0 +1,143 @@ +#[cfg(target_feature = "neon")] +use std::arch::aarch64::*; + +#[cfg(target_feature = "neon")] +pub(crate) unsafe fn euclid_distance_neon(v1: &[f32], v2: &[f32]) -> f32 { + let n = v1.len(); + let m = n - (n % 16); + let mut ptr1: *const f32 = v1.as_ptr(); + let mut ptr2: *const f32 = v2.as_ptr(); + let mut sum1 = vdupq_n_f32(0.); + let mut sum2 = vdupq_n_f32(0.); + let mut sum3 = vdupq_n_f32(0.); + let mut sum4 = vdupq_n_f32(0.); + + let mut i: usize = 0; + while i < m { + let sub1 = vsubq_f32(vld1q_f32(ptr1), vld1q_f32(ptr2)); + sum1 = vfmaq_f32(sum1, sub1, sub1); + + let sub2 = vsubq_f32(vld1q_f32(ptr1.add(4)), vld1q_f32(ptr2.add(4))); + sum2 = vfmaq_f32(sum2, sub2, sub2); + + let sub3 = vsubq_f32(vld1q_f32(ptr1.add(8)), vld1q_f32(ptr2.add(8))); + sum3 = vfmaq_f32(sum3, sub3, sub3); + + let sub4 = vsubq_f32(vld1q_f32(ptr1.add(12)), vld1q_f32(ptr2.add(12))); + sum4 = vfmaq_f32(sum4, sub4, sub4); + + ptr1 = ptr1.add(16); + ptr2 = ptr2.add(16); + i += 16; + } + let mut result = vaddvq_f32(sum1) + vaddvq_f32(sum2) + vaddvq_f32(sum3) + vaddvq_f32(sum4); + for i in 0..n - m { + result += (*ptr1.add(i) - *ptr2.add(i)).powi(2); + } + result.abs().sqrt() +} + +#[cfg(target_feature = "neon")] +pub(crate) unsafe fn cosine_preprocess_neon(vector: &mut [f32]) { + let n = vector.len(); + let m = n - (n % 16); + let mut ptr: *const f32 = vector.as_ptr(); + let mut sum1 = vdupq_n_f32(0.); + let mut sum2 = vdupq_n_f32(0.); + let mut sum3 = vdupq_n_f32(0.); + let mut sum4 = vdupq_n_f32(0.); + + let mut i: usize = 0; + while i < m { + let d1 = vld1q_f32(ptr); + sum1 = vfmaq_f32(sum1, d1, d1); + + let d2 = vld1q_f32(ptr.add(4)); + sum2 = vfmaq_f32(sum2, d2, d2); + + let d3 = vld1q_f32(ptr.add(8)); + sum3 = vfmaq_f32(sum3, d3, d3); + + let d4 = vld1q_f32(ptr.add(12)); + sum4 = vfmaq_f32(sum4, d4, d4); + + ptr = ptr.add(16); + i += 16; + } + let mut length = vaddvq_f32(sum1) + vaddvq_f32(sum2) + vaddvq_f32(sum3) + vaddvq_f32(sum4); + for v in vector.iter().take(n).skip(m) { + length += v.powi(2); + } + if length < f32::EPSILON { + return; + } + let length = length.sqrt(); + for x in vector.iter_mut() { + *x /= length; + } +} + +#[cfg(target_feature = "neon")] +pub(crate) unsafe fn dot_similarity_neon(v1: &[f32], v2: &[f32]) -> f32 { + let n = v1.len(); + let m = n - (n % 16); + let mut ptr1: *const f32 = v1.as_ptr(); + let mut ptr2: *const f32 = v2.as_ptr(); + let mut sum1 = vdupq_n_f32(0.); + let mut sum2 = vdupq_n_f32(0.); + let mut sum3 = vdupq_n_f32(0.); + let mut sum4 = vdupq_n_f32(0.); + + let mut i: usize = 0; + while i < m { + sum1 = vfmaq_f32(sum1, vld1q_f32(ptr1), vld1q_f32(ptr2)); + sum2 = vfmaq_f32(sum2, vld1q_f32(ptr1.add(4)), vld1q_f32(ptr2.add(4))); + sum3 = vfmaq_f32(sum3, vld1q_f32(ptr1.add(8)), vld1q_f32(ptr2.add(8))); + sum4 = vfmaq_f32(sum4, vld1q_f32(ptr1.add(12)), vld1q_f32(ptr2.add(12))); + ptr1 = ptr1.add(16); + ptr2 = ptr2.add(16); + i += 16; + } + let mut result = vaddvq_f32(sum1) + vaddvq_f32(sum2) + vaddvq_f32(sum3) + vaddvq_f32(sum4); + for i in 0..n - m { + result += (*ptr1.add(i)) * (*ptr2.add(i)); + } + result +} + +#[cfg(test)] +mod tests { + #[cfg(target_feature = "neon")] + #[test] + fn test_spaces_neon() { + use super::*; + use crate::*; + + if std::arch::is_aarch64_feature_detected!("neon") { + let v1: Vec = vec![ + 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., + 26., 27., 28., 29., 30., 31., + ]; + let v2: Vec = vec![ + 40., 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55., + 56., 57., 58., 59., 60., 61., + ]; + + let euclid_simd = unsafe { euclid_distance_neon(&v1, &v2) }; + let euclid = euclid_distance(&v1, &v2); + assert_eq!(euclid_simd, euclid); + + let dot_simd = unsafe { dot_similarity_neon(&v1, &v2) }; + let dot = dot_similarity(&v1, &v2); + assert_eq!(dot_simd, dot); + + let mut v1 = v1; + let mut v1_copy = v1.clone(); + unsafe { cosine_preprocess_neon(&mut v1) }; + cosine_preprocess(&mut v1_copy); + assert_eq!(v1, v1_copy); + } else { + println!("neon test skipped"); + } + } +} diff --git a/distance-metrics/src/simd_sse.rs b/distance-metrics/src/simd_sse.rs new file mode 100644 index 0000000..596afba --- /dev/null +++ b/distance-metrics/src/simd_sse.rs @@ -0,0 +1,179 @@ +#[cfg(target_arch = "x86")] +use std::arch::x86::*; +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; + +#[target_feature(enable = "sse")] +unsafe fn hsum128_ps_sse(x: __m128) -> f32 { + let x64: __m128 = _mm_add_ps(x, _mm_movehl_ps(x, x)); + let x32: __m128 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55)); + _mm_cvtss_f32(x32) +} + +#[target_feature(enable = "sse")] +pub(crate) unsafe fn euclid_distance_sse(v1: &[f32], v2: &[f32]) -> f32 { + let n = v1.len(); + let m = n - (n % 16); + let mut ptr1: *const f32 = v1.as_ptr(); + let mut ptr2: *const f32 = v2.as_ptr(); + let mut sum128_1: __m128 = _mm_setzero_ps(); + let mut sum128_2: __m128 = _mm_setzero_ps(); + let mut sum128_3: __m128 = _mm_setzero_ps(); + let mut sum128_4: __m128 = _mm_setzero_ps(); + let mut i: usize = 0; + while i < m { + let sub128_1 = _mm_sub_ps(_mm_loadu_ps(ptr1), _mm_loadu_ps(ptr2)); + sum128_1 = _mm_add_ps(_mm_mul_ps(sub128_1, sub128_1), sum128_1); + + let sub128_2 = _mm_sub_ps(_mm_loadu_ps(ptr1.add(4)), _mm_loadu_ps(ptr2.add(4))); + sum128_2 = _mm_add_ps(_mm_mul_ps(sub128_2, sub128_2), sum128_2); + + let sub128_3 = _mm_sub_ps(_mm_loadu_ps(ptr1.add(8)), _mm_loadu_ps(ptr2.add(8))); + sum128_3 = _mm_add_ps(_mm_mul_ps(sub128_3, sub128_3), sum128_3); + + let sub128_4 = _mm_sub_ps(_mm_loadu_ps(ptr1.add(12)), _mm_loadu_ps(ptr2.add(12))); + sum128_4 = _mm_add_ps(_mm_mul_ps(sub128_4, sub128_4), sum128_4); + + ptr1 = ptr1.add(16); + ptr2 = ptr2.add(16); + i += 16; + } + + let mut result = hsum128_ps_sse(sum128_1) + + hsum128_ps_sse(sum128_2) + + hsum128_ps_sse(sum128_3) + + hsum128_ps_sse(sum128_4); + for i in 0..n - m { + result += (*ptr1.add(i) - *ptr2.add(i)).powi(2); + } + result.abs().sqrt() +} + +#[target_feature(enable = "sse")] +pub(crate) unsafe fn cosine_preprocess_sse(vector: &mut [f32]) { + let n = vector.len(); + let m = n - (n % 16); + let mut ptr: *const f32 = vector.as_ptr(); + let mut sum128_1: __m128 = _mm_setzero_ps(); + let mut sum128_2: __m128 = _mm_setzero_ps(); + let mut sum128_3: __m128 = _mm_setzero_ps(); + let mut sum128_4: __m128 = _mm_setzero_ps(); + + let mut i: usize = 0; + while i < m { + let m128_1 = _mm_loadu_ps(ptr); + sum128_1 = _mm_add_ps(_mm_mul_ps(m128_1, m128_1), sum128_1); + + let m128_2 = _mm_loadu_ps(ptr.add(4)); + sum128_2 = _mm_add_ps(_mm_mul_ps(m128_2, m128_2), sum128_2); + + let m128_3 = _mm_loadu_ps(ptr.add(8)); + sum128_3 = _mm_add_ps(_mm_mul_ps(m128_3, m128_3), sum128_3); + + let m128_4 = _mm_loadu_ps(ptr.add(12)); + sum128_4 = _mm_add_ps(_mm_mul_ps(m128_4, m128_4), sum128_4); + + ptr = ptr.add(16); + i += 16; + } + + let mut length = hsum128_ps_sse(sum128_1) + + hsum128_ps_sse(sum128_2) + + hsum128_ps_sse(sum128_3) + + hsum128_ps_sse(sum128_4); + for i in 0..n - m { + length += (*ptr.add(i)).powi(2); + } + if length < f32::EPSILON { + return; + } + length = length.sqrt(); + for x in vector.iter_mut() { + *x /= length; + } +} + +#[target_feature(enable = "sse")] +pub(crate) unsafe fn dot_similarity_sse(v1: &[f32], v2: &[f32]) -> f32 { + let n = v1.len(); + let m = n - (n % 16); + let mut ptr1: *const f32 = v1.as_ptr(); + let mut ptr2: *const f32 = v2.as_ptr(); + let mut sum128_1: __m128 = _mm_setzero_ps(); + let mut sum128_2: __m128 = _mm_setzero_ps(); + let mut sum128_3: __m128 = _mm_setzero_ps(); + let mut sum128_4: __m128 = _mm_setzero_ps(); + + let mut i: usize = 0; + while i < m { + sum128_1 = _mm_add_ps(_mm_mul_ps(_mm_loadu_ps(ptr1), _mm_loadu_ps(ptr2)), sum128_1); + + sum128_2 = _mm_add_ps( + _mm_mul_ps(_mm_loadu_ps(ptr1.add(4)), _mm_loadu_ps(ptr2.add(4))), + sum128_2, + ); + + sum128_3 = _mm_add_ps( + _mm_mul_ps(_mm_loadu_ps(ptr1.add(8)), _mm_loadu_ps(ptr2.add(8))), + sum128_3, + ); + + sum128_4 = _mm_add_ps( + _mm_mul_ps(_mm_loadu_ps(ptr1.add(12)), _mm_loadu_ps(ptr2.add(12))), + sum128_4, + ); + + ptr1 = ptr1.add(16); + ptr2 = ptr2.add(16); + i += 16; + } + + let mut result = hsum128_ps_sse(sum128_1) + + hsum128_ps_sse(sum128_2) + + hsum128_ps_sse(sum128_3) + + hsum128_ps_sse(sum128_4); + for i in 0..n - m { + result += (*ptr1.add(i)) * (*ptr2.add(i)); + } + result +} + +#[cfg(test)] +mod tests { + #[test] + fn test_spaces_sse() { + use super::*; + use crate::*; + + if is_x86_feature_detected!("sse") { + let v1: Vec = vec![ + 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., + 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., + 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., + 26., 27., 28., 29., 30., 31., + ]; + let v2: Vec = vec![ + 40., 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55., + 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., + 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., + 56., 57., 58., 59., 60., 61., + ]; + + let euclid_simd = unsafe { euclid_distance_sse(&v1, &v2) }; + let euclid = euclid_distance(&v1, &v2); + assert_eq!(euclid_simd, euclid); + + let dot_simd = unsafe { dot_similarity_sse(&v1, &v2) }; + let dot = dot_similarity(&v1, &v2); + assert_eq!(dot_simd, dot); + + let mut v1 = v1; + let mut v1_copy = v1.clone(); + unsafe { cosine_preprocess_sse(&mut v1) }; + cosine_preprocess(&mut v1_copy); + assert_eq!(v1, v1_copy); + } else { + println!("sse test skipped"); + } + } +} diff --git a/instant-distance-py/Cargo.toml b/instant-distance-py/Cargo.toml index eafab40..d749518 100644 --- a/instant-distance-py/Cargo.toml +++ b/instant-distance-py/Cargo.toml @@ -16,7 +16,7 @@ crate-type = ["cdylib"] [dependencies] bincode = "1.3.1" +distance-metrics = { version = "0.6", path = "../distance-metrics" } instant-distance = { version = "0.6", path = "../instant-distance", features = ["with-serde"] } pyo3 = { version = "0.18.0", features = ["extension-module"] } serde = { version = "1", features = ["derive"] } -serde-big-array = "0.4.1" diff --git a/instant-distance-py/src/lib.rs b/instant-distance-py/src/lib.rs index bc35090..6f40449 100644 --- a/instant-distance-py/src/lib.rs +++ b/instant-distance-py/src/lib.rs @@ -5,15 +5,17 @@ use std::convert::TryFrom; use std::fs::File; use std::io::{BufReader, BufWriter}; use std::iter::FromIterator; +use std::marker::PhantomData; +use distance_metrics::Metric; +use distance_metrics::{CosineMetric, EuclidMetric}; use instant_distance::Point; use pyo3::conversion::IntoPy; -use pyo3::exceptions::{PyTypeError, PyValueError}; +use pyo3::exceptions::PyValueError; use pyo3::types::{PyList, PyModule, PyString}; use pyo3::{pyclass, pymethods, pymodule}; use pyo3::{Py, PyAny, PyErr, PyObject, PyRef, PyRefMut, PyResult, Python}; use serde::{Deserialize, Serialize}; -use serde_big_array::BigArray; #[pymodule] #[pyo3(name = "instant_distance")] @@ -24,12 +26,46 @@ fn instant_distance_py(_: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; Ok(()) } +#[pyclass] +#[derive(Copy, Clone)] +enum DistanceMetric { + Euclid, + Cosine, +} + +impl Default for DistanceMetric { + fn default() -> Self { + Self::Euclid + } +} + +// Helper macro for dispatching to inner implementation +macro_rules! impl_for_each_hnsw_with_metric { + ($type:ident, $instance:expr, $inner:ident, $($tokens:tt)+) => { + match $instance { + $type::Euclid($inner) => { + $($tokens)+ + } + $type::Cosine($inner) => { + $($tokens)+ + } + } + }; +} + #[pyclass] struct HnswMap { - inner: instant_distance::HnswMap, + inner: HnswMapWithMetric, +} + +#[derive(Deserialize, Serialize)] +enum HnswMapWithMetric { + Euclid(instant_distance::HnswMap, MapValue>), + Cosine(instant_distance::HnswMap, MapValue>), } #[pymethods] @@ -37,28 +73,32 @@ impl HnswMap { /// Build the index #[staticmethod] fn build(points: &PyList, values: &PyList, config: &Config) -> PyResult { - let points = points - .into_iter() - .map(FloatArray::try_from) - .collect::, PyErr>>()?; - let values = values .into_iter() .map(MapValue::try_from) .collect::, PyErr>>()?; - - let hsnw_map = instant_distance::Builder::from(config).build(points, values); - Ok(Self { inner: hsnw_map }) + let builder = instant_distance::Builder::from(config); + let inner = match config.distance_metric { + DistanceMetric::Euclid => { + let points = FloatArray::try_from_pylist(points)?; + HnswMapWithMetric::Euclid(builder.build(points, values)) + } + DistanceMetric::Cosine => { + let points = FloatArray::try_from_pylist(points)?; + HnswMapWithMetric::Cosine(builder.build(points, values)) + } + }; + Ok(Self { inner }) } /// Load an index from the given file name #[staticmethod] fn load(fname: &str) -> PyResult { - let hnsw_map = - bincode::deserialize_from::<_, instant_distance::HnswMap>( - BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?), - ) - .map_err(|e| PyValueError::new_err(format!("deserialization error: {e:?}")))?; + let hnsw_map = bincode::deserialize_from::<_, HnswMapWithMetric>(BufReader::with_capacity( + 32 * 1024 * 1024, + File::open(fname)?, + )) + .map_err(|e| PyValueError::new_err(format!("deserialization error: {e:?}")))?; Ok(Self { inner: hnsw_map }) } @@ -78,20 +118,25 @@ impl HnswMap { /// /// For best performance, reusing `Search` objects is recommended. fn search(slf: Py, point: &PyAny, search: &mut Search, py: Python<'_>) -> PyResult<()> { - let point = FloatArray::try_from(point)?; - let _ = slf.try_borrow(py)?.inner.search(&point, &mut search.inner); + impl_for_each_hnsw_with_metric!(HnswMapWithMetric, &slf.try_borrow(py)?.inner, hnsw, { + let point = FloatArray::try_from(point)?; + let _ = hnsw.search(&point, &mut search.inner); + }); search.cur = Some((HnswType::Map(slf.clone_ref(py)), 0)); Ok(()) } } /// An instance of hierarchical navigable small worlds -/// -/// For now, this is specialized to only support 300-element (32-bit) float vectors -/// with a squared Euclidean distance metric. #[pyclass] struct Hnsw { - inner: instant_distance::Hnsw, + inner: HnswWithMetric, +} + +#[derive(Deserialize, Serialize)] +enum HnswWithMetric { + Euclid(instant_distance::Hnsw>), + Cosine(instant_distance::Hnsw>), } #[pymethods] @@ -99,12 +144,19 @@ impl Hnsw { /// Build the index #[staticmethod] fn build(input: &PyList, config: &Config) -> PyResult<(Self, Vec)> { - let points = input - .into_iter() - .map(FloatArray::try_from) - .collect::, PyErr>>()?; - - let (inner, ids) = instant_distance::Builder::from(config).build_hnsw(points); + let builder = instant_distance::Builder::from(config); + let (inner, ids) = match config.distance_metric { + DistanceMetric::Euclid => { + let points = FloatArray::try_from_pylist(input)?; + let (hnsw, ids) = builder.build_hnsw(points); + (HnswWithMetric::Euclid(hnsw), ids) + } + DistanceMetric::Cosine => { + let points = FloatArray::try_from_pylist(input)?; + let (hnsw, ids) = builder.build_hnsw(points); + (HnswWithMetric::Cosine(hnsw), ids) + } + }; let ids = Vec::from_iter(ids.into_iter().map(|pid| pid.into_inner())); Ok((Self { inner }, ids)) } @@ -112,9 +164,10 @@ impl Hnsw { /// Load an index from the given file name #[staticmethod] fn load(fname: &str) -> PyResult { - let hnsw = bincode::deserialize_from::<_, instant_distance::Hnsw>( - BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?), - ) + let hnsw = bincode::deserialize_from::<_, HnswWithMetric>(BufReader::with_capacity( + 32 * 1024 * 1024, + File::open(fname)?, + )) .map_err(|e| PyValueError::new_err(format!("deserialization error: {e:?}")))?; Ok(Self { inner: hnsw }) } @@ -135,8 +188,10 @@ impl Hnsw { /// /// For best performance, reusing `Search` objects is recommended. fn search(slf: Py, point: &PyAny, search: &mut Search, py: Python<'_>) -> PyResult<()> { - let point = FloatArray::try_from(point)?; - let _ = slf.try_borrow(py)?.inner.search(&point, &mut search.inner); + impl_for_each_hnsw_with_metric!(HnswWithMetric, &slf.try_borrow(py)?.inner, hnsw, { + let point = FloatArray::try_from(point)?; + let _ = hnsw.search(&point, &mut search.inner); + }); search.cur = Some((HnswType::Hnsw(slf.clone_ref(py)), 0)); Ok(()) } @@ -175,20 +230,24 @@ impl Search { let neighbor = match &index { HnswType::Hnsw(hnsw) => { let hnsw = hnsw.as_ref(py).borrow(); - let item = hnsw.inner.get(idx, &slf.inner); - item.map(|item| Neighbor { - distance: item.distance, - pid: item.pid.into_inner(), - value: py.None(), + impl_for_each_hnsw_with_metric!(HnswWithMetric, &hnsw.inner, hnsw, { + let item = hnsw.get(idx, &slf.inner); + item.map(|item| Neighbor { + distance: item.distance, + pid: item.pid.into_inner(), + value: py.None(), + }) }) } HnswType::Map(map) => { let map = map.as_ref(py).borrow(); - let item = map.inner.get(idx, &slf.inner); - item.map(|item| Neighbor { - distance: item.distance, - pid: item.pid.into_inner(), - value: item.value.into_py(py), + impl_for_each_hnsw_with_metric!(HnswMapWithMetric, &map.inner, map, { + let item = map.get(idx, &slf.inner); + item.map(|item| Neighbor { + distance: item.distance, + pid: item.pid.into_inner(), + value: item.value.into_py(py), + }) }) } }; @@ -226,6 +285,11 @@ struct Config { /// in order to get better results on clustered data points. #[pyo3(get, set)] heuristic: Option, + /// Distance metric to use + /// + /// Defaults to Euclidean distance + #[pyo3(get, set)] + distance_metric: DistanceMetric, } #[pymethods] @@ -235,12 +299,14 @@ impl Config { let builder = instant_distance::Builder::default(); let (ef_search, ef_construction, ml, seed) = builder.into_parts(); let heuristic = Some(Heuristic::default()); + let distance_metric = DistanceMetric::default(); Self { ef_search, ef_construction, ml, seed, heuristic, + distance_metric, } } } @@ -253,6 +319,7 @@ impl From<&Config> for instant_distance::Builder { ml, seed, heuristic, + distance_metric: _, } = *py; Self::default() .ef_search(ef_search) @@ -346,67 +413,43 @@ impl Neighbor { } } -#[repr(align(32))] #[derive(Clone, Deserialize, Serialize)] -struct FloatArray(#[serde(with = "BigArray")] [f32; DIMENSIONS]); +struct FloatArray { + array: Vec, + phantom: PhantomData, +} -impl TryFrom<&PyAny> for FloatArray { +impl FloatArray { + fn try_from_pylist(list: &PyList) -> Result, PyErr> { + list.into_iter().map(FloatArray::try_from).collect() + } +} + +impl From> for FloatArray { + fn from(mut array: Vec) -> Self { + M::preprocess(&mut array); + Self { + array, + phantom: PhantomData, + } + } +} + +impl TryFrom<&PyAny> for FloatArray { type Error = PyErr; fn try_from(value: &PyAny) -> Result { - let mut new = FloatArray([0.0; DIMENSIONS]); - for (i, val) in value.iter()?.enumerate() { - match i >= DIMENSIONS { - true => return Err(PyTypeError::new_err("point array too long")), - false => new.0[i] = val?.extract::()?, - } - } - Ok(new) + let array: Vec = value + .iter()? + .map(|val| val.and_then(|v| v.extract::())) + .collect::>()?; + Ok(Self::from(array)) } } -impl Point for FloatArray { +impl Point for FloatArray { fn distance(&self, rhs: &Self) -> f32 { - #[cfg(target_arch = "x86_64")] - { - use std::arch::x86_64::{ - _mm256_castps256_ps128, _mm256_extractf128_ps, _mm256_fmadd_ps, _mm256_load_ps, - _mm256_setzero_ps, _mm256_sub_ps, _mm_add_ps, _mm_add_ss, _mm_cvtss_f32, - _mm_fmadd_ps, _mm_load_ps, _mm_movehl_ps, _mm_shuffle_ps, _mm_sub_ps, - }; - debug_assert_eq!(self.0.len() % 8, 4); - - unsafe { - let mut acc_8x = _mm256_setzero_ps(); - for (lh_slice, rh_slice) in self.0.chunks_exact(8).zip(rhs.0.chunks_exact(8)) { - let lh_8x = _mm256_load_ps(lh_slice.as_ptr()); - let rh_8x = _mm256_load_ps(rh_slice.as_ptr()); - let diff = _mm256_sub_ps(lh_8x, rh_8x); - acc_8x = _mm256_fmadd_ps(diff, diff, acc_8x); - } - - let mut acc_4x = _mm256_extractf128_ps(acc_8x, 1); // upper half - let right = _mm256_castps256_ps128(acc_8x); // lower half - acc_4x = _mm_add_ps(acc_4x, right); // sum halves - - let lh_4x = _mm_load_ps(self.0[DIMENSIONS - 4..].as_ptr()); - let rh_4x = _mm_load_ps(rhs.0[DIMENSIONS - 4..].as_ptr()); - let diff = _mm_sub_ps(lh_4x, rh_4x); - acc_4x = _mm_fmadd_ps(diff, diff, acc_4x); - - let lower = _mm_movehl_ps(acc_4x, acc_4x); - acc_4x = _mm_add_ps(acc_4x, lower); - let upper = _mm_shuffle_ps(acc_4x, acc_4x, 0x1); - acc_4x = _mm_add_ss(acc_4x, upper); - _mm_cvtss_f32(acc_4x) - } - } - #[cfg(not(target_arch = "x86_64"))] - self.0 - .iter() - .zip(rhs.0.iter()) - .map(|(&a, &b)| (a - b).powi(2)) - .sum::() + M::distance(&self.array, &rhs.array) } } @@ -430,5 +473,3 @@ impl IntoPy> for &'_ MapValue { } } } - -const DIMENSIONS: usize = 300; diff --git a/instant-distance-py/test/test.py b/instant-distance-py/test/test.py index 3ec2e23..a855d2f 100644 --- a/instant-distance-py/test/test.py +++ b/instant-distance-py/test/test.py @@ -1,9 +1,10 @@ import instant_distance, random -def test_hsnw(): +def test_hsnw(distance_metric=instant_distance.DistanceMetric.Euclid): points = [[random.random() for _ in range(300)] for _ in range(1024)] config = instant_distance.Config() + config.distance_metric = distance_metric (hnsw, ids) = instant_distance.Hnsw.build(points, config) p = [random.random() for _ in range(300)] search = instant_distance.Search() @@ -12,7 +13,7 @@ def test_hsnw(): print(candidate) -def test_hsnw_map(): +def test_hsnw_map(distance_metric=instant_distance.DistanceMetric.Euclid): the_chosen_one = 123 embeddings = [[random.random() for _ in range(300)] for _ in range(1024)] @@ -20,6 +21,7 @@ def test_hsnw_map(): values = f.read().splitlines()[1024:] config = instant_distance.Config() + config.distance_metric = distance_metric hnsw_map = instant_distance.HnswMap.build(embeddings, values, config) search = instant_distance.Search() @@ -38,3 +40,5 @@ def test_hsnw_map(): if __name__ == "__main__": test_hsnw() test_hsnw_map() + test_hsnw(instant_distance.DistanceMetric.Cosine) + test_hsnw_map(instant_distance.DistanceMetric.Cosine)