Skip to content

Commit 0257f33

Browse files
Arnav SoodKristofferC
authored andcommitted
Bregman Divergence (#99)
* bregman * squash bugs * add commas * fix typo * fix other typo * fix bregman test function * fix other sqeuclidean call: bregman * fix colwise test * fix colwise test again * move \del * add Fs * this build actually passes * remove faulty pairwise test * foo-bar * add back premetric checks * modernize * docs + coverage * cache size * new tests + type signature * unindent * suppress some unnecessary output * rand fix * Add Bregman to README
1 parent b05f5c8 commit 0257f33

File tree

4 files changed

+84
-2
lines changed

4 files changed

+84
-2
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ This package also provides optimized functions to compute column-wise and pairwi
3737
* Root mean squared deviation
3838
* Normalized root mean squared deviation
3939
* Bray-Curtis dissimilarity
40+
* Bregman divergence
4041

4142
For ``Euclidean distance``, ``Squared Euclidean distance``, ``Cityblock distance``, ``Minkowski distance``, and ``Hamming distance``, a weighted version is also provided.
4243

@@ -163,6 +164,7 @@ Each distance corresponds to a distance type. The type name and the correspondin
163164
| WeightedCityblock | `wcityblock(x, y, w)` | `sum(abs(x - y) .* w)` |
164165
| WeightedMinkowski | `wminkowski(x, y, w, p)` | `sum(abs(x - y).^p .* w) ^ (1/p)` |
165166
| WeightedHamming | `whamming(x, y, w)` | `sum((x .!= y) .* w)` |
167+
| Bregman | `bregman(F, ∇, x, y; inner = LinearAlgebra.dot)` | `F(x) - F(y) - inner(∇(y), x - y)` |
166168

167169
**Note:** The formulas above are using *Julia*'s functions. These formulas are mainly for conveying the math concepts in a concise way. The actual implementation may use a faster way. The arguments `x` and `y` are arrays of real numbers; `k` and `l` are arrays of distinct elements of any kind; a and b are arrays of Bools; and finally, `p` and `q` are arrays forming a discrete probability distribution and are therefore both expected to sum to one.
168170

src/Distances.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ export
5454
MeanSqDeviation,
5555
RMSDeviation,
5656
NormRMSDeviation,
57+
Bregman,
5758

5859
# convenient functions
5960
euclidean,
@@ -84,6 +85,7 @@ export
8485
mahalanobis,
8586
bhattacharyya,
8687
hellinger,
88+
bregman,
8789

8890
haversine,
8991

@@ -99,5 +101,6 @@ include("wmetrics.jl")
99101
include("haversine.jl")
100102
include("mahalanobis.jl")
101103
include("bhattacharyya.jl")
104+
include("bregman.jl")
102105

103106
end # module end

src/bregman.jl

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Bregman divergence
2+
3+
"""
4+
Implements the Bregman divergence, a friendly introduction to which can be found
5+
[here](http://mark.reid.name/blog/meet-the-bregman-divergences.html).
6+
Bregman divergences are a minimal implementation of the "mean-minimizer" property.
7+
8+
It is assumed that the (convex differentiable) function F maps vectors (of any type or size) to real numbers.
9+
The inner product used is `Base.dot`, but one can be passed in either by defining `inner` or by
10+
passing in a keyword argument. If an analytic gradient isn't available, Julia offers a suite
11+
of good automatic differentiation packages.
12+
13+
function evaluate(dist::Bregman, p::AbstractVector, q::AbstractVector)
14+
"""
15+
struct Bregman{T1 <: Function, T2 <: Function, T3 <: Function} <: PreMetric
16+
F::T1
17+
::T2
18+
inner::T3
19+
end
20+
21+
# Default costructor.
22+
Bregman(F, ∇) = Bregman(F, ∇, LinearAlgebra.dot)
23+
24+
# Evaluation fuction
25+
function evaluate(dist::Bregman, p::AbstractVector, q::AbstractVector)
26+
# Create cache vals.
27+
FP_val = dist.F(p);
28+
FQ_val = dist.F(q);
29+
DQ_val = dist.(q);
30+
p_size = size(p);
31+
# Check F codomain.
32+
if !(isa(FP_val, Real) && isa(FQ_val, Real))
33+
throw(ArgumentError("F Codomain Error: F doesn't map the vectors to real numbers"))
34+
end
35+
# Check vector size.
36+
if !(p_size == size(q))
37+
throw(DimensionMismatch("The vector p ($(size(p))) and q ($(size(q))) are different sizes."))
38+
end
39+
# Check gradient size.
40+
if !(size(DQ_val) == p_size)
41+
throw(DimensionMismatch("The gradient result is not the same size as p and q"))
42+
end
43+
# Return the Bregman divergence.
44+
return FP_val - FQ_val - dist.inner(DQ_val, p-q);
45+
end
46+
47+
# Convenience function.
48+
bregman(F, ∇, x, y; inner = LinearAlgebra.dot) = evaluate(Bregman(F, ∇, inner), x, y)

test/test_dists.jl

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ function test_metricity(dist, x, y, z)
88
if isa(dist, PreMetric)
99
# Unfortunately small non-zero numbers (~10^-16) are appearing
1010
# in our tests due to accumulating floating point rounding errors.
11-
# We either need to allow small errors in our tests or change the
11+
# We either need to allow small errors in our tests or change the
1212
# way we do accumulations...
1313
@test evaluate(dist, x, x) + one(eltype(x)) one(eltype(x))
1414
@test evaluate(dist, y, y) + one(eltype(y)) one(eltype(y))
@@ -59,6 +59,8 @@ end
5959

6060
test_metricity(BhattacharyyaDist(), x, y, z)
6161
test_metricity(HellingerDist(), x, y, z)
62+
test_metricity(Bregman(x -> sqeuclidean(x, zero(x)), x -> 2*x), x, y, z);
63+
6264

6365
x₁ = rand(T, 2)
6466
x₂ = rand(T, 2)
@@ -276,6 +278,9 @@ end # testset
276278
@test_throws DimensionMismatch colwise!(mat23, Euclidean(), mat23, mat23)
277279
@test_throws DimensionMismatch colwise!(mat23, Euclidean(), mat23, q)
278280
@test_throws DimensionMismatch colwise!(mat23, Euclidean(), mat23, mat22)
281+
@test_throws DimensionMismatch colwise!(mat23, Bregman(x -> sqeuclidean(x, zero(x)), x -> 2*x), mat23, mat22)
282+
@test_throws DimensionMismatch evaluate(Bregman(x -> sqeuclidean(x, zero(x)), x -> 2*x), [1, 2, 3], [1, 2])
283+
@test_throws DimensionMismatch evaluate(Bregman(x -> sqeuclidean(x, zero(x)), x -> [1, 2]), [1, 2, 3], [1, 2, 3])
279284
end # testset
280285

281286
@testset "mahalanobis" begin
@@ -382,6 +387,7 @@ end
382387
test_colwise(Chebyshev(), X, Y, T)
383388
test_colwise(Minkowski(2.5), X, Y, T)
384389
test_colwise(Hamming(), A, B, T)
390+
test_colwise(Bregman(x -> sqeuclidean(x, zero(x)), x -> 2*x), X, Y, T);
385391

386392
test_colwise(CosineDist(), X, Y, T)
387393
test_colwise(CorrDist(), X, Y, T)
@@ -416,7 +422,6 @@ end
416422
test_colwise(Mahalanobis(Q), X, Y, T)
417423
end
418424

419-
420425
function test_pairwise(dist, x, y, T)
421426
@testset "Pairwise test for $(typeof(dist))" begin
422427
nx = size(x, 2)
@@ -472,6 +477,7 @@ end
472477
test_pairwise(BhattacharyyaDist(), X, Y, T)
473478
test_pairwise(HellingerDist(), X, Y, T)
474479
test_pairwise(BrayCurtis(), X, Y, T)
480+
test_pairwise(Bregman(x -> sqeuclidean(x, zero(x)), x -> 2*x), X, Y, T)
475481

476482
w = rand(m)
477483

@@ -503,3 +509,26 @@ end
503509
@test pd[1, 1] == 0
504510
@test pd[2, 2] == 0
505511
end
512+
513+
@testset "Bregman Divergence" begin
514+
# Some basic tests.
515+
@test_throws ArgumentError bregman(x -> x, x -> 2*x, [1, 2, 3], [1, 2, 3])
516+
# Test if Bregman() correctly implements the gkl divergence between two random vectors.
517+
F(p) = LinearAlgebra.dot(p, log.(p));
518+
(p) = map(x -> log(x) + 1, p)
519+
testDist = Bregman(F, ∇)
520+
p = rand(4)
521+
q = rand(4)
522+
p = p/sum(p);
523+
q = q/sum(q);
524+
@test evaluate(testDist, p, q) gkl_divergence(p, q)
525+
# Test if Bregman() correctly implements the squared euclidean dist. between them.
526+
@test bregman(x -> norm(x)^2, x -> 2*x, p, q) sqeuclidean(p, q)
527+
# Test if Bregman() correctly implements the IS distance.
528+
F(p) = -1 * sum(log.(p))
529+
(p) = map(x -> -1 * x^(-1), p)
530+
function ISdist(p::AbstractVector, q::AbstractVector)
531+
return sum([p[i]/q[i] - log(p[i]/q[i]) - 1 for i in 1:length(p)])
532+
end
533+
@test bregman(F, ∇, p, q) ISdist(p, q)
534+
end

0 commit comments

Comments
 (0)