Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Clustering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ module Clustering
kmpp, kmpp_by_costs,

# kmeans
kmeans, kmeans!, KmeansResult,
kmeans, kmeans!, KmeansResult, get_cluster_assignments,

# kmedoids
kmedoids, kmedoids!, KmedoidsResult,
Expand Down
37 changes: 37 additions & 0 deletions src/kmeans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -390,4 +390,41 @@ function repick_unused_centers(X::AbstractMatrix{<:Real}, # in: the data matrix
colwise!(ds, distance, v, X)
tcosts = min(tcosts, ds)
end

"""
get_cluster_assignments(X::Matrix{T}, R::KmeansResult; ...) -> Vector{Int}

Perform the clustering assigment of ``n`` points into `k` clusters,
using the learned prototopyes from the input `KmeansResult`.

Note: This method is usefull when clustering new data leveraging a fitted model.

# Arguments
- `X`: Input data to be clustered.
- `R`: Fitted keamns result.
"""
function get_cluster_assignments(
X::Matrix{T},
R::KmeansResult,
distance::SemiMetric=SqEuclidean()) where {F<:Function, T}

cluster_assignments = zeros(Int, size(X,2))

Threads.@threads for n in axes(X,2)
min_dist = typemax(T)
cluster_assignment = 0

for k in axes(R.centers, 2)
dist = distance(@view(X[:,n]),@view(R.centers[:,k]))
if dist < min_dist
min_dist = dist
cluster_assignment = k
end
end
cluster_assignments[n] = cluster_assignment
end

return cluster_assignments
end

end
7 changes: 7 additions & 0 deletions test/kmeans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,4 +204,11 @@ end
end
end

@testset "get cluster assigments" begin
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also add the testset to test/utils.jl (it would be the new file that should be included from runtests.jl before all others) testing that assign_clusters(.., R) throws "not implemented" exception for an arbitrary ClusteringResult object other than KmeansResult, e.g. for KMedoidsResult.

Copy link
Author

@davidbp davidbp Apr 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added the test to cover the case assign_clusters does not have correct implementation for non kmeans ClusteringResult.

X = rand(5, 100)
R = kmeans(X, 10; maxiter=200)
clusters_from_get_cluster_assignments = get_cluster_assignments(X, R);
@test R.assignments == clusters_from_get_cluster_assignments
end

end