Skip to content

Commit dcb69c0

Browse files
committed
Implement distance metric selection
1 parent f1cb9ee commit dcb69c0

File tree

11 files changed

+932
-101
lines changed

11 files changed

+932
-101
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[workspace]
2-
members = ["instant-distance", "instant-distance-py"]
2+
members = ["distance-metrics", "instant-distance", "instant-distance-py"]
33

44
[profile.bench]
55
debug = true

Makefile

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
1-
test-python:
2-
cargo build --release
3-
cp target/release/libinstant_distance.dylib instant-distance-py/test/instant_distance.so
1+
instant-distance-py/test/instant_distance.so: instant-distance-py/src/lib.rs
2+
RUSTFLAGS="-C target-cpu=native" cargo build --release
3+
([ -f target/release/libinstant_distance.dylib ] && cp target/release/libinstant_distance.dylib instant-distance-py/test/instant_distance.so) || \
4+
([ -f target/release/libinstant_distance.so ] && cp target/release/libinstant_distance.so instant-distance-py/test/instant_distance.so)
5+
6+
test-python: instant-distance-py/test/instant_distance.so
47
PYTHONPATH=instant-distance-py/test/ python3 -m test
58

9+
bench-python: instant-distance-py/test/instant_distance.so
10+
PYTHONPATH=instant-distance-py/test/ python3 -m timeit -n 10 -s 'import random, instant_distance; points = [[random.random() for _ in range(300)] for _ in range(1024)]; config = instant_distance.Config()' 'instant_distance.Hnsw.build(points, config)'
11+
PYTHONPATH=instant-distance-py/test/ python3 -m timeit -n 10 -s 'import random, instant_distance; points = [[random.random() for _ in range(300)] for _ in range(1024)]; config = instant_distance.Config(); config.distance_metric = instant_distance.DistanceMetric.Cosine' 'instant_distance.Hnsw.build(points, config)'
12+
613
clean:
714
cargo clean
815
rm -f instant-distance-py/test/instant_distance.so

distance-metrics/Cargo.toml

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
[package]
2+
name = "distance-metrics"
3+
version = "0.6.0"
4+
license = "MIT OR Apache-2.0"
5+
edition = "2021"
6+
rust-version = "1.58"
7+
homepage = "https://github.com/InstantDomain/instant-distance"
8+
repository = "https://github.com/InstantDomain/instant-distance"
9+
documentation = "https://docs.rs/instant-distance"
10+
workspace = ".."
11+
readme = "../README.md"
12+
13+
[dependencies]
14+
15+
[dev-dependencies]
16+
bencher = "0.1.5"
17+
rand = { version = "0.8", features = ["small_rng"] }
18+
19+
[[bench]]
20+
name = "all"
21+
harness = false

distance-metrics/benches/all.rs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
use bencher::{benchmark_group, benchmark_main, Bencher};
2+
3+
use distance_metrics::{
4+
Metric, {CosineMetric, EuclidMetric},
5+
};
6+
use rand::{rngs::StdRng, Rng, SeedableRng};
7+
8+
benchmark_main!(benches);
9+
benchmark_group!(
10+
benches,
11+
legacy,
12+
non_simd,
13+
simple::<EuclidMetric>,
14+
simple::<CosineMetric>
15+
);
16+
17+
fn legacy(bench: &mut Bencher) {
18+
let mut rng = StdRng::seed_from_u64(SEED);
19+
let point_a = distance_metrics::FloatArray([rng.gen(); 300]);
20+
let point_b = distance_metrics::FloatArray([rng.gen(); 300]);
21+
22+
bench.iter(|| distance_metrics::legacy_distance(&point_a, &point_b))
23+
}
24+
25+
fn non_simd(bench: &mut Bencher) {
26+
let mut rng = StdRng::seed_from_u64(SEED);
27+
let point_a = [rng.gen(); 300];
28+
let point_b = [rng.gen(); 300];
29+
30+
bench.iter(|| distance_metrics::euclid_distance(&point_a, &point_b))
31+
}
32+
33+
fn simple<M: Metric>(bench: &mut Bencher) {
34+
let mut rng = StdRng::seed_from_u64(SEED);
35+
let mut point_a = [rng.gen(); 300];
36+
let mut point_b = [rng.gen(); 300];
37+
M::preprocess(&mut point_a);
38+
M::preprocess(&mut point_b);
39+
40+
bench.iter(|| M::distance(&point_a, &point_b))
41+
}
42+
43+
const SEED: u64 = 123456789;

distance-metrics/src/lib.rs

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2+
pub mod simd_sse;
3+
4+
#[cfg(target_arch = "x86_64")]
5+
pub mod simd_avx;
6+
7+
#[cfg(target_arch = "aarch64")]
8+
pub mod simd_neon;
9+
10+
/// Defines how to compare vectors
11+
pub trait Metric {
12+
/// Greater the value - more distant the vectors
13+
fn distance(v1: &[f32], v2: &[f32]) -> f32;
14+
15+
/// Necessary vector transformations performed before adding it to the collection (like normalization)
16+
fn preprocess(vector: &mut [f32]);
17+
}
18+
19+
#[cfg(target_arch = "x86_64")]
20+
const MIN_DIM_SIZE_AVX: usize = 32;
21+
22+
#[cfg(any(
23+
target_arch = "x86",
24+
target_arch = "x86_64",
25+
all(target_arch = "aarch64", target_feature = "neon")
26+
))]
27+
const MIN_DIM_SIZE_SIMD: usize = 16;
28+
29+
#[derive(Clone, Copy)]
30+
pub struct EuclidMetric {}
31+
32+
impl Metric for EuclidMetric {
33+
fn distance(v1: &[f32], v2: &[f32]) -> f32 {
34+
#[cfg(target_arch = "x86_64")]
35+
{
36+
if is_x86_feature_detected!("avx")
37+
&& is_x86_feature_detected!("fma")
38+
&& v1.len() >= MIN_DIM_SIZE_AVX
39+
{
40+
return unsafe { simd_avx::euclid_distance_avx(v1, v2) };
41+
}
42+
}
43+
44+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
45+
{
46+
if is_x86_feature_detected!("sse") && v1.len() >= MIN_DIM_SIZE_SIMD {
47+
return unsafe { simd_sse::euclid_distance_sse(v1, v2) };
48+
}
49+
}
50+
51+
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
52+
{
53+
if std::arch::is_aarch64_feature_detected!("neon") && v1.len() >= MIN_DIM_SIZE_SIMD {
54+
return unsafe { simple_neon::euclid_distance_neon(v1, v2) };
55+
}
56+
}
57+
58+
euclid_distance(v1, v2)
59+
}
60+
61+
fn preprocess(_vector: &mut [f32]) {
62+
// no-op
63+
}
64+
}
65+
66+
#[derive(Clone, Copy)]
67+
pub struct CosineMetric {}
68+
69+
impl Metric for CosineMetric {
70+
fn distance(v1: &[f32], v2: &[f32]) -> f32 {
71+
#[cfg(target_arch = "x86_64")]
72+
{
73+
if is_x86_feature_detected!("avx")
74+
&& is_x86_feature_detected!("fma")
75+
&& v1.len() >= MIN_DIM_SIZE_AVX
76+
{
77+
return 1.0 - unsafe { simd_avx::dot_similarity_avx(v1, v2) };
78+
}
79+
}
80+
81+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
82+
{
83+
if is_x86_feature_detected!("sse") && v1.len() >= MIN_DIM_SIZE_SIMD {
84+
return 1.0 - unsafe { simd_sse::dot_similarity_sse(v1, v2) };
85+
}
86+
}
87+
88+
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
89+
{
90+
if std::arch::is_aarch64_feature_detected!("neon") && v1.len() >= MIN_DIM_SIZE_SIMD {
91+
return 1.0 - unsafe { simd_neon::dot_similarity_neon(v1, v2) };
92+
}
93+
}
94+
95+
1.0 - dot_similarity(v1, v2)
96+
}
97+
98+
fn preprocess(vector: &mut [f32]) {
99+
#[cfg(target_arch = "x86_64")]
100+
{
101+
if is_x86_feature_detected!("avx")
102+
&& is_x86_feature_detected!("fma")
103+
&& vector.len() >= MIN_DIM_SIZE_AVX
104+
{
105+
return unsafe { simd_avx::cosine_preprocess_avx(vector) };
106+
}
107+
}
108+
109+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
110+
{
111+
if is_x86_feature_detected!("sse") && vector.len() >= MIN_DIM_SIZE_SIMD {
112+
return unsafe { simd_sse::cosine_preprocess_sse(vector) };
113+
}
114+
}
115+
116+
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
117+
{
118+
if std::arch::is_aarch64_feature_detected!("neon") && vector.len() >= MIN_DIM_SIZE_SIMD
119+
{
120+
return unsafe { simd_neon::cosine_preprocess_neon(vector) };
121+
}
122+
}
123+
124+
cosine_preprocess(vector);
125+
}
126+
}
127+
128+
pub fn euclid_distance(v1: &[f32], v2: &[f32]) -> f32 {
129+
let s: f32 = v1
130+
.iter()
131+
.copied()
132+
.zip(v2.iter().copied())
133+
.map(|(a, b)| (a - b).powi(2))
134+
.sum();
135+
s.abs().sqrt()
136+
}
137+
138+
pub fn cosine_preprocess(vector: &mut [f32]) {
139+
let mut length: f32 = vector.iter().map(|x| x * x).sum();
140+
if length < f32::EPSILON {
141+
return;
142+
}
143+
length = length.sqrt();
144+
for x in vector.iter_mut() {
145+
*x /= length;
146+
}
147+
}
148+
149+
pub fn dot_similarity(v1: &[f32], v2: &[f32]) -> f32 {
150+
v1.iter().zip(v2).map(|(a, b)| a * b).sum()
151+
}
152+
153+
pub fn legacy_distance(lhs: &FloatArray, rhs: &FloatArray) -> f32 {
154+
#[cfg(target_arch = "x86_64")]
155+
{
156+
use std::arch::x86_64::{
157+
_mm256_castps256_ps128, _mm256_extractf128_ps, _mm256_fmadd_ps, _mm256_load_ps,
158+
_mm256_setzero_ps, _mm256_sub_ps, _mm_add_ps, _mm_add_ss, _mm_cvtss_f32, _mm_fmadd_ps,
159+
_mm_load_ps, _mm_movehl_ps, _mm_shuffle_ps, _mm_sub_ps,
160+
};
161+
debug_assert_eq!(lhs.0.len() % 8, 4);
162+
163+
unsafe {
164+
let mut acc_8x = _mm256_setzero_ps();
165+
for (lh_slice, rh_slice) in lhs.0.chunks_exact(8).zip(rhs.0.chunks_exact(8)) {
166+
let lh_8x = _mm256_load_ps(lh_slice.as_ptr());
167+
let rh_8x = _mm256_load_ps(rh_slice.as_ptr());
168+
let diff = _mm256_sub_ps(lh_8x, rh_8x);
169+
acc_8x = _mm256_fmadd_ps(diff, diff, acc_8x);
170+
}
171+
172+
let mut acc_4x = _mm256_extractf128_ps(acc_8x, 1); // upper half
173+
let right = _mm256_castps256_ps128(acc_8x); // lower half
174+
acc_4x = _mm_add_ps(acc_4x, right); // sum halves
175+
176+
let lh_4x = _mm_load_ps(lhs.0[DIMENSIONS - 4..].as_ptr());
177+
let rh_4x = _mm_load_ps(rhs.0[DIMENSIONS - 4..].as_ptr());
178+
let diff = _mm_sub_ps(lh_4x, rh_4x);
179+
acc_4x = _mm_fmadd_ps(diff, diff, acc_4x);
180+
181+
let lower = _mm_movehl_ps(acc_4x, acc_4x);
182+
acc_4x = _mm_add_ps(acc_4x, lower);
183+
let upper = _mm_shuffle_ps(acc_4x, acc_4x, 0x1);
184+
acc_4x = _mm_add_ss(acc_4x, upper);
185+
_mm_cvtss_f32(acc_4x)
186+
}
187+
}
188+
#[cfg(not(target_arch = "x86_64"))]
189+
lhs.0
190+
.iter()
191+
.zip(rhs.0.iter())
192+
.map(|(&a, &b)| (a - b).powi(2))
193+
.sum::<f32>()
194+
}
195+
196+
#[repr(align(32))]
197+
pub struct FloatArray(pub [f32; DIMENSIONS]);
198+
199+
const DIMENSIONS: usize = 300;

0 commit comments

Comments
 (0)