|
| 1 | +# NOTE: there's a `kmeans!` function that updates centers, maybe a candidate |
| 2 | +# for the `update` machinery. Same for `kmedoids!` |
| 3 | +# NOTE: if the prediction is done on the original array, just the assignment |
| 4 | +# should be returned, unclear what's the best way of doing this. |
| 5 | + |
| 6 | +module MLJClusteringInterface |
| 7 | + |
| 8 | +# =================================================================== |
| 9 | +# IMPORTS |
| 10 | +import Clustering |
| 11 | +import MLJModelInterface |
| 12 | +import MLJModelInterface: Continuous, Count, Finite, MulticlassTable, OrderedFactor, |
| 13 | + @mlj_model, metadata_model, metadata_pkg |
| 14 | + |
| 15 | +using Distances |
| 16 | + |
| 17 | +# =================================================================== |
| 18 | +## EXPORTS |
| 19 | +export KMeans, KMedoids |
| 20 | + |
| 21 | +# =================================================================== |
| 22 | +## CONSTANTS |
| 23 | +# Define constants for easy referencing of packages |
| 24 | +const MMI = MLJModelInterface |
| 25 | +const Cl = Clustering |
| 26 | + |
| 27 | +# Definitions of model descriptions for use in model doc-strings. |
| 28 | +const KMeansDescription =""" |
| 29 | +K-Means algorithm: find K centroids corresponding to K clusters in the data. |
| 30 | +""" |
| 31 | + |
| 32 | +const KMedoidsDescription =""" |
| 33 | +K-Medoids algorithm: find K centroids corresponding to K clusters in the data. |
| 34 | +Unlike K-Means, the centroids are found among data points themselves. |
| 35 | +""" |
| 36 | + |
| 37 | +const KMFields =""" |
| 38 | + ## Keywords |
| 39 | +
|
| 40 | + * `k=3` : number of centroids |
| 41 | + * `metric` : distance metric to use |
| 42 | +""" |
| 43 | + |
| 44 | +#### |
| 45 | +#### KMeans |
| 46 | +#### |
| 47 | +""" |
| 48 | +KMeans(; kwargs...) |
| 49 | +
|
| 50 | +$KMeansDescription |
| 51 | +
|
| 52 | +$KMFields |
| 53 | +
|
| 54 | +See also the |
| 55 | +[package documentation](http://juliastats.github.io/Clustering.jl/latest/kmeans.html). |
| 56 | +""" |
| 57 | + |
| 58 | +@mlj_model mutable struct KMeans <: MMI.Unsupervised |
| 59 | + k::Int = 3::(_ ≥ 2) |
| 60 | + metric::SemiMetric = SqEuclidean() |
| 61 | +end |
| 62 | + |
| 63 | +#### |
| 64 | +#### KMeans |
| 65 | +#### |
| 66 | + |
| 67 | +function MMI.fit(model::KMeans, verbosity::Int, X) |
| 68 | + # NOTE: using transpose here to get a LinearAlgebra.Transpose object |
| 69 | + # which Kmeans can handle. |
| 70 | + Xarray = transpose(MMI.matrix(X)) |
| 71 | + result = Cl.kmeans(Xarray, model.k; distance=model.metric) |
| 72 | + cluster_labels = MMI.categorical(1:model.k) |
| 73 | + fitresult = (result.centers, cluster_labels) # centers (p x k) |
| 74 | + cache = nothing |
| 75 | + report = ( |
| 76 | + assignments=result.assignments, # size n |
| 77 | + cluster_labels=cluster_labels |
| 78 | + ) |
| 79 | + return fitresult, cache, report |
| 80 | +end |
| 81 | + |
| 82 | +MMI.fitted_params(::KMeans, fitresult) = (centers=fitresult[1],) |
| 83 | + |
| 84 | +function MMI.transform(model::KMeans, fitresult, X) |
| 85 | + # pairwise distance from samples to centers |
| 86 | + X̃ = pairwise( |
| 87 | + model.metric, |
| 88 | + transpose(MMI.matrix(X)), |
| 89 | + fitresult[1], |
| 90 | + dims=2 |
| 91 | + ) |
| 92 | + return MMI.table(X̃, prototype=X) |
| 93 | +end |
| 94 | + |
| 95 | +""" |
| 96 | +KMedoids(; kwargs...) |
| 97 | +
|
| 98 | +$KMedoidsDescription |
| 99 | +
|
| 100 | +$KMFields |
| 101 | +
|
| 102 | +See also the |
| 103 | +[package documentation](http://juliastats.github.io/Clustering.jl/latest/kmedoids.html). |
| 104 | +""" |
| 105 | +@mlj_model mutable struct KMedoids <: MMI.Unsupervised |
| 106 | + k::Int = 3::(_ ≥ 2) |
| 107 | + metric::SemiMetric = SqEuclidean() |
| 108 | +end |
| 109 | + |
| 110 | +function MMI.fit(model::KMedoids, verbosity::Int, X) |
| 111 | + # NOTE: using transpose=true will materialize the transpose (~ permutedims), KMedoids |
| 112 | + # does not yet accept LinearAlgebra.Transpose |
| 113 | + Xarray = MMI.matrix(X, transpose=true) |
| 114 | + # cost matrix: all the pairwise distances |
| 115 | + cost_array = pairwise(model.metric, Xarray, dims=2) # n x n |
| 116 | + result = Cl.kmedoids(cost_array, model.k) |
| 117 | + cluster_labels = MMI.categorical(1:model.k) |
| 118 | + fitresult = (view(Xarray, :, result.medoids), cluster_labels) # medoids |
| 119 | + cache = nothing |
| 120 | + report = ( |
| 121 | + assignments=result.assignments, # size n |
| 122 | + cluster_labels=cluster_labels |
| 123 | + ) |
| 124 | + return fitresult, cache, report |
| 125 | +end |
| 126 | + |
| 127 | +MMI.fitted_params(::KMedoids, fitresult) = (medoids=fitresult[1],) |
| 128 | + |
| 129 | +function MMI.transform(model::KMedoids, fitresult, X) |
| 130 | + # pairwise distance from samples to medoids |
| 131 | + X̃ = pairwise( |
| 132 | + model.metric, |
| 133 | + MMI.matrix(X, transpose=true), |
| 134 | + fitresult[1], dims=2 |
| 135 | + ) |
| 136 | + return MMI.table(X̃, prototype=X) |
| 137 | +end |
| 138 | + |
| 139 | +#### |
| 140 | +#### Predict methods |
| 141 | +#### |
| 142 | + |
| 143 | +function MMI.predict(model::Union{KMeans,KMedoids}, fitresult, Xnew) |
| 144 | + locations, cluster_labels = fitresult |
| 145 | + Xarray = MMI.matrix(Xnew) |
| 146 | + (n, p), k = size(Xarray), model.k |
| 147 | + pred = zeros(Int, n) |
| 148 | + |
| 149 | + @inbounds for i in 1:n |
| 150 | + minv = Inf |
| 151 | + @inbounds @simd for j in 1:k |
| 152 | + curv = evaluate( |
| 153 | + model.metric, view(Xarray, i, :), view(locations, :, j) |
| 154 | + ) |
| 155 | + P = curv < minv |
| 156 | + pred[i] = j * P + pred[i] * !P # if P is true --> j |
| 157 | + minv = curv * P + minv * !P # if P is true --> curvalue |
| 158 | + end |
| 159 | + end |
| 160 | + return cluster_labels[pred] |
| 161 | +end |
| 162 | + |
| 163 | +#### |
| 164 | +#### METADATA |
| 165 | +#### |
| 166 | + |
| 167 | +metadata_pkg.( |
| 168 | + (KMeans, KMedoids), |
| 169 | + name="Clustering", |
| 170 | + uuid="aaaa29a8-35af-508c-8bc3-b662a17a0fe5", |
| 171 | + url="https://github.com/JuliaStats/Clustering.jl", |
| 172 | + julia=true, |
| 173 | + license="MIT", |
| 174 | + is_wrapper=false |
| 175 | +) |
| 176 | + |
| 177 | +metadata_model( |
| 178 | + KMeans, |
| 179 | + input = MMI.Table(Continuous), |
| 180 | + output = MMI.Table(Continuous), |
| 181 | + weights = false, |
| 182 | + descr = KMeansDescription |
| 183 | +) |
| 184 | + |
| 185 | +metadata_model( |
| 186 | + KMedoids, |
| 187 | + input = MMI.Table(Continuous), |
| 188 | + output = MMI.Table(Continuous), |
| 189 | + weights = false, |
| 190 | + descr = KMedoidsDescription |
| 191 | +) |
| 192 | + |
| 193 | +end # module |
| 194 | + |
0 commit comments