@@ -16,7 +16,7 @@ using Distances
16
16
17
17
# ===================================================================
18
18
# # EXPORTS
19
- export KMeans, KMedoids, DBSCAN
19
+ export KMeans, KMedoids, DBSCAN, HierarchicalClustering
20
20
21
21
# ===================================================================
22
22
# # CONSTANTS
@@ -31,6 +31,7 @@ const PKG = "MLJClusteringInterface"
31
31
@mlj_model mutable struct KMeans <: MMI.Unsupervised
32
32
k:: Int = 3 :: (_ ≥ 2)
33
33
metric:: SemiMetric = SqEuclidean ()
34
+ init = :kmpp
34
35
end
35
36
36
37
function MMI. fit (model:: KMeans , verbosity:: Int , X)
@@ -169,10 +170,51 @@ end
169
170
MMI. reporting_operations (:: Type{<:DBSCAN} ) = (:predict ,)
170
171
171
172
173
+ # # HierarchicalClustering
174
+ @mlj_model mutable struct HierarchicalClustering <: MMI.Static
175
+ linkage:: Symbol = :single :: (_ ∈ (:single, :average, :complete, :ward, :ward_presquared) )
176
+ metric:: SemiMetric = SqEuclidean ()
177
+ branchorder:: Symbol = :r :: (_ ∈ (:r, :barjoseph, :optimal) )
178
+ h:: Union{Nothing,Float64} = nothing
179
+ k:: Int = 3
180
+ end
181
+ """
182
+ struct DendrogramCutter{T}
183
+ dendrogram::T
184
+ end
185
+
186
+ Callable object to cut a dendrogram.
187
+ """
188
+ struct DendrogramCutter{T}
189
+ dendrogram:: T
190
+ end
191
+ """
192
+ (cutter::DendrogramCutter)(; h = nothing, k = 3)
193
+
194
+ Cuts the dendrogram at height `h` or, if `height == nothing`, such that `k` clusters are obtained.
195
+ """
196
+ function (cutter:: DendrogramCutter )(; h = nothing , k = 3 )
197
+ MMI. categorical (Cl. cutree (cutter. dendrogram, k = k, h = h))
198
+ end
199
+ function Base. show (io:: IO , :: DendrogramCutter )
200
+ print (io, " Dendrogram Cutter." )
201
+ end
202
+
203
+ function MMI. predict (model:: HierarchicalClustering , :: Nothing , X)
204
+ Xarray = MMI. matrix (X)
205
+ d = pairwise (model. metric, Xarray, dims = 1 ) # n x n
206
+ dendrogram = Cl. hclust (d, linkage = model. linkage, branchorder = model. branchorder)
207
+ cutter = DendrogramCutter (dendrogram)
208
+ yhat = cutter (h = model. h, k = model. k)
209
+ return yhat, (; cutter, dendrogram)
210
+ end
211
+
212
+ MMI. reporting_operations (:: Type{<:HierarchicalClustering} ) = (:predict ,)
213
+
172
214
# # METADATA
173
215
174
216
metadata_pkg .(
175
- (KMeans, KMedoids, DBSCAN),
217
+ (KMeans, KMedoids, DBSCAN, HierarchicalClustering ),
176
218
name= " Clustering" ,
177
219
uuid= " aaaa29a8-35af-508c-8bc3-b662a17a0fe5" ,
178
220
url= " https://github.com/JuliaStats/Clustering.jl" ,
@@ -205,6 +247,12 @@ metadata_model(
205
247
path = " $(PKG) .DBSCAN"
206
248
)
207
249
250
+ metadata_model (
251
+ HierarchicalClustering,
252
+ human_name = " hierarchical clusterer" ,
253
+ input = MMI. Table (Continuous),
254
+ path = " $(PKG) .HierarchicalClustering"
255
+ )
208
256
209
257
"""
210
258
$(MMI. doc_header (KMeans))
@@ -477,4 +525,78 @@ scatter(points, color=colors)
477
525
"""
478
526
DBSCAN
479
527
528
+ """
529
+ $(MMI. doc_header (HierarchicalClustering))
530
+
531
+ [Hierarchical Clustering](https://en.wikipedia.org/wiki/Hierarchical_clustering) is a
532
+ clustering algorithm that organizes the data in a dendrogram based on distances between
533
+ groups of points and computes cluster assignments by cutting the dendrogram at a given
534
+ height. More information is available at the [Clustering.jl
535
+ documentation](https://juliastats.org/Clustering.jl/stable/index.html). Use `predict` to
536
+ get cluster assignments. The dendrogram and the dendrogram cutter are accessed from the
537
+ machine report (see below).
538
+
539
+ This is a static implementation, i.e., it does not generalize to new data instances, and
540
+ there is no training data. For clusterers that do generalize, see [`KMeans`](@ref) or
541
+ [`KMedoids`](@ref).
542
+
543
+ In MLJ or MLJBase, create a machine with
544
+
545
+ mach = machine(model)
546
+
547
+ # Hyper-parameters
548
+
549
+ - `linkage = :single`: linkage method (:single, :average, :complete, :ward, :ward_presquared)
550
+
551
+ - `metric = SqEuclidean`: metric (see `Distances.jl` for available metrics)
552
+
553
+ - `branchorder = :r`: branchorder (:r, :barjoseph, :optimal)
554
+
555
+ - `h = nothing`: height at which the dendrogram is cut
556
+
557
+ - `k = 3`: number of clusters.
558
+
559
+ If both `k` and `h` are specified, it is guaranteed that the number of clusters is not less than `k` and their height is not above `h`.
560
+
561
+
562
+ # Operations
563
+
564
+ - `predict(mach, X)`: return cluster label assignments, as an unordered
565
+ `CategoricalVector`. Here `X` is any table of input features (eg, a `DataFrame`) whose
566
+ columns are of scitype `Continuous`; check column scitypes with `schema(X)`.
567
+
568
+
569
+ # Report
570
+
571
+ After calling `predict(mach)`, the fields of `report(mach)` are:
572
+
573
+ - `dendrogram`: the dendrogram that was computed when calling `predict`.
574
+
575
+ - `cutter`: a dendrogram cutter that can be called with a height `h` or a number of clusters `k`, to obtain a new assignment of the data points to clusters (see example below).
576
+
577
+ # Examples
578
+
579
+ ```
580
+ using MLJ
581
+
582
+ X, labels = make_moons(400, noise=0.09, rng=1) # synthetic data with 2 clusters; X
583
+
584
+ HierarchicalClustering = @load HierarchicalClustering pkg=Clustering
585
+ model = HierarchicalClustering(linkage = :complete)
586
+ mach = machine(model)
587
+
588
+ # compute and output cluster assignments for observations in `X`:
589
+ yhat = predict(mach, X)
590
+
591
+ # plot dendrogram:
592
+ using StatsPlots
593
+ plot(report(mach).dendrogram)
594
+
595
+ # make new predictions by cutting the dendrogram at another height
596
+ report(mach).cutter(h = 2.5)
597
+ ```
598
+
599
+ """
600
+ HierarchicalClustering
601
+
480
602
end # module
0 commit comments