Skip to content

Commit a910032

Browse files
authored
apply code review
1 parent 34dc9ca commit a910032

File tree

1 file changed

+34
-31
lines changed

1 file changed

+34
-31
lines changed

src/MLJClusteringInterface.jl

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ const MMI = MLJModelInterface
2525
const Cl = Clustering
2626

2727

28+
2829
const PKG = "MLJClusteringInterface"
2930

3031
####
@@ -155,18 +156,12 @@ metadata_model(
155156
weights = false,
156157
path = "$(PKG).KMedoids"
157158
)
158-
159159
"""
160160
$(MMI.doc_header(KMeans))
161161
162162
163-
`KMeans` is a classical method for clustering or vector quantization. It produces a fixed
164-
number of clusters, each associated with a center (also known as a prototype), and each data
165-
point is assigned to a cluster with the nearest center. Works best with euclidean distance
166-
measures, for non-euclidean measures use [`KMedoids`](@ref).
167-
168-
From a mathematical standpoint, K-means is a coordinate descent algorithm that solves the following optimization problem:
169-
minimize ∑i=1n∥xi−μzi∥2 w.r.t. (μ,z):
163+
`KMeans`: The K-Means algorithm finds K centroids corresponding to K clusters in
164+
the data. The clusters are assumed to be elliptical, should be used with a euclidean distance metric
170165
171166
# Training data
172167
@@ -177,7 +172,7 @@ In MLJ or MLJBase, bind an instance `model` to data with
177172
Where
178173
179174
- `X`: is any table of input features (eg, a `DataFrame`) whose columns
180-
are of scitype `Continuous`; check the column scitypes with `schema(X)`
175+
are of scitype `Continuous`; check the scitype with `schema(X)`
181176
182177
- `y`: is the target, which can be any `AbstractVector` whose element
183178
scitype is `Count`; check the scitype with `schema(y)`
@@ -187,15 +182,15 @@ Train the machine using `fit!(mach, rows=...)`.
187182
# Hyper-parameters
188183
189184
- `k=3`: The number of centroids to use in clustering.
190-
- `metric::Distances.SqEuclidean`: The metric used to calculate the clustering distance
191-
matrix. Must be a subtype of `Distances.SemiMetric` from Distances.jl.
185+
- `metric::SemiMetric=SqEuclidean`: The metric used to calculate the clustering distance
186+
matrix
192187
193188
# Operations
194189
195-
- `predict(mach, Xnew)`: return learned cluster labels for a new
196-
table of inputs `Xnew` having the same scitype as `X` above.
190+
- `predict(mach, Xnew)`: return predictions of the target given new
191+
features `Xnew` having the same Scitype as `X` above.
197192
- `transform(mach, Xnew)`: instead return the mean pairwise distances from
198-
new samples to the cluster centers.
193+
new samples to the cluster centers
199194
200195
# Fitted parameters
201196
@@ -214,17 +209,21 @@ The fields of `report(mach)` are:
214209
215210
```
216211
using MLJ
212+
using Distances
217213
using Test
218214
KMeans = @load KMeans pkg=Clustering
219215
220216
X, y = @load_iris
221217
model = KMeans(k=3)
222218
mach = machine(model, X) |> fit!
223219
224-
preds = predict(mach, X)
225-
@test preds == report(mach).assignments
220+
yhat = predict(mach, X)
221+
@test yhat == report(mach).assignments
226222
227-
center_dists = transform(mach, MLJ.table(fitted_params(mach).centers'))
223+
compare = zip(yhat, y) |> collect;
224+
compare[1:8] # clusters align with classes
225+
226+
center_dists = transform(mach, fitted_params(mach).centers')
228227
229228
@test center_dists[1][1] == 0.0
230229
@test center_dists[2][2] == 0.0
@@ -238,11 +237,9 @@ KMeans
238237
"""
239238
$(MMI.doc_header(KMedoids))
240239
241-
`KMedoids`:K-medoids is a clustering algorithm that works by finding k data points (called
242-
medoids) such that the total distance between each data point and the closest medoid is
243-
minimal. The function implements a K-means style algorithm instead of PAM (Partitioning
244-
Around Medoids). K-means style algorithm converges in fewer iterations, but was shown to
245-
produce worse (10-20% higher total costs) results (see e.g. (https://juliastats.org/Clustering.jl/latest/kmedoids.html#kmedoid_refs-1)[Schubert & Rousseeuw (2019)]).
240+
`KMedoids`: The K-Medoids algorithm finds K centroids corresponding to K clusters in the
241+
data. Unlike K-Means, the centroids are found among data points themselves. Clusters
242+
are not assumed to be elliptical. Should be used with a non-euclidean distance metric
246243
247244
# Training data
248245
@@ -253,7 +250,7 @@ In MLJ or MLJBase, bind an instance `model` to data with
253250
Where
254251
255252
- `X`: is any table of input features (eg, a `DataFrame`) whose columns
256-
are of scitype `Continuous`; check the column scitypes with `schema(X)`
253+
are of scitype `Continuous`; check the scitype with `schema(X)`
257254
258255
- `y`: is the target, which can be any `AbstractVector` whose element
259256
scitype is `Count`; check the scitype with `schema(y)`
@@ -263,15 +260,15 @@ Train the machine using `fit!(mach, rows=...)`.
263260
# Hyper-parameters
264261
265262
- `k=3`: The number of centroids to use in clustering.
266-
- `metric::Distances.SqEuclidean`: The metric used to calculate the clustering distance
267-
matrix. Must be a subtype of `Distances.SemiMetric` from Distances.jl.
263+
- `metric::SemiMetric=SqEuclidean`: The metric used to calculate the clustering distance
264+
matrix
268265
269266
# Operations
270267
271-
- `predict(mach, Xnew)`: return learned cluster labels for a new
272-
table of inputs `Xnew` having the same scitype as `X` above.
268+
- `predict(mach, Xnew)`: return predictions of the target given new
269+
features `Xnew` having the same Scitype as `X` above.
273270
- `transform(mach, Xnew)`: instead return the mean pairwise distances from
274-
new samples to the cluster centers.
271+
new samples to the cluster centers
275272
276273
# Fitted parameters
277274
@@ -291,20 +288,25 @@ The fields of `report(mach)` are:
291288
```
292289
using MLJ
293290
using Test
294-
KMeans = @load KMedoids pkg=Clustering
291+
KMedoids = @load KMedoids pkg=Clustering
295292
296293
X, y = @load_iris
297294
model = KMedoids(k=3)
298295
mach = machine(model, X) |> fit!
299296
300-
preds = predict(mach, X)
301-
@test preds == report(mach).assignments
297+
yhat = predict(mach, X)
298+
@test yhat == report(mach).assignments
299+
300+
compare = zip(yhat, y) |> collect;
301+
compare[1:8] # clusters align with classes
302302
303303
center_dists = transform(mach, fitted_params(mach).medoids')
304304
305305
@test center_dists[1][1] == 0.0
306306
@test center_dists[2][2] == 0.0
307307
@test center_dists[3][3] == 0.0
308+
309+
# we can also
308310
```
309311
310312
See also
@@ -314,3 +316,4 @@ KMedoids
314316

315317

316318
end # module
319+

0 commit comments

Comments
 (0)