diff --git a/src/confusion.jl b/src/confusion.jl index 0aabb97d..89cc42d7 100644 --- a/src/confusion.jl +++ b/src/confusion.jl @@ -1,11 +1,14 @@ """ - confusion(a::Union{ClusteringResult, AbstractVector}, - b::Union{ClusteringResult, AbstractVector}) -> Matrix{Int} + confusion([T = Int], + a::Union{ClusteringResult, AbstractVector}, + b::Union{ClusteringResult, AbstractVector}) -> Matrix{T} -Return 2×2 confusion matrix `C` that represents partition co-occurrence or -similarity matrix between two clusterings by considering all pairs of samples -and counting pairs that are assigned into the same or into different clusters -under the true and predicted clusterings. +Calculate the *confusion matrix* of the two clusterings. + +Returns the 2×2 confusion matrix `C` of type `T` (`Int` by default) that +represents partition co-occurrence or similarity matrix between two clusterings +`a` and `b` by considering all pairs of samples and counting pairs that are +assigned into the same or into different clusters. Considering a pair of samples that is in the same group as a **positive pair**, and a pair is in the different group as a **negative pair**, then the count of @@ -17,23 +20,27 @@ true negatives is `C₂₂`: |Positive|C₁₁|C₁₂| |Negative|C₂₁|C₂₂| """ -function confusion(a::AbstractVector{<:Integer}, b::AbstractVector{<:Integer}) - c = counts(a, b) +function confusion(::Type{T}, a::AbstractVector{<:Integer}, b::AbstractVector{<:Integer}) where T<:Union{Integer, AbstractFloat} + cc = counts(a, b) + c = eltype(cc) === T ? cc : convert(Matrix{T}, cc) n = sum(c) - nis = sum(abs2, sum(c, dims=2)) # sum of squares of sums of rows - njs = sum(abs2, sum(c, dims=1)) # sum of squares of sums of columns + nis = sum(abs2, sum!(zeros(T, (size(c, 1), 1)), c)) + (nis < 0) && OverflowError("sum of squares of sums of rows overflowed") + njs = sum(abs2, sum!(zeros(T, (1, size(c, 2))), c)) + (njs < 0) && OverflowError("sum of squares of sums of columns overflowed") - t2 = sum(abs2, c) # sum over rows & columns of nij^2 + t2 = sum(abs2, c) + (t2 < 0) && OverflowError("sum of squares of matrix elements overflowed") t3 = nis + njs C = [(t2 - n)÷2 (nis - t2)÷2; (njs - t2)÷2 (t2 + n^2 - t3)÷2] return C end -confusion(a::ClusteringResult, b::ClusteringResult) = - confusion(assignments(a), assignments(b)) -confusion(a::AbstractVector{<:Integer}, b::ClusteringResult) = - confusion(a, assignments(b)) -confusion(a::ClusteringResult, b::AbstractVector{<:Integer}) = - confusion(assignments(a), b) +confusion(T, a::ClusteringResultOrAssignments, + b::ClusteringResultOrAssignments) = + confusion(T, assignments(a), assignments(b)) +confusion(a::ClusteringResultOrAssignments, + b::ClusteringResultOrAssignments) = + confusion(Int, a, b) diff --git a/src/randindex.jl b/src/randindex.jl index 3cc146f2..08d41874 100644 --- a/src/randindex.jl +++ b/src/randindex.jl @@ -23,7 +23,7 @@ Returns a tuple of indices: > Rand Index.* Psychological Methods, Vol. 9, No. 3: 386-396 """ function randindex(a, b) - c11, c21, c12, c22 = confusion(a, b) # Table 2 from Steinley 2004 + c11, c21, c12, c22 = confusion(Float64, a, b) # Table 2 from Steinley 2004 t = c11 + c12 + c21 + c22 # total number of pairs of entities A = c11 + c22 @@ -32,7 +32,7 @@ function randindex(a, b) # expected index ERI = (c11+c12)*(c11+c21)+(c21+c22)*(c12+c22) # adjusted Rand - Hubert & Arabie 1985 - ARI = D == 0 ? 1.0 : (t*A-ERI)/(t*t-ERI) # (9) from Steinley 2004 + ARI = D == 0 ? 1.0 : (t*A - ERI)/(abs2(t) - ERI) # (9) from Steinley 2004 RI = A/t # Rand 1971 # Probability of agreement MI = D/t # Mirkin 1970 # p(disagreement) diff --git a/src/utils.jl b/src/utils.jl index b832e86b..a01db9dc 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -9,6 +9,11 @@ Base type for the output of clustering algorithm. """ abstract type ClusteringResult end +# vector of cluster indices for each clustered point +ClusterAssignments = AbstractVector{<:Integer} + +ClusteringResultOrAssignments = Union{ClusteringResult, ClusterAssignments} + # generic functions """ @@ -48,6 +53,7 @@ Get the vector of cluster indices for each point. is assigned. """ assignments(R::ClusteringResult) = R.assignments +assignments(A::ClusterAssignments) = A ##### convert display symbol to disp level diff --git a/test/confusion.jl b/test/confusion.jl index 24c931ac..c456472f 100644 --- a/test/confusion.jl +++ b/test/confusion.jl @@ -10,21 +10,27 @@ using Clustering @test confusion([0,0,1], [0,0,0]) == [1 0; 2 0] @test confusion([0,1,1], [0,0,0]) == [1 0; 2 0] @test confusion([1,1,1], [0,0,0]) == [3 0; 0 0] - + @test confusion([0,0,0], [0,0,1]) == [1 2; 0 0] @test confusion([0,0,1], [0,0,1]) == [1 0; 0 2] @test confusion([0,1,1], [0,0,1]) == [0 1; 1 1] @test confusion([1,1,1], [0,0,1]) == [1 2; 0 0] - + @test confusion([0,0,0], [0,1,1]) == [1 2; 0 0] @test confusion([0,0,1], [0,1,1]) == [0 1; 1 1] @test confusion([0,1,1], [0,1,1]) == [1 0; 0 2] @test confusion([1,1,1], [0,1,1]) == [1 2; 0 0] - + @test confusion([0,0,0], [1,1,1]) == [3 0; 0 0] @test confusion([0,0,1], [1,1,1]) == [1 0; 2 0] @test confusion([0,1,1], [1,1,1]) == [1 0; 2 0] @test confusion([1,1,1], [1,1,1]) == [3 0; 0 0] + + end + + @testset "specifying element type" begin + @test @inferred(confusion(Int, [1,1,1], [1,1,1])) isa Matrix{Int} + @test @inferred(confusion(Float64, [1,1,1], [1,1,1])) isa Matrix{Float64} end @testset "comparing 2 k-means clusterings" begin @@ -38,6 +44,9 @@ using Clustering r2 = kmeans(x, k; maxiter=5) C = confusion(r1, r2) @test C == [n*(n-1)/2 0; 0 0] + + C = confusion(Float64, r1, r2) + @test C == [n*(n-1)/2 0; 0 0] end end diff --git a/test/randindex.jl b/test/randindex.jl index c6ad3de4..f6ee629d 100644 --- a/test/randindex.jl +++ b/test/randindex.jl @@ -36,7 +36,15 @@ a3 = [3, 3, 3, 2, 2, 2, 1, 1, 1, 1] @test randindex(ones(Int, 3), ones(Int, 3)) == (1, 1, 0, 1) -a, b = rand(1:5, 10_000), rand(1:5, 10_000) -@test randindex(a, b)[1] < 1.0e-2 +@testset "large independent clusterings (#225)" begin + rng = MersenneTwister(123) + + n = 10_000_000 + k = 5 # number of clusters + a = rand(rng, 1:k, n) + b = rand(rng, 1:k, n) + + @test collect(randindex(a, b)) ≈ [0.0, ((k-1)^2 + 1)/k^2, 2*(k-1)/k^2, ((k-2)/k)^2] atol=1e-5 +end end