-
-
Notifications
You must be signed in to change notification settings - Fork 30
Implement distance metric selection #35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
[workspace] | ||
members = ["instant-distance", "instant-distance-py"] | ||
members = ["distance-metrics", "instant-distance", "instant-distance-py"] | ||
|
||
[profile.bench] | ||
debug = true |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,15 @@ | ||
test-python: | ||
cargo build --release | ||
cp target/release/libinstant_distance.dylib instant-distance-py/test/instant_distance.so | ||
instant-distance-py/test/instant_distance.so: instant-distance-py/src/lib.rs | ||
RUSTFLAGS="-C target-cpu=native" cargo build --release | ||
([ -f target/release/libinstant_distance.dylib ] && cp target/release/libinstant_distance.dylib instant-distance-py/test/instant_distance.so) || \ | ||
([ -f target/release/libinstant_distance.so ] && cp target/release/libinstant_distance.so instant-distance-py/test/instant_distance.so) | ||
|
||
test-python: instant-distance-py/test/instant_distance.so | ||
PYTHONPATH=instant-distance-py/test/ python3 -m test | ||
|
||
bench-python: instant-distance-py/test/instant_distance.so | ||
PYTHONPATH=instant-distance-py/test/ python3 -m timeit -n 10 -s 'import random, instant_distance; points = [[random.random() for _ in range(300)] for _ in range(1024)]; config = instant_distance.Config()' 'instant_distance.Hnsw.build(points, config)' | ||
PYTHONPATH=instant-distance-py/test/ python3 -m timeit -n 10 -s 'import random, instant_distance; points = [[random.random() for _ in range(300)] for _ in range(1024)]; config = instant_distance.Config(); config.distance_metric = instant_distance.DistanceMetric.Cosine' 'instant_distance.Hnsw.build(points, config)' | ||
|
||
clean: | ||
cargo clean | ||
rm -f instant-distance-py/test/instant_distance.so |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
[package] | ||
name = "distance-metrics" | ||
version = "0.6.0" | ||
license = "MIT OR Apache-2.0" | ||
edition = "2021" | ||
rust-version = "1.58" | ||
homepage = "https://github.com/InstantDomain/instant-distance" | ||
repository = "https://github.com/InstantDomain/instant-distance" | ||
documentation = "https://docs.rs/instant-distance" | ||
workspace = ".." | ||
readme = "../README.md" | ||
|
||
[dependencies] | ||
|
||
[dev-dependencies] | ||
bencher = "0.1.5" | ||
rand = { version = "0.8", features = ["small_rng"] } | ||
|
||
[[bench]] | ||
name = "all" | ||
harness = false |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
use bencher::{benchmark_group, benchmark_main, Bencher}; | ||
|
||
use distance_metrics::{ | ||
Metric, {CosineMetric, EuclidMetric}, | ||
}; | ||
use rand::{rngs::StdRng, Rng, SeedableRng}; | ||
|
||
benchmark_main!(benches); | ||
benchmark_group!( | ||
benches, | ||
legacy, | ||
non_simd, | ||
metric::<EuclidMetric>, | ||
metric::<CosineMetric> | ||
); | ||
|
||
fn legacy(bench: &mut Bencher) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would be curious to see some benchmarking results in this PR for the legacy approach vs your new implementations! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The most recent results:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So I think Did you run any tests to show that the platform-specific implementations have the same result (or close to it) as the non-SIMD implementation? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, the
(the numbers are higher because I'm on battery and the CPU is throttled) We could use that since it doesn't have dimensions limitation. Another idea would be to have the current metric (as for example |
||
let mut rng = StdRng::seed_from_u64(SEED); | ||
let point_a = distance_metrics::FloatArray([rng.gen(); 300]); | ||
let point_b = distance_metrics::FloatArray([rng.gen(); 300]); | ||
|
||
bench.iter(|| distance_metrics::legacy_distance(&point_a, &point_b)) | ||
} | ||
|
||
fn non_simd(bench: &mut Bencher) { | ||
let mut rng = StdRng::seed_from_u64(SEED); | ||
let point_a = [rng.gen(); 300]; | ||
let point_b = [rng.gen(); 300]; | ||
|
||
bench.iter(|| distance_metrics::euclid_distance(&point_a, &point_b)) | ||
} | ||
|
||
fn metric<M: Metric>(bench: &mut Bencher) { | ||
let mut rng = StdRng::seed_from_u64(SEED); | ||
let mut point_a = [rng.gen(); 300]; | ||
let mut point_b = [rng.gen(); 300]; | ||
M::preprocess(&mut point_a); | ||
M::preprocess(&mut point_b); | ||
|
||
bench.iter(|| M::distance(&point_a, &point_b)) | ||
} | ||
|
||
const SEED: u64 = 123456789; |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,199 @@ | ||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] | ||
pub mod simd_sse; | ||
|
||
#[cfg(target_arch = "x86_64")] | ||
pub mod simd_avx; | ||
|
||
#[cfg(target_arch = "aarch64")] | ||
pub mod simd_neon; | ||
|
||
/// Defines how to compare vectors | ||
pub trait Metric { | ||
/// Greater the value - more distant the vectors | ||
fn distance(v1: &[f32], v2: &[f32]) -> f32; | ||
|
||
/// Necessary vector transformations performed before adding it to the collection (like normalization) | ||
fn preprocess(vector: &mut [f32]); | ||
} | ||
|
||
#[cfg(target_arch = "x86_64")] | ||
const MIN_DIM_SIZE_AVX: usize = 32; | ||
|
||
#[cfg(any( | ||
target_arch = "x86", | ||
target_arch = "x86_64", | ||
all(target_arch = "aarch64", target_feature = "neon") | ||
))] | ||
const MIN_DIM_SIZE_SIMD: usize = 16; | ||
|
||
#[derive(Clone, Copy)] | ||
pub struct EuclidMetric {} | ||
|
||
impl Metric for EuclidMetric { | ||
fn distance(v1: &[f32], v2: &[f32]) -> f32 { | ||
#[cfg(target_arch = "x86_64")] | ||
{ | ||
if is_x86_feature_detected!("avx") | ||
&& is_x86_feature_detected!("fma") | ||
&& v1.len() >= MIN_DIM_SIZE_AVX | ||
{ | ||
return unsafe { simd_avx::euclid_distance_avx(v1, v2) }; | ||
} | ||
} | ||
|
||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] | ||
{ | ||
if is_x86_feature_detected!("sse") && v1.len() >= MIN_DIM_SIZE_SIMD { | ||
return unsafe { simd_sse::euclid_distance_sse(v1, v2) }; | ||
} | ||
} | ||
|
||
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))] | ||
{ | ||
if std::arch::is_aarch64_feature_detected!("neon") && v1.len() >= MIN_DIM_SIZE_SIMD { | ||
return unsafe { simple_neon::euclid_distance_neon(v1, v2) }; | ||
} | ||
} | ||
|
||
euclid_distance(v1, v2) | ||
} | ||
|
||
fn preprocess(_vector: &mut [f32]) { | ||
// no-op | ||
} | ||
} | ||
|
||
#[derive(Clone, Copy)] | ||
pub struct CosineMetric {} | ||
|
||
impl Metric for CosineMetric { | ||
fn distance(v1: &[f32], v2: &[f32]) -> f32 { | ||
#[cfg(target_arch = "x86_64")] | ||
{ | ||
if is_x86_feature_detected!("avx") | ||
&& is_x86_feature_detected!("fma") | ||
&& v1.len() >= MIN_DIM_SIZE_AVX | ||
{ | ||
return 1.0 - unsafe { simd_avx::dot_similarity_avx(v1, v2) }; | ||
} | ||
} | ||
|
||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] | ||
{ | ||
if is_x86_feature_detected!("sse") && v1.len() >= MIN_DIM_SIZE_SIMD { | ||
return 1.0 - unsafe { simd_sse::dot_similarity_sse(v1, v2) }; | ||
} | ||
} | ||
|
||
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))] | ||
{ | ||
if std::arch::is_aarch64_feature_detected!("neon") && v1.len() >= MIN_DIM_SIZE_SIMD { | ||
return 1.0 - unsafe { simd_neon::dot_similarity_neon(v1, v2) }; | ||
} | ||
} | ||
|
||
1.0 - dot_similarity(v1, v2) | ||
} | ||
|
||
fn preprocess(vector: &mut [f32]) { | ||
#[cfg(target_arch = "x86_64")] | ||
{ | ||
if is_x86_feature_detected!("avx") | ||
&& is_x86_feature_detected!("fma") | ||
&& vector.len() >= MIN_DIM_SIZE_AVX | ||
{ | ||
return unsafe { simd_avx::cosine_preprocess_avx(vector) }; | ||
} | ||
} | ||
|
||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] | ||
{ | ||
if is_x86_feature_detected!("sse") && vector.len() >= MIN_DIM_SIZE_SIMD { | ||
return unsafe { simd_sse::cosine_preprocess_sse(vector) }; | ||
} | ||
} | ||
|
||
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))] | ||
{ | ||
if std::arch::is_aarch64_feature_detected!("neon") && vector.len() >= MIN_DIM_SIZE_SIMD | ||
{ | ||
return unsafe { simd_neon::cosine_preprocess_neon(vector) }; | ||
} | ||
} | ||
|
||
cosine_preprocess(vector); | ||
} | ||
} | ||
|
||
pub fn euclid_distance(v1: &[f32], v2: &[f32]) -> f32 { | ||
let s: f32 = v1 | ||
.iter() | ||
.copied() | ||
.zip(v2.iter().copied()) | ||
.map(|(a, b)| (a - b).powi(2)) | ||
.sum(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Style nit: I usually prefer turbofishing over type annotations, so |
||
s.abs().sqrt() | ||
} | ||
Comment on lines
+128
to
+136
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be a |
||
|
||
pub fn cosine_preprocess(vector: &mut [f32]) { | ||
let mut length: f32 = vector.iter().map(|x| x * x).sum(); | ||
if length < f32::EPSILON { | ||
return; | ||
} | ||
length = length.sqrt(); | ||
for x in vector.iter_mut() { | ||
*x /= length; | ||
} | ||
} | ||
|
||
pub fn dot_similarity(v1: &[f32], v2: &[f32]) -> f32 { | ||
v1.iter().zip(v2).map(|(a, b)| a * b).sum() | ||
} | ||
|
||
pub fn legacy_distance(lhs: &FloatArray, rhs: &FloatArray) -> f32 { | ||
#[cfg(target_arch = "x86_64")] | ||
{ | ||
use std::arch::x86_64::{ | ||
_mm256_castps256_ps128, _mm256_extractf128_ps, _mm256_fmadd_ps, _mm256_load_ps, | ||
_mm256_setzero_ps, _mm256_sub_ps, _mm_add_ps, _mm_add_ss, _mm_cvtss_f32, _mm_fmadd_ps, | ||
_mm_load_ps, _mm_movehl_ps, _mm_shuffle_ps, _mm_sub_ps, | ||
}; | ||
debug_assert_eq!(lhs.0.len() % 8, 4); | ||
|
||
unsafe { | ||
let mut acc_8x = _mm256_setzero_ps(); | ||
for (lh_slice, rh_slice) in lhs.0.chunks_exact(8).zip(rhs.0.chunks_exact(8)) { | ||
let lh_8x = _mm256_load_ps(lh_slice.as_ptr()); | ||
let rh_8x = _mm256_load_ps(rh_slice.as_ptr()); | ||
let diff = _mm256_sub_ps(lh_8x, rh_8x); | ||
acc_8x = _mm256_fmadd_ps(diff, diff, acc_8x); | ||
} | ||
|
||
let mut acc_4x = _mm256_extractf128_ps(acc_8x, 1); // upper half | ||
let right = _mm256_castps256_ps128(acc_8x); // lower half | ||
acc_4x = _mm_add_ps(acc_4x, right); // sum halves | ||
|
||
let lh_4x = _mm_load_ps(lhs.0[DIMENSIONS - 4..].as_ptr()); | ||
let rh_4x = _mm_load_ps(rhs.0[DIMENSIONS - 4..].as_ptr()); | ||
let diff = _mm_sub_ps(lh_4x, rh_4x); | ||
acc_4x = _mm_fmadd_ps(diff, diff, acc_4x); | ||
|
||
let lower = _mm_movehl_ps(acc_4x, acc_4x); | ||
acc_4x = _mm_add_ps(acc_4x, lower); | ||
let upper = _mm_shuffle_ps(acc_4x, acc_4x, 0x1); | ||
acc_4x = _mm_add_ss(acc_4x, upper); | ||
_mm_cvtss_f32(acc_4x) | ||
} | ||
} | ||
#[cfg(not(target_arch = "x86_64"))] | ||
lhs.0 | ||
.iter() | ||
.zip(rhs.0.iter()) | ||
.map(|(&a, &b)| (a - b).powi(2)) | ||
.sum::<f32>() | ||
} | ||
|
||
#[repr(align(32))] | ||
pub struct FloatArray(pub [f32; DIMENSIONS]); | ||
|
||
const DIMENSIONS: usize = 300; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious what your motivation for adding a new crate for this is?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Separate mostly because I couldn't get benchmark to access
instant-distance-py
internals and the metrics felt to specific forinstant-distance
crate.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I gave this another try. There are 2 things that prevent me from adding benches directly to
instant-distance-py
:cdylib
- that one is easy we could have 2 types (cdylib
and regularlib
),instance-distance
, just like the other crate - AFAIK it makes it impossible to use both crates at the same time.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, let's just change the crate name to instant-distance-py in the Rust metadata (but should verify that the Python name will still be instant-distance), and then if we add the extra type we wouldn't need the extra crate anymore, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be possible if we add a tiny python wrapper module. That's to force maturin to build wheel for mixed rust/python project. Without the wrapper maturin uses lib.name as top level module name.