Skip to content

Commit 264dd95

Browse files
committed
Implement distance metric selection
1 parent f1cb9ee commit 264dd95

File tree

12 files changed

+933
-101
lines changed

12 files changed

+933
-101
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[workspace]
2-
members = ["instant-distance", "instant-distance-py"]
2+
members = ["distance-metrics", "instant-distance", "instant-distance-py"]
33

44
[profile.bench]
55
debug = true

Makefile

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
1-
test-python:
2-
cargo build --release
3-
cp target/release/libinstant_distance.dylib instant-distance-py/test/instant_distance.so
1+
instant-distance-py/test/instant_distance.so: instant-distance-py/src/lib.rs
2+
RUSTFLAGS="-C target-cpu=native" cargo build --release
3+
([ -f target/release/libinstant_distance.dylib ] && cp target/release/libinstant_distance.dylib instant-distance-py/test/instant_distance.so) || \
4+
([ -f target/release/libinstant_distance.so ] && cp target/release/libinstant_distance.so instant-distance-py/test/instant_distance.so)
5+
6+
test-python: instant-distance-py/test/instant_distance.so
47
PYTHONPATH=instant-distance-py/test/ python3 -m test
58

9+
bench-python: instant-distance-py/test/instant_distance.so
10+
PYTHONPATH=instant-distance-py/test/ python3 -m timeit -n 10 -s 'import random, instant_distance; points = [[random.random() for _ in range(300)] for _ in range(1024)]; config = instant_distance.Config()' 'instant_distance.Hnsw.build(points, config)'
11+
PYTHONPATH=instant-distance-py/test/ python3 -m timeit -n 10 -s 'import random, instant_distance; points = [[random.random() for _ in range(300)] for _ in range(1024)]; config = instant_distance.Config(); config.distance_metric = instant_distance.DistanceMetric.Cosine' 'instant_distance.Hnsw.build(points, config)'
12+
613
clean:
714
cargo clean
815
rm -f instant-distance-py/test/instant_distance.so

distance-metrics/Cargo.toml

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
[package]
2+
name = "distance-metrics"
3+
version = "0.6.0"
4+
license = "MIT OR Apache-2.0"
5+
edition = "2021"
6+
rust-version = "1.58"
7+
homepage = "https://github.com/InstantDomain/instant-distance"
8+
repository = "https://github.com/InstantDomain/instant-distance"
9+
documentation = "https://docs.rs/instant-distance"
10+
workspace = ".."
11+
readme = "../README.md"
12+
13+
[dependencies]
14+
15+
[dev-dependencies]
16+
bencher = "0.1.5"
17+
rand = { version = "0.8", features = ["small_rng"] }
18+
19+
[[bench]]
20+
name = "all"
21+
harness = false

distance-metrics/benches/all.rs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
use bencher::{benchmark_group, benchmark_main, Bencher};
2+
3+
use distance_metrics::{
4+
Metric, {CosineMetric, EuclidMetric},
5+
};
6+
use rand::{rngs::StdRng, Rng, SeedableRng};
7+
8+
benchmark_main!(benches);
9+
benchmark_group!(
10+
benches,
11+
legacy,
12+
non_simd,
13+
simple::<EuclidMetric>,
14+
simple::<CosineMetric>
15+
);
16+
17+
fn legacy(bench: &mut Bencher) {
18+
let mut rng = StdRng::seed_from_u64(SEED);
19+
let point_a = distance_metrics::legacy::FloatArray([rng.gen(); 300]);
20+
let point_b = distance_metrics::legacy::FloatArray([rng.gen(); 300]);
21+
22+
bench.iter(|| distance_metrics::legacy::distance(&point_a, &point_b))
23+
}
24+
25+
fn non_simd(bench: &mut Bencher) {
26+
let mut rng = StdRng::seed_from_u64(SEED);
27+
let point_a = [rng.gen(); 300];
28+
let point_b = [rng.gen(); 300];
29+
30+
bench.iter(|| distance_metrics::euclid_distance(&point_a, &point_b))
31+
}
32+
33+
fn simple<M: Metric>(bench: &mut Bencher) {
34+
let mut rng = StdRng::seed_from_u64(SEED);
35+
let mut point_a = [rng.gen(); 300];
36+
let mut point_b = [rng.gen(); 300];
37+
M::preprocess(&mut point_a);
38+
M::preprocess(&mut point_b);
39+
40+
bench.iter(|| M::distance(&point_a, &point_b))
41+
}
42+
43+
const SEED: u64 = 123456789;

distance-metrics/src/legacy.rs

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
const DIMENSIONS: usize = 300;
2+
3+
#[repr(align(32))]
4+
pub struct FloatArray(pub [f32; DIMENSIONS]);
5+
6+
pub fn distance(lhs: &FloatArray, rhs: &FloatArray) -> f32 {
7+
#[cfg(target_arch = "x86_64")]
8+
{
9+
use std::arch::x86_64::{
10+
_mm256_castps256_ps128, _mm256_extractf128_ps, _mm256_fmadd_ps, _mm256_load_ps,
11+
_mm256_setzero_ps, _mm256_sub_ps, _mm_add_ps, _mm_add_ss, _mm_cvtss_f32, _mm_fmadd_ps,
12+
_mm_load_ps, _mm_movehl_ps, _mm_shuffle_ps, _mm_sub_ps,
13+
};
14+
debug_assert_eq!(lhs.0.len() % 8, 4);
15+
16+
unsafe {
17+
let mut acc_8x = _mm256_setzero_ps();
18+
for (lh_slice, rh_slice) in lhs.0.chunks_exact(8).zip(rhs.0.chunks_exact(8)) {
19+
let lh_8x = _mm256_load_ps(lh_slice.as_ptr());
20+
let rh_8x = _mm256_load_ps(rh_slice.as_ptr());
21+
let diff = _mm256_sub_ps(lh_8x, rh_8x);
22+
acc_8x = _mm256_fmadd_ps(diff, diff, acc_8x);
23+
}
24+
25+
let mut acc_4x = _mm256_extractf128_ps(acc_8x, 1); // upper half
26+
let right = _mm256_castps256_ps128(acc_8x); // lower half
27+
acc_4x = _mm_add_ps(acc_4x, right); // sum halves
28+
29+
let lh_4x = _mm_load_ps(lhs.0[DIMENSIONS - 4..].as_ptr());
30+
let rh_4x = _mm_load_ps(rhs.0[DIMENSIONS - 4..].as_ptr());
31+
let diff = _mm_sub_ps(lh_4x, rh_4x);
32+
acc_4x = _mm_fmadd_ps(diff, diff, acc_4x);
33+
34+
let lower = _mm_movehl_ps(acc_4x, acc_4x);
35+
acc_4x = _mm_add_ps(acc_4x, lower);
36+
let upper = _mm_shuffle_ps(acc_4x, acc_4x, 0x1);
37+
acc_4x = _mm_add_ss(acc_4x, upper);
38+
_mm_cvtss_f32(acc_4x)
39+
}
40+
}
41+
#[cfg(not(target_arch = "x86_64"))]
42+
lhs.0
43+
.iter()
44+
.zip(rhs.0.iter())
45+
.map(|(&a, &b)| (a - b).powi(2))
46+
.sum::<f32>()
47+
}

distance-metrics/src/lib.rs

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
pub mod legacy;
2+
3+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
4+
pub mod simd_sse;
5+
6+
#[cfg(target_arch = "x86_64")]
7+
pub mod simd_avx;
8+
9+
#[cfg(target_arch = "aarch64")]
10+
pub mod simd_neon;
11+
12+
/// Defines how to compare vectors
13+
pub trait Metric {
14+
/// Greater the value - more distant the vectors
15+
fn distance(v1: &[f32], v2: &[f32]) -> f32;
16+
17+
/// Necessary vector transformations performed before adding it to the collection (like normalization)
18+
fn preprocess(vector: &mut [f32]);
19+
}
20+
21+
#[cfg(target_arch = "x86_64")]
22+
const MIN_DIM_SIZE_AVX: usize = 32;
23+
24+
#[cfg(any(
25+
target_arch = "x86",
26+
target_arch = "x86_64",
27+
all(target_arch = "aarch64", target_feature = "neon")
28+
))]
29+
const MIN_DIM_SIZE_SIMD: usize = 16;
30+
31+
#[derive(Clone, Copy)]
32+
pub struct EuclidMetric {}
33+
34+
impl Metric for EuclidMetric {
35+
fn distance(v1: &[f32], v2: &[f32]) -> f32 {
36+
#[cfg(target_arch = "x86_64")]
37+
{
38+
if is_x86_feature_detected!("avx")
39+
&& is_x86_feature_detected!("fma")
40+
&& v1.len() >= MIN_DIM_SIZE_AVX
41+
{
42+
return unsafe { simd_avx::euclid_distance_avx(v1, v2) };
43+
}
44+
}
45+
46+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
47+
{
48+
if is_x86_feature_detected!("sse") && v1.len() >= MIN_DIM_SIZE_SIMD {
49+
return unsafe { simd_sse::euclid_distance_sse(v1, v2) };
50+
}
51+
}
52+
53+
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
54+
{
55+
if std::arch::is_aarch64_feature_detected!("neon") && v1.len() >= MIN_DIM_SIZE_SIMD {
56+
return unsafe { simple_neon::euclid_distance_neon(v1, v2) };
57+
}
58+
}
59+
60+
euclid_distance(v1, v2)
61+
}
62+
63+
fn preprocess(_vector: &mut [f32]) {
64+
// no-op
65+
}
66+
}
67+
68+
#[derive(Clone, Copy)]
69+
pub struct CosineMetric {}
70+
71+
impl Metric for CosineMetric {
72+
fn distance(v1: &[f32], v2: &[f32]) -> f32 {
73+
#[cfg(target_arch = "x86_64")]
74+
{
75+
if is_x86_feature_detected!("avx")
76+
&& is_x86_feature_detected!("fma")
77+
&& v1.len() >= MIN_DIM_SIZE_AVX
78+
{
79+
return 1.0 - unsafe { simd_avx::dot_similarity_avx(v1, v2) };
80+
}
81+
}
82+
83+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
84+
{
85+
if is_x86_feature_detected!("sse") && v1.len() >= MIN_DIM_SIZE_SIMD {
86+
return 1.0 - unsafe { simd_sse::dot_similarity_sse(v1, v2) };
87+
}
88+
}
89+
90+
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
91+
{
92+
if std::arch::is_aarch64_feature_detected!("neon") && v1.len() >= MIN_DIM_SIZE_SIMD {
93+
return 1.0 - unsafe { simd_neon::dot_similarity_neon(v1, v2) };
94+
}
95+
}
96+
97+
1.0 - dot_similarity(v1, v2)
98+
}
99+
100+
fn preprocess(vector: &mut [f32]) {
101+
#[cfg(target_arch = "x86_64")]
102+
{
103+
if is_x86_feature_detected!("avx")
104+
&& is_x86_feature_detected!("fma")
105+
&& vector.len() >= MIN_DIM_SIZE_AVX
106+
{
107+
return unsafe { simd_avx::cosine_preprocess_avx(vector) };
108+
}
109+
}
110+
111+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
112+
{
113+
if is_x86_feature_detected!("sse") && vector.len() >= MIN_DIM_SIZE_SIMD {
114+
return unsafe { simd_sse::cosine_preprocess_sse(vector) };
115+
}
116+
}
117+
118+
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
119+
{
120+
if std::arch::is_aarch64_feature_detected!("neon") && vector.len() >= MIN_DIM_SIZE_SIMD
121+
{
122+
return unsafe { simd_neon::cosine_preprocess_neon(vector) };
123+
}
124+
}
125+
126+
cosine_preprocess(vector);
127+
}
128+
}
129+
130+
pub fn euclid_distance(v1: &[f32], v2: &[f32]) -> f32 {
131+
let s: f32 = v1
132+
.iter()
133+
.copied()
134+
.zip(v2.iter().copied())
135+
.map(|(a, b)| (a - b).powi(2))
136+
.sum();
137+
s.abs().sqrt()
138+
}
139+
140+
pub fn cosine_preprocess(vector: &mut [f32]) {
141+
let mut length: f32 = vector.iter().map(|x| x * x).sum();
142+
if length < f32::EPSILON {
143+
return;
144+
}
145+
length = length.sqrt();
146+
for x in vector.iter_mut() {
147+
*x /= length;
148+
}
149+
}
150+
151+
pub fn dot_similarity(v1: &[f32], v2: &[f32]) -> f32 {
152+
v1.iter().zip(v2).map(|(a, b)| a * b).sum()
153+
}

0 commit comments

Comments
 (0)