Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 24 additions & 17 deletions src/confusion.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""
confusion(a::Union{ClusteringResult, AbstractVector},
b::Union{ClusteringResult, AbstractVector}) -> Matrix{Int}
confusion([T = Int],
a::Union{ClusteringResult, AbstractVector},
b::Union{ClusteringResult, AbstractVector}) -> Matrix{T}

Return 2×2 confusion matrix `C` that represents partition co-occurrence or
similarity matrix between two clusterings by considering all pairs of samples
and counting pairs that are assigned into the same or into different clusters
under the true and predicted clusterings.
Calculate the *confusion matrix* of the two clusterings.

Returns the 2×2 confusion matrix `C` of type `T` (`Int` by default) that
represents partition co-occurrence or similarity matrix between two clusterings
`a` and `b` by considering all pairs of samples and counting pairs that are
assigned into the same or into different clusters.

Considering a pair of samples that is in the same group as a **positive pair**,
and a pair is in the different group as a **negative pair**, then the count of
Expand All @@ -17,23 +20,27 @@ true negatives is `C₂₂`:
|Positive|C₁₁|C₁₂|
|Negative|C₂₁|C₂₂|
"""
function confusion(a::AbstractVector{<:Integer}, b::AbstractVector{<:Integer})
c = counts(a, b)
function confusion(::Type{T}, a::AbstractVector{<:Integer}, b::AbstractVector{<:Integer}) where T<:Union{Integer, AbstractFloat}
cc = counts(a, b)
c = eltype(cc) === T ? cc : convert(Matrix{T}, cc)

n = sum(c)
nis = sum(abs2, sum(c, dims=2)) # sum of squares of sums of rows
njs = sum(abs2, sum(c, dims=1)) # sum of squares of sums of columns
nis = sum(abs2, sum!(zeros(T, (size(c, 1), 1)), c))
(nis < 0) && OverflowError("sum of squares of sums of rows overflowed")
njs = sum(abs2, sum!(zeros(T, (1, size(c, 2))), c))
(njs < 0) && OverflowError("sum of squares of sums of columns overflowed")

t2 = sum(abs2, c) # sum over rows & columns of nij^2
t2 = sum(abs2, c)
(t2 < 0) && OverflowError("sum of squares of matrix elements overflowed")
t3 = nis + njs
C = [(t2 - n)÷2 (nis - t2)÷2; (njs - t2)÷2 (t2 + n^2 - t3)÷2]
return C
end

confusion(a::ClusteringResult, b::ClusteringResult) =
confusion(assignments(a), assignments(b))
confusion(a::AbstractVector{<:Integer}, b::ClusteringResult) =
confusion(a, assignments(b))
confusion(a::ClusteringResult, b::AbstractVector{<:Integer}) =
confusion(assignments(a), b)
confusion(T, a::ClusteringResultOrAssignments,
b::ClusteringResultOrAssignments) =
confusion(T, assignments(a), assignments(b))

confusion(a::ClusteringResultOrAssignments,
b::ClusteringResultOrAssignments) =
confusion(Int, a, b)
4 changes: 2 additions & 2 deletions src/randindex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Returns a tuple of indices:
> Rand Index.* Psychological Methods, Vol. 9, No. 3: 386-396
"""
function randindex(a, b)
c11, c21, c12, c22 = confusion(a, b) # Table 2 from Steinley 2004
c11, c21, c12, c22 = confusion(Float64, a, b) # Table 2 from Steinley 2004

t = c11 + c12 + c21 + c22 # total number of pairs of entities
A = c11 + c22
Expand All @@ -32,7 +32,7 @@ function randindex(a, b)
# expected index
ERI = (c11+c12)*(c11+c21)+(c21+c22)*(c12+c22)
# adjusted Rand - Hubert & Arabie 1985
ARI = D == 0 ? 1.0 : (t*A-ERI)/(t*t-ERI) # (9) from Steinley 2004
ARI = D == 0 ? 1.0 : (t*A - ERI)/(abs2(t) - ERI) # (9) from Steinley 2004

RI = A/t # Rand 1971 # Probability of agreement
MI = D/t # Mirkin 1970 # p(disagreement)
Expand Down
6 changes: 6 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ Base type for the output of clustering algorithm.
"""
abstract type ClusteringResult end

# vector of cluster indices for each clustered point
ClusterAssignments = AbstractVector{<:Integer}

ClusteringResultOrAssignments = Union{ClusteringResult, ClusterAssignments}

# generic functions

"""
Expand Down Expand Up @@ -48,6 +53,7 @@ Get the vector of cluster indices for each point.
is assigned.
"""
assignments(R::ClusteringResult) = R.assignments
assignments(A::ClusterAssignments) = A


##### convert display symbol to disp level
Expand Down
15 changes: 12 additions & 3 deletions test/confusion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,27 @@ using Clustering
@test confusion([0,0,1], [0,0,0]) == [1 0; 2 0]
@test confusion([0,1,1], [0,0,0]) == [1 0; 2 0]
@test confusion([1,1,1], [0,0,0]) == [3 0; 0 0]

@test confusion([0,0,0], [0,0,1]) == [1 2; 0 0]
@test confusion([0,0,1], [0,0,1]) == [1 0; 0 2]
@test confusion([0,1,1], [0,0,1]) == [0 1; 1 1]
@test confusion([1,1,1], [0,0,1]) == [1 2; 0 0]

@test confusion([0,0,0], [0,1,1]) == [1 2; 0 0]
@test confusion([0,0,1], [0,1,1]) == [0 1; 1 1]
@test confusion([0,1,1], [0,1,1]) == [1 0; 0 2]
@test confusion([1,1,1], [0,1,1]) == [1 2; 0 0]

@test confusion([0,0,0], [1,1,1]) == [3 0; 0 0]
@test confusion([0,0,1], [1,1,1]) == [1 0; 2 0]
@test confusion([0,1,1], [1,1,1]) == [1 0; 2 0]
@test confusion([1,1,1], [1,1,1]) == [3 0; 0 0]

end

@testset "specifying element type" begin
@test @inferred(confusion(Int, [1,1,1], [1,1,1])) isa Matrix{Int}
@test @inferred(confusion(Float64, [1,1,1], [1,1,1])) isa Matrix{Float64}
end

@testset "comparing 2 k-means clusterings" begin
Expand All @@ -38,6 +44,9 @@ using Clustering
r2 = kmeans(x, k; maxiter=5)
C = confusion(r1, r2)
@test C == [n*(n-1)/2 0; 0 0]

C = confusion(Float64, r1, r2)
@test C == [n*(n-1)/2 0; 0 0]
end

end
Expand Down
12 changes: 10 additions & 2 deletions test/randindex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,15 @@ a3 = [3, 3, 3, 2, 2, 2, 1, 1, 1, 1]

@test randindex(ones(Int, 3), ones(Int, 3)) == (1, 1, 0, 1)

a, b = rand(1:5, 10_000), rand(1:5, 10_000)
@test randindex(a, b)[1] < 1.0e-2
@testset "large independent clusterings (#225)" begin
rng = MersenneTwister(123)

n = 10_000_000
k = 5 # number of clusters
a = rand(rng, 1:k, n)
b = rand(rng, 1:k, n)

@test collect(randindex(a, b)) ≈ [0.0, ((k-1)^2 + 1)/k^2, 2*(k-1)/k^2, ((k-2)/k)^2] atol=1e-5
end

end