Skip to content

Commit 1f1d180

Browse files
committed
Implement Adjusted Mutual Information
1 parent fb2c187 commit 1f1d180

File tree

6 files changed

+155
-14
lines changed

6 files changed

+155
-14
lines changed

Project.toml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,20 @@ NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
99
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1010
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1111
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
12+
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1213
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1314
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1415

16+
[weakdeps]
17+
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
18+
19+
[extensions]
20+
SpecialFunctionsExt = "SpecialFunctions"
21+
1522
[compat]
1623
Distances = "0.10.9"
1724
NearestNeighbors = "0.4"
25+
SpecialFunctions = ">= 0.8"
1826
Statistics = "1"
1927
StatsBase = "0.25, 0.26, 0.27, 0.28, 0.29, 0.30, 0.31, 0.32, 0.33, 0.34"
2028
julia = "1"
@@ -26,9 +34,10 @@ Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
2634
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2735
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2836
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
37+
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
2938
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
3039
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
3140
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3241

3342
[targets]
34-
test = ["CodecZlib", "Statistics", "LinearAlgebra", "SparseArrays", "Distances", "Random", "DelimitedFiles", "StableRNGs", "Test"]
43+
test = ["CodecZlib", "Statistics", "LinearAlgebra", "SparseArrays", "Distances", "Random", "DelimitedFiles", "SpecialFunctions", "StableRNGs", "Test"]

ext/SpecialFunctionsExt.jl

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
module SpecialFunctionsExt # Should be same name as the file (just like a normal package)
2+
3+
using SpecialFunctions: loggamma
4+
using StatsBase: counts
5+
using Statistics: middle
6+
7+
import Clustering: _mutualinfo
8+
9+
function _mutualinfo(::Val{:adjusted}, a, b; aggregate::Symbol = :mean)
10+
norm_f = if aggregate === :mean
11+
middle
12+
elseif aggregate === :geomean
13+
(a,b) -> sqrt(a*b)
14+
elseif aggregate === :max
15+
max
16+
elseif aggregate === :min
17+
min
18+
else
19+
throw(ArgumentError("Valid options for `aggregate` are: `mean`, `geomean`, `max`, `min`"))
20+
end
21+
22+
return _mutualinfo(counts(a, b)) do hck, hc, hk, rows, cols, N
23+
mi = hc - hck
24+
emi = _expectedmutualinfo(rows, cols, N)
25+
normalizer = norm_f(hc, hk)
26+
denominator = normalizer - emi
27+
(mi - emi) / denominator
28+
end
29+
end
30+
31+
# Adjusted Mutual Information
32+
33+
function _expectedmutualinfo(a, b, n_samples)
34+
nijs = 1:max(maximum(a), maximum(b))
35+
36+
term1 = nijs ./ n_samples
37+
38+
log_ab = [log(a[i]) + log(b[j]) for i in eachindex(a), j in eachindex(b)]
39+
log_Nnij = log(n_samples) .+ log.(nijs)
40+
41+
gln_a = loggamma.(a .+ 1)
42+
gln_b = loggamma.(b .+ 1)
43+
gln_Na = loggamma.(n_samples .- a .+ 1)
44+
gln_Nb = loggamma.(n_samples .- b .+ 1)
45+
gln_Nnij = loggamma.(nijs .+ 1) .+ loggamma.(n_samples + 1)
46+
47+
emi = zero(Float64)
48+
for i in eachindex(a), j in eachindex(b)
49+
nij_idxs = max(1, a[i] - n_samples + b[j]):min(a[i], b[j])
50+
for nij in nij_idxs
51+
term2 = log_Nnij[nij] - log_ab[i,j]
52+
gln = (
53+
gln_a[i] + gln_b[j] + gln_Na[i] + gln_Nb[j] - gln_Nnij[nij] -
54+
loggamma(a[i] - nij + 1) - loggamma(b[j] - nij + 1) -
55+
loggamma(n_samples - a[i] - b[j] + nij + 1)
56+
)
57+
term3 = exp(gln)
58+
emi += (term1[nij] * term2 * term3)
59+
end
60+
end
61+
return emi
62+
end
63+
64+
end # module

src/Clustering.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,4 +100,8 @@ module Clustering
100100
include("hclust.jl")
101101

102102
include("deprecate.jl")
103+
104+
if !isdefined(Base, :get_extension)
105+
include("../ext/SpecialFunctionsExt.jl")
106+
end
103107
end

src/mutualinfo.jl

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Mutual Information
22

3-
function _mutualinfo(A::AbstractMatrix{<:Integer}, normed::Bool)
3+
@inline function _mutualinfo(f, A::AbstractMatrix{<:Integer})
44
N = sum(A)
55
(N == 0.0) && return 0.0
66

@@ -14,28 +14,73 @@ function _mutualinfo(A::AbstractMatrix{<:Integer}, normed::Bool)
1414
hc = entArows/N + log(N)
1515
hk = entAcols/N + log(N)
1616

17-
mi = hc - hck
18-
return if normed
19-
2*mi/(hc+hk)
17+
f(hck,hc,hk,rows,cols,N)
18+
end
19+
20+
function _mutualinfo(::Val{T}, a, b; kwargs...) where T
21+
if T === :adjusted && isempty(kwargs)
22+
error("Error: mutualinfo(): `method=:adjusted` requires SpecialFunctions package to be loaded")
23+
elseif T !== :classic && T !== :normalized
24+
throw(ArgumentError("mutualinfo(): `method=:$(T)` is not supported"))
2025
else
21-
mi
26+
throw(ArgumentError("mutualinfo(): unsupported kwargs used. See the `mutualinfo` docstring for more information"))
2227
end
2328
end
2429

2530
"""
26-
mutualinfo(a, b; normed=true) -> Float64
31+
mutualinfo(a, b; method=:normalized, kwargs...) -> Float64
2732
2833
Compute the *mutual information* between the two clusterings of the same
2934
data points.
3035
3136
`a` and `b` can be either [`ClusteringResult`](@ref) instances or
3237
assignments vectors (`AbstractVector{<:Integer}`).
3338
34-
If `normed` parameter is `true` the return value is the normalized mutual information (symmetric uncertainty),
35-
see "Data Mining Practical Machine Tools and Techniques", Witten & Frank 2005.
39+
`method` can be one of `:classic`, `:normalized` (default), or `:adjusted`, to calculate the
40+
original mutual information score, the normalized mutual information, or the adjusted mutual
41+
information respectively.
42+
43+
When `method=:adjusted`, the `aggregate` kwarg determines how the normalizer
44+
in the denominator is computed. It can be one of:
45+
- `:mean`: The arithmetic mean of two values
46+
- `:geomean`: The geometric mean of two values
47+
- `:max`: The highest of two values
48+
- `:min`: The lowest of two values
3649
3750
# References
3851
> Vinh, Epps, and Bailey, (2009). *Information theoretic measures for clusterings comparison*.
3952
> Proceedings of the 26th Annual International Conference on Machine Learning - ICML ‘09.
53+
54+
> "Data Mining Practical Machine Tools and Techniques", Witten & Frank 2005.
4055
"""
41-
mutualinfo(a, b; normed::Bool=true) = _mutualinfo(counts(a, b), normed)
56+
function mutualinfo(a, b; method::Union{Nothing, Symbol} = nothing, normed::Union{Nothing, Bool} = nothing, kwargs...)
57+
# Disallow `method` and `normed` to be used together
58+
if isnothing(method)
59+
isnothing(normed) || Base.depwarn("`normed` kwarg is deprecated, please use `method=:normalized` instead of `normed=true`, and `method=:classic` instead of `normed=false'", :mutualinfo)
60+
method = if isnothing(normed) || normed
61+
:normalized
62+
else
63+
:classic
64+
end
65+
else
66+
isnothing(normed) || throw(ArgumentError("`normed` kwarg is not compatible with `method` kwarg"))
67+
end
68+
# Little hack to ensure the correct error is thrown
69+
if method === :adjusted && length(kwargs) >= 1 && :aggregate keys(kwargs)
70+
method = :classic
71+
end
72+
73+
_mutualinfo(Val(method), a, b; kwargs...)
74+
end
75+
76+
function _mutualinfo(::Val{:normalized}, a, b)
77+
return _mutualinfo(counts(a, b)) do hck, hc, hk, _, _, _
78+
mi = hc - hck
79+
2*mi/(hc+hk)
80+
end
81+
end
82+
function _mutualinfo(::Val{:classic}, a, b)
83+
return _mutualinfo(counts(a, b)) do hck, hc, _, _, _, _
84+
hc - hck
85+
end
86+
end

test/mutualinfo.jl

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,31 @@ using Clustering
66
# https://nlp.stanford.edu/IR-book/html/htmledition/evaluation-of-clustering-1.html
77
a1 = [1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3]
88
a2 = [1, 1, 1, 1, 1, 2, 3, 3, 1, 2, 2, 2, 2, 2, 3, 3, 3]
9-
@test mutualinfo(a1, a2, normed=false) 0.39 atol=1.0e-2
10-
@test mutualinfo(a1, a2) 0.36 atol=1.0e-2
9+
@test mutualinfo(a1, a2; method=:classic) 0.39 atol=1.0e-2
10+
@test mutualinfo(a1, a2;) 0.36 atol=1.0e-2
11+
@test mutualinfo(a1, a2; method=:adjusted) 0.2602 atol=1.0e-4
12+
@test mutualinfo(a1, a2; method=:adjusted, aggregate=:geomean) 0.2602 atol=1.0e-4
13+
@test mutualinfo(a1, a2; method=:adjusted, aggregate=:max) 0.2547 atol=1.0e-4
14+
@test mutualinfo(a1, a2; method=:adjusted, aggregate=:min) 0.2659 atol=1.0e-4
15+
16+
# test deprecated kwarg
17+
@test mutualinfo(a1, a2; normed=false) 0.39 atol=1.0e-2
18+
@test mutualinfo(a1, a2; normed=true) 0.36 atol=1.0e-2
1119

1220
# https://doi.org/10.1186/1471-2105-7-380
1321
a1 = [1, 1, 1, 1, 1, 3, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 1, 2]
1422
a2 = [1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4]
15-
@test mutualinfo(a1, a2, normed=false) 0.6 atol=0.1
16-
@test mutualinfo(a1, a2) 0.5 atol=0.1
23+
@test mutualinfo(a1, a2; method=:classic) 0.6 atol=0.1
24+
@test mutualinfo(a1, a2; method=:normalized) 0.5 atol=0.1
25+
@test mutualinfo(a1, a2; method=:adjusted, aggregate=:mean) 0.3839 atol=1.0e-4
26+
@test mutualinfo(a1, a2; method=:adjusted, aggregate=:geomean) 0.3861 atol=1.0e-4
27+
@test mutualinfo(a1, a2; method=:adjusted, aggregate=:max) 0.3437 atol=1.0e-4
28+
@test mutualinfo(a1, a2; method=:adjusted, aggregate=:min) 0.4348 atol=1.0e-4
29+
30+
# test errors
31+
@test_throws "ArgumentError: `normed` kwarg is not compatible with `method` kwarg" mutualinfo(a1, a2; method=:adjusfted, normed=false)
32+
@test_throws "ArgumentError: mutualinfo(): `method=:adjusfted` is not supported" mutualinfo(a1, a2; method=:adjusfted, aggregate=:min)
33+
@test_throws "ArgumentError: mutualinfo(): unsupported kwargs used." mutualinfo(a1, a2; method=:adjusted, notaggregate=:min)
34+
@test_throws "ArgumentError: mutualinfo(): unsupported kwargs used." mutualinfo(a1, a2; method=:classic, notaggregate=:min)
1735

1836
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using Test
33
using Random
44
using LinearAlgebra
55
using SparseArrays
6+
using SpecialFunctions
67
using StableRNGs
78
using Statistics
89

0 commit comments

Comments
 (0)