Skip to content

Commit 3877ea4

Browse files
richardreeveKristofferC
authored andcommitted
Add in Rényi divergences (#49)
* Add in Rényi divergences. * Add in Renyi divergences to README.md. * test_dists.jl: p and q are prob distributions, so sum to 1. * test_dists.jl: Move tests that don't use x and y out of the x,y for loop. * Fix Rényi entropies to handle generalised probability distributions. * Test Rényi divergences. * Lining up |s * Rewrote eval_op() for RenyiDivergence in function form. * Corrected test for zero, and return result where computed. * Removing unicode symbols from arguments. * No need to define Renyi divergence for numbers rather than vectors. * Add renyi_divergence() colwise and pairwise tests. * Add renyi_divergence() tests for empty vectors and NaNs. * Rearrange tests to create variables near where they are used. * Document p for RenyiDivergence * Replace 0 and 1 with correct type. * Correct type for eval_reduce(::RenyiDivergence) and eval_op(::RenyiDivergence), and turn eval_reduce and eval_end into functions, handling q = Inf correctly and testing the code.
1 parent df02f3d commit 3877ea4

File tree

4 files changed

+132
-25
lines changed

4 files changed

+132
-25
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ This package also provides optimized functions to compute column-wise and pairwi
2424
* Correlation distance
2525
* Chi-square distance
2626
* Kullback-Leibler divergence
27+
* Rényi divergence
2728
* Jensen-Shannon divergence
2829
* Mahalanobis distance
2930
* Squared Mahalanobis distance
@@ -138,6 +139,7 @@ Each distance corresponds to a distance type. The type name and the correspondin
138139
| CorrDist | corr_dist(x, y) | cosine_dist(x - mean(x), y - mean(y)) |
139140
| ChiSqDist | chisq_dist(x, y) | sum((x - y).^2 / (x + y)) |
140141
| KLDivergence | kl_divergence(x, y) | sum(p .* log(p ./ q)) |
142+
| RenyiDivergence | renyi_divergence(x, y, k)| log(sum( x .* (x ./ y) .^ (k - 1))) / (k - 1) |
141143
| JSDivergence | js_divergence(x, y) | KL(x, m) / 2 + KL(y, m) / 2 with m = (x + y) / 2 |
142144
| SpanNormDist | spannorm_dist(x, y) | max(x - y) - min(x - y ) |
143145
| BhattacharyyaDist | bhattacharyya(x, y) | -log(sum(sqrt(x .* y) / sqrt(sum(x) * sum(y))) |

src/Distances.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ export
3333
ChiSqDist,
3434
KLDivergence,
3535
JSDivergence,
36+
RenyiDivergence,
3637
SpanNormDist,
3738

3839
WeightedEuclidean,
@@ -61,6 +62,7 @@ export
6162
chisq_dist,
6263
kl_divergence,
6364
js_divergence,
65+
renyi_divergence,
6466
spannorm_dist,
6567

6668
weuclidean,

src/metrics.jl

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,34 @@ type CorrDist <: SemiMetric end
2424

2525
type ChiSqDist <: SemiMetric end
2626
type KLDivergence <: PreMetric end
27+
28+
immutable RenyiDivergence{T <: Real} <: PreMetric
29+
p::T # order of power mean (order of divergence - 1)
30+
is_normal::Bool
31+
is_zero::Bool
32+
is_one::Bool
33+
is_inf::Bool
34+
function RenyiDivergence(q)
35+
# There are four different cases:
36+
# simpler to separate them out now, not over and over in eval_op()
37+
is_zero = q zero(T)
38+
is_one = q one(T)
39+
is_inf = isinf(q)
40+
41+
# Only positive Rényi divergences are defined
42+
!is_zero && q < zero(T) && throw(ArgumentError("Order of Rényi divergence not legal, $(q) < 0."))
43+
44+
new(q - 1, !(is_zero || is_one || is_inf), is_zero, is_one, is_inf)
45+
end
46+
end
47+
RenyiDivergence{T}(q::T) = RenyiDivergence{T}(q)
48+
2749
type JSDivergence <: SemiMetric end
2850

2951
type SpanNormDist <: SemiMetric end
3052

3153

32-
typealias UnionMetrics Union{Euclidean, SqEuclidean, Chebyshev, Cityblock, Minkowski, Hamming, Jaccard, RogersTanimoto, CosineDist, CorrDist, ChiSqDist, KLDivergence, JSDivergence, SpanNormDist}
54+
typealias UnionMetrics Union{Euclidean, SqEuclidean, Chebyshev, Cityblock, Minkowski, Hamming, Jaccard, RogersTanimoto, CosineDist, CorrDist, ChiSqDist, KLDivergence, RenyiDivergence, JSDivergence, SpanNormDist}
3355

3456
###########################################################
3557
#
@@ -141,6 +163,53 @@ chisq_dist(a::AbstractArray, b::AbstractArray) = evaluate(ChiSqDist(), a, b)
141163
@inline eval_reduce(::KLDivergence, s1, s2) = s1 + s2
142164
kl_divergence(a::AbstractArray, b::AbstractArray) = evaluate(KLDivergence(), a, b)
143165

166+
# RenyiDivergence
167+
function eval_start{T<:AbstractFloat}(::RenyiDivergence, a::AbstractArray{T}, b::AbstractArray{T})
168+
zero(T), zero(T)
169+
end
170+
171+
@inline function eval_op{T<:AbstractFloat}(dist::RenyiDivergence, ai::T, bi::T)
172+
if ai == zero(T)
173+
return zero(T), zero(T)
174+
elseif dist.is_normal
175+
return ai, ai .* ((ai ./ bi) .^ dist.p)
176+
elseif dist.is_zero
177+
return ai, bi
178+
elseif dist.is_one
179+
return ai, ai * log(ai / bi)
180+
else # otherwise q = ∞
181+
return ai, ai / bi
182+
end
183+
end
184+
185+
@inline function eval_reduce{T<:AbstractFloat}(dist::RenyiDivergence,
186+
s1::Tuple{T, T},
187+
s2::Tuple{T, T})
188+
if dist.is_inf
189+
if s1[1] == zero(T)
190+
return s2
191+
elseif s2[1] == zero(T)
192+
return s1
193+
else
194+
return s1[2] > s2[2] ? s1 : s2
195+
end
196+
else
197+
return s1[1] + s2[1], s1[2] + s2[2]
198+
end
199+
end
200+
201+
function eval_end(dist::RenyiDivergence, s)
202+
if dist.is_zero || dist.is_normal
203+
log(s[2] / s[1]) / dist.p
204+
elseif dist.is_one
205+
return s[2] / s[1]
206+
else # q = ∞
207+
log(s[2])
208+
end
209+
end
210+
211+
renyi_divergence(a::AbstractArray, b::AbstractArray, q::Real) = evaluate(RenyiDivergence(q), a, b)
212+
144213
# JSDivergence
145214
@inline function eval_op{T}(::JSDivergence, ai::T, bi::T)
146215
u = (ai + bi) / 2

test/test_dists.jl

Lines changed: 58 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,6 @@ bf = [false, true, true]
2929
@test rogerstanimoto(bt, bt) == 0
3030
@test rogerstanimoto(bt, bf) == 4./5
3131

32-
33-
p = rand(12)
34-
p[p .< 0.3] = 0.
35-
q = rand(12)
36-
a = [1., 2., 1., 3., 2., 1.]
37-
b = [1., 3., 0., 2., 2., 0.]
3832
for (x, y) in (([4., 5., 6., 7.], [3., 9., 8., 1.]),
3933
([4., 5., 6., 7.], [3. 8.; 9. 1.]))
4034
@test sqeuclidean(x, x) == 0.
@@ -66,19 +60,6 @@ for (x, y) in (([4., 5., 6., 7.], [3., 9., 8., 1.]),
6660
@test chisq_dist(x, x) == 0.
6761
@test chisq_dist(x, y) == sum((x - vec(y)).^2 ./ (x + vec(y)))
6862

69-
klv = 0.
70-
for i = 1 : length(p)
71-
if p[i] > 0
72-
klv += p[i] * log(p[i] / q[i])
73-
end
74-
end
75-
@test kl_divergence(p, q) klv
76-
77-
pm = (p + q) / 2
78-
jsv = kl_divergence(p, pm) / 2 + kl_divergence(q, pm) / 2
79-
@test js_divergence(p, p) 0.0
80-
@test js_divergence(p, q) jsv
81-
8263
@test spannorm_dist(x, x) == 0.
8364
@test spannorm_dist(x, y) == maximum(x - vec(y)) - minimum(x - vec(y))
8465

@@ -101,17 +82,57 @@ for (x, y) in (([4., 5., 6., 7.], [3., 9., 8., 1.]),
10182

10283
@test wminkowski(x, x, w, 2) == 0.
10384
@test wminkowski(x, y, w, 2) weuclidean(x, y, w)
85+
end
10486

105-
w = rand(size(a))
87+
# Test weighted Hamming distances with even weights
88+
a = [1., 2., 1., 3., 2., 1.]
89+
b = [1., 3., 0., 2., 2., 0.]
90+
w = rand(size(a))
10691

107-
@test whamming(a, a, w) == 0.
108-
@test whamming(a, b, w) == sum((a .!= b) .* w)
109-
end
92+
@test whamming(a, a, w) == 0.
93+
@test whamming(a, b, w) == sum((a .!= b) .* w)
11094

95+
# Minimal test of Jaccard - test return type stability.
11196
@inferred evaluate(Jaccard(), rand(3), rand(3))
11297
@inferred evaluate(Jaccard(), [1,2,3], [1,2,3])
11398
@inferred evaluate(Jaccard(), [true, false, true], [false, true, true])
11499

100+
# Test KL, Renyi and JS divergences
101+
p = r = rand(12)
102+
p[p .< 0.3] = 0.0
103+
scale = sum(p) / sum(r)
104+
r /= sum(r)
105+
p /= sum(p)
106+
q = rand(12)
107+
q /= sum(q)
108+
109+
klv = 0.
110+
for i = 1 : length(p)
111+
if p[i] > 0
112+
klv += p[i] * log(p[i] / q[i])
113+
end
114+
end
115+
@test kl_divergence(p, q) klv
116+
117+
@test renyi_divergence(p, p, 0) 0
118+
@test renyi_divergence(p, p, 1) 0
119+
@test renyi_divergence(p, p, rand()) 0
120+
@test renyi_divergence(p, p, 1.0 + rand()) 0
121+
@test renyi_divergence(p, p, Inf) 0
122+
@test renyi_divergence(p, r, 0) -log(scale)
123+
@test renyi_divergence(p, r, 1) -log(scale)
124+
@test renyi_divergence(p, r, rand()) -log(scale)
125+
@test renyi_divergence(p, r, Inf) -log(scale)
126+
@test isinf(renyi_divergence([0.0, 0.5, 0.5], [0.0, 1.0, 0.0], Inf))
127+
@test renyi_divergence([0.0, 1.0, 0.0], [0.0, 0.5, 0.5], Inf) log(2.0)
128+
@test renyi_divergence(p, q, 1) kl_divergence(p, q)
129+
130+
pm = (p + q) / 2
131+
jsv = kl_divergence(p, pm) / 2 + kl_divergence(q, pm) / 2
132+
@test js_divergence(p, p) 0.0
133+
@test js_divergence(p, q) jsv
134+
135+
115136
end # testset
116137

117138

@@ -121,7 +142,8 @@ a = [NaN, 0]; b = [0, NaN]
121142
@test isnan(chebyshev(a, b)) == isnan(maximum(a-b))
122143
a = [NaN, 0]; b = [0, 1]
123144
@test isnan(chebyshev(a, b)) == isnan(maximum(a-b))
124-
145+
@test !isnan(renyi_divergence([0.5, 0.0, 0.5], [0.5, NaN, 0.5], 2))
146+
@test isnan(renyi_divergence([0.5, 0.0, 0.5], [0.5, 0.5, NaN], 2))
125147
end #testset
126148

127149

@@ -141,6 +163,8 @@ b = Float64[]
141163
@test isa(minkowski(a, b, 2), Float64)
142164
@test hamming(a, b) == 0.0
143165
@test isa(hamming(a, b), Int)
166+
@test renyi_divergence(a, b, 1.0) == 0.0
167+
@test isa(renyi_divergence(a, b, 2.0), Float64)
144168

145169
w = Float64[]
146170
@test isa(whamming(a, b, w), Float64)
@@ -261,6 +285,11 @@ P[P .< 0.3] = 0.
261285

262286
@test_colwise ChiSqDist() X Y
263287
@test_colwise KLDivergence() P Q
288+
@test_colwise RenyiDivergence(0.0) P Q
289+
@test_colwise RenyiDivergence(1.0) P Q
290+
@test_colwise RenyiDivergence(Inf) P Q
291+
@test_colwise RenyiDivergence(0.5) P Q
292+
@test_colwise RenyiDivergence(2) P Q
264293
@test_colwise JSDivergence() P Q
265294
@test_colwise SpanNormDist() X Y
266295

@@ -329,6 +358,11 @@ Q = rand(m, ny)
329358

330359
@test_pairwise ChiSqDist() X Y
331360
@test_pairwise KLDivergence() P Q
361+
@test_pairwise RenyiDivergence(0.0) P Q
362+
@test_pairwise RenyiDivergence(1.0) P Q
363+
@test_pairwise RenyiDivergence(Inf) P Q
364+
@test_pairwise RenyiDivergence(0.5) P Q
365+
@test_pairwise RenyiDivergence(2) P Q
332366
@test_pairwise JSDivergence() P Q
333367

334368
@test_pairwise BhattacharyyaDist() X Y

0 commit comments

Comments
 (0)