Skip to content

Commit e093e02

Browse files
authored
Merge pull request #250 from JuliaStats/ast/fix_kmedoids
kmedoids(): fix duplicate medoids case
2 parents d4cb6af + 0e80afe commit e093e02

File tree

2 files changed

+77
-63
lines changed

2 files changed

+77
-63
lines changed

src/kmedoids.jl

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -182,59 +182,51 @@ end
182182

183183

184184
# update assignments and related quantities
185-
function _kmed_update_assignments!(dist::AbstractMatrix{T}, # in: (n, n)
185+
# returns the total cost and the number of assignment changes
186+
function _kmed_update_assignments!(dist::AbstractMatrix{<:Real}, # in: (n, n)
186187
medoids::AbstractVector{Int}, # in: (k,)
187188
assignments::Vector{Int}, # out: (n,)
188189
groups::Vector{Vector{Int}}, # out: (k,)
189-
costs::Vector{T}, # out: (n,)
190-
isinit::Bool) where T # in
190+
costs::AbstractVector{<:Real},# out: (n,)
191+
initial::Bool) # in
191192
n = size(dist, 1)
192193
k = length(medoids)
193-
ch = 0
194194

195-
if !isinit
196-
for i = 1:k
197-
empty!(groups[i])
198-
end
199-
end
195+
# reset cluster groups (note: assignments are not touched yet)
196+
initial || foreach(empty!, groups)
200197

201198
tcost = 0.0
199+
ch = 0
202200
for j = 1:n
203-
p = 1
201+
p = 1 # initialize the closest medoid for j
204202
mv = dist[medoids[1], j]
205203

206-
for i = 2:k
207-
v = dist[medoids[i], j]
208-
if v < mv
204+
# find the closest medoid for j
205+
@inbounds for i = 2:k
206+
m = medoids[i]
207+
v = dist[m, j]
208+
# assign if current medoid is closer or if it is j itself
209+
if (v < mv) || (m == j)
210+
(v <= mv) || throw(ArgumentError("sample #$j reassigned from medoid[$p]=#$(medoids[p]) (distance=$mv) to medoid[$i]=#$m (distance=$v); check the distance matrix correctness"))
209211
p = i
210212
mv = v
211213
end
212214
end
213215

214-
if isinit
215-
assignments[j] = p
216-
else
217-
a = assignments[j]
218-
if p != a
219-
ch += 1
220-
end
221-
assignments[j] = p
222-
end
223-
216+
ch += !initial && (p != assignments[j])
217+
assignments[j] = p
224218
costs[j] = mv
225219
tcost += mv
226220
push!(groups[p], j)
227221
end
228222

229-
return (tcost, ch)::Tuple{Float64, Int}
223+
return (tcost, ch)
230224
end
231225

232226

233227
# find medoid for a given group
234-
#
235-
# TODO: faster way without creating temporary arrays
236-
function _find_medoid(dist::AbstractMatrix, grp::Vector{Int})
228+
function _find_medoid(dist::AbstractMatrix, grp::AbstractVector{Int})
237229
@assert !isempty(grp)
238-
p = argmin(sum(dist[grp, grp], dims=2))
239-
return grp[p]::Int
230+
p = argmin(sum(view(dist, grp, grp), dims=2))
231+
return grp[p]
240232
end

test/kmedoids.jl

Lines changed: 56 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,15 @@ include("test_helpers.jl")
99
Random.seed!(34568)
1010
@test_throws ArgumentError kmedoids(randn(2, 3), 1)
1111
@test_throws ArgumentError kmedoids(randn(2, 3), 4)
12-
dist = inv.(max.(pairwise(Euclidean(), randn(2, 3), dims=2), 0.1))
13-
@test kmedoids(dist, 2) isa KmedoidsResult
12+
dist = max.(pairwise(Euclidean(), randn(2, 3), dims=2), 0.1)
13+
@test @inferred(kmedoids(dist, 2)) isa KmedoidsResult
14+
# incorrect distance matrix
15+
invdist = inv.(max.(pairwise(Euclidean(), randn(2, 3), dims=2), 0.1))
16+
@test_throws ArgumentError kmedoids(invdist, 2)
17+
1418
@test_throws ArgumentError kmedoids(dist, 2, display=:mylog)
1519
for disp in keys(Clustering.DisplayLevels)
16-
@test kmedoids(dist, 2, display=disp) isa KmedoidsResult
20+
@test @inferred(kmedoids(dist, 2, display=disp)) isa KmedoidsResult
1721
end
1822
end
1923

@@ -28,7 +32,7 @@ dist = pairwise(SqEuclidean(), X, dims=2)
2832
@assert size(dist) == (n, n)
2933

3034
Random.seed!(34568) # reset seed again to known state
31-
R = kmedoids(dist, k)
35+
R = @inferred(kmedoids(dist, k))
3236
@test isa(R, KmedoidsResult)
3337
@test nclusters(R) == k
3438
@test length(R.medoids) == length(unique(R.medoids))
@@ -48,36 +52,54 @@ R = kmedoids(dist, k)
4852
end
4953
end
5054

51-
# k=1 and k=n cases
52-
x = pairwise(SqEuclidean(), [1 2 3; .1 .2 .3; 4 5.6 7], dims=2)
53-
kmed1 = kmedoids(x, 1)
54-
@test nclusters(kmed1) == 1
55-
@test assignments(kmed1) == [1, 1, 1]
56-
@test kmed1.medoids == [2]
57-
kmed3 = kmedoids(x, 3)
58-
@test nclusters(kmed3) == 3
59-
@test sort(assignments(kmed3)) == [1, 2, 3]
60-
@test sort(kmed3.medoids) == [1, 2, 3]
61-
62-
63-
# this data set has three obvious groups:
64-
# group 1: [1, 3, 4], values: [1, 2, 3]
65-
# group 2: [2, 5, 7], values: [6, 7, 8]
66-
# group 3: [6, 8, 9], values: [21, 20, 22]
67-
#
68-
69-
X = reshape(map(Float64, [1, 6, 2, 3, 7, 21, 8, 20, 22]), 1, 9)
70-
dist = pairwise(SqEuclidean(), X, dims=2)
55+
@testset "Duplicated points (#231)" begin
56+
pts = [0.0 0.0]
57+
dists = pairwise(SqEuclidean(), pts, dims=2)
58+
dupmed = kmedoids(dists, 2)
59+
@test nclusters(dupmed) == 2
60+
@test sort(dupmed.medoids) == [1, 2]
61+
@test sort(dupmed.assignments) == [1, 2]
62+
end
7163

72-
R = kmedoids!(dist, [1, 2, 6])
73-
@test isa(R, KmedoidsResult)
74-
@test nclusters(R) == 3
75-
@test R.medoids == [3, 5, 6]
76-
@test R.assignments == [1, 2, 1, 1, 2, 3, 2, 3, 3]
77-
@test counts(R) == [3, 3, 3]
78-
@test wcounts(R) == counts(R)
79-
@test R.costs [1, 1, 0, 1, 0, 0, 1, 1, 1]
80-
@test R.totalcost 6.0
81-
@test R.converged
64+
@testset "Toy example #1" begin
65+
pts = [1 2 3; .1 .2 .3; 4 5.6 7]
66+
# k=1 and k=n cases
67+
dists = pairwise(SqEuclidean(), pts, dims=2)
68+
69+
@testset "k=1" begin
70+
kmed1 = @inferred(kmedoids(dists, 1))
71+
@test nclusters(kmed1) == 1
72+
@test assignments(kmed1) == [1, 1, 1]
73+
@test kmed1.medoids == [2]
74+
end
75+
76+
@testset "k=3" begin
77+
kmed3 = @inferred(kmedoids(dists, 3))
78+
@test nclusters(kmed3) == 3
79+
@test sort(assignments(kmed3)) == [1, 2, 3]
80+
@test sort(kmed3.medoids) == [1, 2, 3]
81+
end
82+
end
83+
84+
@testset "Toy example #2" begin
85+
pts = reshape(map(Float64, [1, 6, 2, 3, 7, 21, 8, 20, 22]), 1, 9)
86+
# this data set has three obvious groups:
87+
# group 1: [1, 3, 4], values: [1, 2, 3]
88+
# group 2: [2, 5, 7], values: [6, 7, 8]
89+
# group 3: [6, 8, 9], values: [21, 20, 22]
90+
91+
dists = pairwise(SqEuclidean(), pts, dims=2)
92+
93+
R = @inferred(kmedoids!(dists, [1, 2, 6]))
94+
@test isa(R, KmedoidsResult)
95+
@test nclusters(R) == 3
96+
@test R.medoids == [3, 5, 6]
97+
@test R.assignments == [1, 2, 1, 1, 2, 3, 2, 3, 3]
98+
@test counts(R) == [3, 3, 3]
99+
@test wcounts(R) == counts(R)
100+
@test R.costs [1, 1, 0, 1, 0, 0, 1, 1, 1]
101+
@test R.totalcost 6.0
102+
@test R.converged
103+
end
82104

83105
end

0 commit comments

Comments
 (0)