Skip to content

Commit bfdd686

Browse files
authored
Merge pull request #243 from JuliaStats/ast/fix_rnadindex
randindex(): fix int overflow
2 parents 9903ccc + 2d5209b commit bfdd686

File tree

5 files changed

+54
-24
lines changed

5 files changed

+54
-24
lines changed

src/confusion.jl

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
"""
2-
confusion(a::Union{ClusteringResult, AbstractVector},
3-
b::Union{ClusteringResult, AbstractVector}) -> Matrix{Int}
2+
confusion([T = Int],
3+
a::Union{ClusteringResult, AbstractVector},
4+
b::Union{ClusteringResult, AbstractVector}) -> Matrix{T}
45
5-
Return 2×2 confusion matrix `C` that represents partition co-occurrence or
6-
similarity matrix between two clusterings by considering all pairs of samples
7-
and counting pairs that are assigned into the same or into different clusters
8-
under the true and predicted clusterings.
6+
Calculate the *confusion matrix* of the two clusterings.
7+
8+
Returns the 2×2 confusion matrix `C` of type `T` (`Int` by default) that
9+
represents partition co-occurrence or similarity matrix between two clusterings
10+
`a` and `b` by considering all pairs of samples and counting pairs that are
11+
assigned into the same or into different clusters.
912
1013
Considering a pair of samples that is in the same group as a **positive pair**,
1114
and a pair is in the different group as a **negative pair**, then the count of
@@ -17,23 +20,27 @@ true negatives is `C₂₂`:
1720
|Positive|C₁₁|C₁₂|
1821
|Negative|C₂₁|C₂₂|
1922
"""
20-
function confusion(a::AbstractVector{<:Integer}, b::AbstractVector{<:Integer})
21-
c = counts(a, b)
23+
function confusion(::Type{T}, a::AbstractVector{<:Integer}, b::AbstractVector{<:Integer}) where T<:Union{Integer, AbstractFloat}
24+
cc = counts(a, b)
25+
c = eltype(cc) === T ? cc : convert(Matrix{T}, cc)
2226

2327
n = sum(c)
24-
nis = sum(abs2, sum(c, dims=2)) # sum of squares of sums of rows
25-
njs = sum(abs2, sum(c, dims=1)) # sum of squares of sums of columns
28+
nis = sum(abs2, sum!(zeros(T, (size(c, 1), 1)), c))
29+
(nis < 0) && OverflowError("sum of squares of sums of rows overflowed")
30+
njs = sum(abs2, sum!(zeros(T, (1, size(c, 2))), c))
31+
(njs < 0) && OverflowError("sum of squares of sums of columns overflowed")
2632

27-
t2 = sum(abs2, c) # sum over rows & columns of nij^2
33+
t2 = sum(abs2, c)
34+
(t2 < 0) && OverflowError("sum of squares of matrix elements overflowed")
2835
t3 = nis + njs
2936
C = [(t2 - n)÷2 (nis - t2)÷2; (njs - t2)÷2 (t2 + n^2 - t3)÷2]
3037
return C
3138
end
3239

33-
confusion(a::ClusteringResult, b::ClusteringResult) =
34-
confusion(assignments(a), assignments(b))
35-
confusion(a::AbstractVector{<:Integer}, b::ClusteringResult) =
36-
confusion(a, assignments(b))
37-
confusion(a::ClusteringResult, b::AbstractVector{<:Integer}) =
38-
confusion(assignments(a), b)
40+
confusion(T, a::ClusteringResultOrAssignments,
41+
b::ClusteringResultOrAssignments) =
42+
confusion(T, assignments(a), assignments(b))
3943

44+
confusion(a::ClusteringResultOrAssignments,
45+
b::ClusteringResultOrAssignments) =
46+
confusion(Int, a, b)

src/randindex.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ Returns a tuple of indices:
2323
> Rand Index.* Psychological Methods, Vol. 9, No. 3: 386-396
2424
"""
2525
function randindex(a, b)
26-
c11, c21, c12, c22 = confusion(a, b) # Table 2 from Steinley 2004
26+
c11, c21, c12, c22 = confusion(Float64, a, b) # Table 2 from Steinley 2004
2727

2828
t = c11 + c12 + c21 + c22 # total number of pairs of entities
2929
A = c11 + c22
@@ -32,7 +32,7 @@ function randindex(a, b)
3232
# expected index
3333
ERI = (c11+c12)*(c11+c21)+(c21+c22)*(c12+c22)
3434
# adjusted Rand - Hubert & Arabie 1985
35-
ARI = D == 0 ? 1.0 : (t*A-ERI)/(t*t-ERI) # (9) from Steinley 2004
35+
ARI = D == 0 ? 1.0 : (t*A - ERI)/(abs2(t) - ERI) # (9) from Steinley 2004
3636

3737
RI = A/t # Rand 1971 # Probability of agreement
3838
MI = D/t # Mirkin 1970 # p(disagreement)

src/utils.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@ Base type for the output of clustering algorithm.
99
"""
1010
abstract type ClusteringResult end
1111

12+
# vector of cluster indices for each clustered point
13+
ClusterAssignments = AbstractVector{<:Integer}
14+
15+
ClusteringResultOrAssignments = Union{ClusteringResult, ClusterAssignments}
16+
1217
# generic functions
1318

1419
"""
@@ -48,6 +53,7 @@ Get the vector of cluster indices for each point.
4853
is assigned.
4954
"""
5055
assignments(R::ClusteringResult) = R.assignments
56+
assignments(A::ClusterAssignments) = A
5157

5258

5359
##### convert display symbol to disp level

test/confusion.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,27 @@ using Clustering
1010
@test confusion([0,0,1], [0,0,0]) == [1 0; 2 0]
1111
@test confusion([0,1,1], [0,0,0]) == [1 0; 2 0]
1212
@test confusion([1,1,1], [0,0,0]) == [3 0; 0 0]
13-
13+
1414
@test confusion([0,0,0], [0,0,1]) == [1 2; 0 0]
1515
@test confusion([0,0,1], [0,0,1]) == [1 0; 0 2]
1616
@test confusion([0,1,1], [0,0,1]) == [0 1; 1 1]
1717
@test confusion([1,1,1], [0,0,1]) == [1 2; 0 0]
18-
18+
1919
@test confusion([0,0,0], [0,1,1]) == [1 2; 0 0]
2020
@test confusion([0,0,1], [0,1,1]) == [0 1; 1 1]
2121
@test confusion([0,1,1], [0,1,1]) == [1 0; 0 2]
2222
@test confusion([1,1,1], [0,1,1]) == [1 2; 0 0]
23-
23+
2424
@test confusion([0,0,0], [1,1,1]) == [3 0; 0 0]
2525
@test confusion([0,0,1], [1,1,1]) == [1 0; 2 0]
2626
@test confusion([0,1,1], [1,1,1]) == [1 0; 2 0]
2727
@test confusion([1,1,1], [1,1,1]) == [3 0; 0 0]
28+
29+
end
30+
31+
@testset "specifying element type" begin
32+
@test @inferred(confusion(Int, [1,1,1], [1,1,1])) isa Matrix{Int}
33+
@test @inferred(confusion(Float64, [1,1,1], [1,1,1])) isa Matrix{Float64}
2834
end
2935

3036
@testset "comparing 2 k-means clusterings" begin
@@ -38,6 +44,9 @@ using Clustering
3844
r2 = kmeans(x, k; maxiter=5)
3945
C = confusion(r1, r2)
4046
@test C == [n*(n-1)/2 0; 0 0]
47+
48+
C = confusion(Float64, r1, r2)
49+
@test C == [n*(n-1)/2 0; 0 0]
4150
end
4251

4352
end

test/randindex.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,15 @@ a3 = [3, 3, 3, 2, 2, 2, 1, 1, 1, 1]
3636

3737
@test randindex(ones(Int, 3), ones(Int, 3)) == (1, 1, 0, 1)
3838

39-
a, b = rand(1:5, 10_000), rand(1:5, 10_000)
40-
@test randindex(a, b)[1] < 1.0e-2
39+
@testset "large independent clusterings (#225)" begin
40+
rng = MersenneTwister(123)
41+
42+
n = 10_000_000
43+
k = 5 # number of clusters
44+
a = rand(rng, 1:k, n)
45+
b = rand(rng, 1:k, n)
46+
47+
@test collect(randindex(a, b)) [0.0, ((k-1)^2 + 1)/k^2, 2*(k-1)/k^2, ((k-2)/k)^2] atol=1e-5
48+
end
4149

4250
end

0 commit comments

Comments
 (0)