Skip to content

Commit 3ac4525

Browse files
committed
chore: BhattacharyyaDist
1 parent 9510657 commit 3ac4525

File tree

5 files changed

+25
-11
lines changed

5 files changed

+25
-11
lines changed

src/distance.rs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@ pub use crate::generic::metric::{
55
};
66
pub use crate::generic::pre_metric::{GenKLDivergence, KLDivergence};
77
pub use crate::generic::semi_metric::{
8-
BrayCurtis, ChiSqDist, JSDivergence, MeanAbsDeviation, MeanSqDeviation, SqEuclidean,
8+
BhattacharyyaDist, BrayCurtis, ChiSqDist, JSDivergence, MeanAbsDeviation, MeanSqDeviation,
9+
SqEuclidean,
910
};
1011

11-
1212
/// Implement this trait for a distance metric. The trait provides a method to evaluate the distance
1313
/// between two arrays along a specified axis.
1414
pub trait Distance<T>
@@ -159,3 +159,15 @@ impl<T: 'static + num_traits::Float + num_traits::FromPrimitive> Distance<T> for
159159
(x - y).pow2().mean_axis(axis).unwrap().sqrt()
160160
}
161161
}
162+
163+
#[doc(hidden)]
164+
impl<T: 'static + num_traits::Float + num_traits::FromPrimitive + ndarray::ScalarOperand>
165+
Distance<T> for BhattacharyyaDist
166+
{
167+
unsafe fn distance(&self, x: &ArrayD<T>, y: &ArrayD<T>, axis: Axis) -> ArrayD<T> {
168+
let ln_sqrt_sum_x_y = (x * y).sqrt().sum_axis(axis).ln();
169+
let sqrt_sum_x = x.sum_axis(axis).ln() * T::from(0.5).unwrap();
170+
let sqrt_sum_y = y.sum_axis(axis).ln() * T::from(0.5).unwrap();
171+
sqrt_sum_x + sqrt_sum_y - ln_sqrt_sum_x_y
172+
}
173+
}

src/generic.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,23 +19,23 @@ macro_rules! impl_metric {
1919
};
2020
}
2121

22-
2322
pub mod pre_metric {
2423
impl_metric! {
25-
KLDivergence,
2624
GenKLDivergence,
25+
KLDivergence,
2726
NormRMSDeviation,
2827
}
2928
}
3029
pub mod semi_metric {
3130
impl_metric! {
32-
SqEuclidean,
31+
BhattacharyyaDist,
3332
BrayCurtis,
3433
ChiSqDist,
3534
JSDivergence,
36-
SpanNormDist,
3735
MeanAbsDeviation,
3836
MeanSqDeviation,
37+
SpanNormDist,
38+
SqEuclidean,
3939
}
4040
}
4141
pub mod metric {

src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ pub use generic::pre_metric::GenKLDivergence;
3434
/// `KLDivergence(x, y) = sum(x .* ln(x ./ y))`
3535
#[doc(inline)]
3636
pub use generic::pre_metric::KLDivergence;
37+
/// `BhattacharyyaDist(x, y) = -ln(sum(sqrt(x .* y)))+0.5*ln(sum(x))+0.5*ln(sum(y))`
38+
#[doc(inline)]
39+
pub use generic::semi_metric::BhattacharyyaDist;
3740
/// `BrayCurtis(x, y) = sum(|x - y|) / sum(|x + y|)`
3841
#[doc(inline)]
3942
pub use generic::semi_metric::BrayCurtis;

tests/common/mod.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ macro_rules! test_triangular_inequality {
7373
};
7474
}
7575

76-
7776
#[macro_export]
7877
macro_rules! test_only_on_negative_values {
7978
($shape:expr, $axis:expr, $metric:expr, $name:ident) => {
@@ -89,7 +88,6 @@ macro_rules! test_only_on_negative_values {
8988
};
9089
}
9190

92-
9391
#[macro_export]
9492
macro_rules! test_on_negative_and_positive_values {
9593
($shape:expr, $axis:expr, $metric:expr, $name:ident) => {

tests/tests.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
mod common;
22
use faasle::{
3-
BrayCurtis, Chebyshev, ChiSqDist, Cityblock, Distance, Euclidean, GenKLDivergence, Hamming,
4-
JSDivergence, KLDivergence, MeanAbsDeviation, MeanSqDeviation, Minkowski, RMSDeviation,
5-
SqEuclidean, TotalVariation,
3+
BhattacharyyaDist, BrayCurtis, Chebyshev, ChiSqDist, Cityblock, Distance, Euclidean,
4+
GenKLDivergence, Hamming, JSDivergence, KLDivergence, MeanAbsDeviation, MeanSqDeviation,
5+
Minkowski, RMSDeviation, SqEuclidean, TotalVariation,
66
};
77
use ndarray::{Array, Axis};
88
use ndarray_rand::rand_distr::Uniform;
@@ -88,4 +88,5 @@ enumerate_tests! {
8888
semi_metric: (js_divergence, JSDivergence::new()),
8989
semi_metric: (mean_sq_deviation, MeanSqDeviation::new()),
9090
semi_metric: (sq_euclidean, SqEuclidean::new()),
91+
semi_metric: (bhattacharyya_dist, BhattacharyyaDist::new()),
9192
}

0 commit comments

Comments
 (0)