Skip to content

Commit bc17903

Browse files
authored
Add distances, divergences, and deviations from StatsBase (#58)
1 parent dd59e65 commit bc17903

File tree

4 files changed

+62
-2
lines changed

4 files changed

+62
-2
lines changed

README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,17 @@ This package also provides optimized functions to compute column-wise and pairwi
2323
* Correlation distance
2424
* Chi-square distance
2525
* Kullback-Leibler divergence
26+
* Generalized Kullback-Leibler divergence
2627
* Rényi divergence
2728
* Jensen-Shannon divergence
2829
* Mahalanobis distance
2930
* Squared Mahalanobis distance
3031
* Bhattacharyya distance
3132
* Hellinger distance
33+
* Mean absolute deviation
34+
* Mean squared deviation
35+
* Root mean squared deviation
36+
* Normalized root mean squared deviation
3237

3338
For ``Euclidean distance``, ``Squared Euclidean distance``, ``Cityblock distance``, ``Minkowski distance``, and ``Hamming distance``, a weighted version is also provided.
3439

@@ -138,13 +143,18 @@ Each distance corresponds to a distance type. The type name and the correspondin
138143
| CorrDist | `corr_dist(x, y)` | `cosine_dist(x - mean(x), y - mean(y))` |
139144
| ChiSqDist | `chisq_dist(x, y)` | `sum((x - y).^2 / (x + y))` |
140145
| KLDivergence | `kl_divergence(p, q)` | `sum(p .* log(p ./ q))` |
146+
| GenKLDivergence | `gkl_divergence(x, y)` | `sum(p .* log(p ./ q) - p + q)` |
141147
| RenyiDivergence | `renyi_divergence(p, q, k)`| `log(sum( p .* (p ./ q) .^ (k - 1))) / (k - 1)` |
142148
| JSDivergence | `js_divergence(p, q)` | `KL(p, m) / 2 + KL(p, m) / 2 with m = (p + q) / 2` |
143149
| SpanNormDist | `spannorm_dist(x, y)` | `max(x - y) - min(x - y )` |
144150
| BhattacharyyaDist | `bhattacharyya(x, y)` | `-log(sum(sqrt(x .* y) / sqrt(sum(x) * sum(y)))` |
145151
| HellingerDist | `hellinger(x, y) ` | `sqrt(1 - sum(sqrt(x .* y) / sqrt(sum(x) * sum(y))))` |
146152
| Mahalanobis | `mahalanobis(x, y, Q)` | `sqrt((x - y)' * Q * (x - y))` |
147153
| SqMahalanobis | `sqmahalanobis(x, y, Q)` | ` (x - y)' * Q * (x - y)` |
154+
| MeanAbsDeviation | `meanad(x, y)` | `mean(abs.(x - y))` |
155+
| MeanSqDeviation | `msd(x, y)` | `mean(abs2.(x - y))` |
156+
| RMSDeviation | `rmsd(x, y)` | `sqrt(msd(x, y))` |
157+
| NormRMSDeviation | `nrmsd(x, y)` | `rmsd(x, y) / (maximum(x) - minimum(x))` |
148158
| WeightedEuclidean | `weuclidean(x, y, w)` | `sqrt(sum((x - y).^2 .* w))` |
149159
| WeightedSqEuclidean | `wsqeuclidean(x, y, w)` | `sum((x - y).^2 .* w)` |
150160
| WeightedCityblock | `wcityblock(x, y, w)` | `sum(abs(x - y) .* w)` |

src/Distances.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ export
3030
CorrDist,
3131
ChiSqDist,
3232
KLDivergence,
33+
GenKLDivergence,
3334
JSDivergence,
3435
RenyiDivergence,
3536
SpanNormDist,
@@ -44,6 +45,11 @@ export
4445
BhattacharyyaDist,
4546
HellingerDist,
4647

48+
MeanAbsDeviation,
49+
MeanSqDeviation,
50+
RMSDeviation,
51+
NormRMSDeviation,
52+
4753
# convenient functions
4854
euclidean,
4955
sqeuclidean,
@@ -59,6 +65,7 @@ export
5965
corr_dist,
6066
chisq_dist,
6167
kl_divergence,
68+
gkl_divergence,
6269
js_divergence,
6370
renyi_divergence,
6471
spannorm_dist,
@@ -71,7 +78,12 @@ export
7178
sqmahalanobis,
7279
mahalanobis,
7380
bhattacharyya,
74-
hellinger
81+
hellinger,
82+
83+
meanad,
84+
msd,
85+
rmsd,
86+
nrmsd
7587

7688
include("common.jl")
7789
include("generic.jl")

src/metrics.jl

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ struct CorrDist <: SemiMetric end
2828

2929
struct ChiSqDist <: SemiMetric end
3030
struct KLDivergence <: PreMetric end
31+
struct KLDivergence <: PreMetric end
32+
struct GenKLDivergence <: PreMetric end
3133

3234
"""
3335
RenyiDivergence(α::Real)
@@ -89,8 +91,15 @@ struct JSDivergence <: SemiMetric end
8991

9092
struct SpanNormDist <: SemiMetric end
9193

94+
# Deviations are handled separately from the other distances/divergences and
95+
# are excluded from `UnionMetrics`
96+
struct MeanAbsDeviation <: Metric end
97+
struct MeanSqDeviation <: SemiMetric end
98+
struct RMSDeviation <: Metric end
99+
struct NormRMSDeviation <: Metric end
100+
92101

93-
const UnionMetrics = Union{Euclidean, SqEuclidean, Chebyshev, Cityblock, Minkowski, Hamming, Jaccard, RogersTanimoto, CosineDist, CorrDist, ChiSqDist, KLDivergence, RenyiDivergence, JSDivergence, SpanNormDist}
102+
const UnionMetrics = Union{Euclidean, SqEuclidean, Chebyshev, Cityblock, Minkowski, Hamming, Jaccard, RogersTanimoto, CosineDist, CorrDist, ChiSqDist, KLDivergence, RenyiDivergence, JSDivergence, SpanNormDist, GenKLDivergence}
94103

95104
"""
96105
Euclidean([thresh])
@@ -240,6 +249,11 @@ chisq_dist(a::AbstractArray, b::AbstractArray) = evaluate(ChiSqDist(), a, b)
240249
@inline eval_reduce(::KLDivergence, s1, s2) = s1 + s2
241250
kl_divergence(a::AbstractArray, b::AbstractArray) = evaluate(KLDivergence(), a, b)
242251

252+
# GenKLDivergence
253+
@inline eval_op(::GenKLDivergence, ai, bi) = ai > 0 ? ai * log(ai / bi) - ai + bi : bi
254+
@inline eval_reduce(::GenKLDivergence, s1, s2) = s1 + s2
255+
gkl_divergence(a::AbstractArray, b::AbstractArray) = evaluate(GenKLDivergence(), a, b)
256+
243257
# RenyiDivergence
244258
function eval_start(::RenyiDivergence, a::AbstractArray{T}, b::AbstractArray{T}) where {T <: AbstractFloat}
245259
zero(T), zero(T), sum(a), sum(b)
@@ -369,6 +383,23 @@ end
369383
end
370384
rogerstanimoto(a::AbstractArray{T}, b::AbstractArray{T}) where {T <: Bool} = evaluate(RogersTanimoto(), a, b)
371385

386+
# Deviations
387+
388+
evaluate(::MeanAbsDeviation, a, b) = cityblock(a, b) / length(a)
389+
meanad(a, b) = evaluate(MeanAbsDeviation(), a, b)
390+
391+
evaluate(::MeanSqDeviation, a, b) = sqeuclidean(a, b) / length(a)
392+
msd(a, b) = evaluate(MeanSqDeviation(), a, b)
393+
394+
evaluate(::RMSDeviation, a, b) = sqrt(evaluate(MeanSqDeviation(), a, b))
395+
rmsd(a, b) = evaluate(RMSDeviation(), a, b)
396+
397+
function evaluate(::NormRMSDeviation, a, b)
398+
amin, amax = extrema(a)
399+
return evaluate(RMSDeviation(), a, b) / (amax - amin)
400+
end
401+
nrmsd(a, b) = evaluate(NormRMSDeviation(), a, b)
402+
372403

373404
###########################################################
374405
#

test/test_dists.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,13 @@ end
134134
@test chisq_dist(x, y) == sum((x - vec(y)).^2 ./ (x + vec(y)))
135135
@test spannorm_dist(x, y) == maximum(x - vec(y)) - minimum(x - vec(y))
136136

137+
@test gkl_divergence(x, y) sum(i -> x[i] * log(x[i] / y[i]) - x[i] + y[i], 1:length(x))
138+
139+
@test meanad(x, y) mean(Float64[abs(x[i] - y[i]) for i in 1:length(x)])
140+
@test msd(x, y) mean(Float64[abs2(x[i] - y[i]) for i in 1:length(x)])
141+
@test rmsd(x, y) sqrt(msd(x, y))
142+
@test nrmsd(x, y) sqrt(msd(x, y)) / (maximum(x) - minimum(x))
143+
137144
w = ones(4)
138145
@test sqeuclidean(x, y) wsqeuclidean(x, y, w)
139146

0 commit comments

Comments
 (0)