Skip to content

Commit 9f5b883

Browse files
committed
docs: rm target y; tweak examples
1 parent c0ae414 commit 9f5b883

File tree

1 file changed

+54
-45
lines changed

1 file changed

+54
-45
lines changed

src/MLJClusteringInterface.jl

Lines changed: 54 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,6 @@ export KMeans, KMedoids
2323
# Define constants for easy referencing of packages
2424
const MMI = MLJModelInterface
2525
const Cl = Clustering
26-
27-
28-
2926
const PKG = "MLJClusteringInterface"
3027

3128
####
@@ -143,6 +140,7 @@ metadata_pkg.(
143140

144141
metadata_model(
145142
KMeans,
143+
human_name = "K-means clusterer",
146144
input = MMI.Table(Continuous),
147145
output = MMI.Table(Continuous),
148146
weights = false,
@@ -151,6 +149,7 @@ metadata_model(
151149

152150
metadata_model(
153151
KMedoids,
152+
human_name = "K-medoids clusterer",
154153
input = MMI.Table(Continuous),
155154
output = MMI.Table(Continuous),
156155
weights = false,
@@ -159,38 +158,49 @@ metadata_model(
159158
"""
160159
$(MMI.doc_header(KMeans))
161160
161+
[K-means](http://en.wikipedia.org/wiki/K_means) is a classical method for
162+
clustering or vector quantization. It produces a fixed number of clusters,
163+
each associated with a *center* (also known as a *prototype*), and each data
164+
point is assigned to a cluster with the nearest center.
165+
166+
From a mathematical standpoint, K-means is a coordinate descent
167+
algorithm that solves the following optimization problem:
168+
169+
```math
170+
\\text{minimize} \\ \\sum_{i=1}^n \\| \\mathbf{x}_i - \\boldsymbol{\\mu}_{z_i} \\|^2 \\ \\text{w.r.t.} \\ (\\boldsymbol{\\mu}, z)
171+
```
172+
Here, ``\\boldsymbol{\\mu}_k`` is the center of the ``k``-th cluster, and
173+
``z_i`` is an index of the cluster for ``i``-th point ``\\mathbf{x}_i``.
162174
163-
`KMeans`: The K-Means algorithm finds K centroids corresponding to K clusters in
164-
the data. The clusters are assumed to be elliptical, should be used with a euclidean distance metric
165175
166176
# Training data
167177
168178
In MLJ or MLJBase, bind an instance `model` to data with
169179
170-
mach = machine(model, X, y)
171-
172-
Where
180+
mach = machine(model, X)
173181
174-
- `X`: is any table of input features (eg, a `DataFrame`) whose columns
175-
are of scitype `Continuous`; check the scitype with `schema(X)`
182+
Here:
176183
177-
- `y`: is the target, which can be any `AbstractVector` whose element
178-
scitype is `Count`; check the scitype with `schema(y)`
184+
- `X` is any table of input features (eg, a `DataFrame`) whose columns
185+
are of scitype `Continuous`; check column scitypes with `schema(X)`.
179186
180187
Train the machine using `fit!(mach, rows=...)`.
181188
182189
# Hyper-parameters
183190
184191
- `k=3`: The number of centroids to use in clustering.
185-
- `metric::SemiMetric=SqEuclidean`: The metric used to calculate the clustering distance
186-
matrix
192+
193+
- `metric::SemiMetric=Distances.SqEuclidean`: The metric used to calculate the
194+
clustering. Must have type `PreMetric` from Distances.jl.
195+
187196
188197
# Operations
189198
190-
- `predict(mach, Xnew)`: return predictions of the target given new
199+
- `predict(mach, Xnew)`: return cluster label assignments, given new
191200
features `Xnew` having the same Scitype as `X` above.
201+
192202
- `transform(mach, Xnew)`: instead return the mean pairwise distances from
193-
new samples to the cluster centers
203+
new samples to the cluster centers.
194204
195205
# Fitted parameters
196206
@@ -203,72 +213,72 @@ The fields of `fitted_params(mach)` are:
203213
The fields of `report(mach)` are:
204214
205215
- `assignments`: The cluster assignments of each point in the training data.
216+
206217
- `cluster_labels`: The labels assigned to each cluster.
207218
208219
# Examples
209220
210221
```
211222
using MLJ
212-
using Distances
213-
using Test
214223
KMeans = @load KMeans pkg=Clustering
215224
216-
X, y = @load_iris
225+
table = load_iris()
226+
y, X = unpack(table, ==(:target), rng=123)
217227
model = KMeans(k=3)
218228
mach = machine(model, X) |> fit!
219229
220230
yhat = predict(mach, X)
221-
@test yhat == report(mach).assignments
231+
@assert yhat == report(mach).assignments
222232
223233
compare = zip(yhat, y) |> collect;
224234
compare[1:8] # clusters align with classes
225235
226236
center_dists = transform(mach, fitted_params(mach).centers')
227237
228-
@test center_dists[1][1] == 0.0
229-
@test center_dists[2][2] == 0.0
230-
@test center_dists[3][3] == 0.0
238+
@assert center_dists[1][1] == 0.0
239+
@assert center_dists[2][2] == 0.0
240+
@assert center_dists[3][3] == 0.0
231241
```
232242
233243
See also
234244
[`KMedoids`](@ref)
235245
"""
236246
KMeans
247+
237248
"""
238249
$(MMI.doc_header(KMedoids))
239250
240-
`KMedoids`: The K-Medoids algorithm finds K centroids corresponding to K clusters in the
241-
data. Unlike K-Means, the centroids are found among data points themselves. Clusters
242-
are not assumed to be elliptical. Should be used with a non-euclidean distance metric
251+
[K-medoids](http://en.wikipedia.org/wiki/K-medoids) is a clustering algorithm that works by
252+
finding ``k`` data points (called *medoids*) such that the total distance between each data
253+
point and the closest *medoid* is minimal.
243254
244255
# Training data
245256
246257
In MLJ or MLJBase, bind an instance `model` to data with
247258
248-
mach = machine(model, X, y)
249-
250-
Where
259+
mach = machine(model, X)
251260
252-
- `X`: is any table of input features (eg, a `DataFrame`) whose columns
253-
are of scitype `Continuous`; check the scitype with `schema(X)`
261+
Here:
254262
255-
- `y`: is the target, which can be any `AbstractVector` whose element
256-
scitype is `Count`; check the scitype with `schema(y)`
263+
- `X` is any table of input features (eg, a `DataFrame`) whose columns
264+
are of scitype `Continuous`; check column scitypes with `schema(X)`
257265
258266
Train the machine using `fit!(mach, rows=...)`.
259267
260268
# Hyper-parameters
261269
262270
- `k=3`: The number of centroids to use in clustering.
263-
- `metric::SemiMetric=SqEuclidean`: The metric used to calculate the clustering distance
264-
matrix
271+
272+
- `metric::SemiMetric=Distances.SqEuclidean`: The metric used to calculate the
273+
clustering. Must have type `PreMetric` from Distances.jl.
265274
266275
# Operations
267276
268-
- `predict(mach, Xnew)`: return predictions of the target given new
277+
- `predict(mach, Xnew)`: return cluster label assignments, given new
269278
features `Xnew` having the same Scitype as `X` above.
279+
270280
- `transform(mach, Xnew)`: instead return the mean pairwise distances from
271-
new samples to the cluster centers
281+
new samples to the cluster centers.
272282
273283
# Fitted parameters
274284
@@ -281,32 +291,31 @@ The fields of `fitted_params(mach)` are:
281291
The fields of `report(mach)` are:
282292
283293
- `assignments`: The cluster assignments of each point in the training data.
294+
284295
- `cluster_labels`: The labels assigned to each cluster.
285296
286297
# Examples
287298
288299
```
289300
using MLJ
290-
using Test
291301
KMedoids = @load KMedoids pkg=Clustering
292302
293-
X, y = @load_iris
303+
table = load_iris()
304+
y, X = unpack(table, ==(:target), rng=123)
294305
model = KMedoids(k=3)
295306
mach = machine(model, X) |> fit!
296307
297308
yhat = predict(mach, X)
298-
@test yhat == report(mach).assignments
309+
@assert yhat == report(mach).assignments
299310
300311
compare = zip(yhat, y) |> collect;
301312
compare[1:8] # clusters align with classes
302313
303314
center_dists = transform(mach, fitted_params(mach).medoids')
304315
305-
@test center_dists[1][1] == 0.0
306-
@test center_dists[2][2] == 0.0
307-
@test center_dists[3][3] == 0.0
308-
309-
# we can also
316+
@assert center_dists[1][1] == 0.0
317+
@assert center_dists[2][2] == 0.0
318+
@assert center_dists[3][3] == 0.0
310319
```
311320
312321
See also

0 commit comments

Comments
 (0)