Skip to content

Commit a9e19bf

Browse files
authored
Merge pull request #9 from jbrea/master
add Hierarchical Clustering & some docstring fixes
2 parents 6f60a35 + 9f08382 commit a9e19bf

File tree

2 files changed

+142
-3
lines changed

2 files changed

+142
-3
lines changed

src/MLJClusteringInterface.jl

Lines changed: 124 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ using Distances
1616

1717
# ===================================================================
1818
## EXPORTS
19-
export KMeans, KMedoids, DBSCAN
19+
export KMeans, KMedoids, DBSCAN, HierarchicalClustering
2020

2121
# ===================================================================
2222
## CONSTANTS
@@ -31,6 +31,7 @@ const PKG = "MLJClusteringInterface"
3131
@mlj_model mutable struct KMeans <: MMI.Unsupervised
3232
k::Int = 3::(_ ≥ 2)
3333
metric::SemiMetric = SqEuclidean()
34+
init = :kmpp
3435
end
3536

3637
function MMI.fit(model::KMeans, verbosity::Int, X)
@@ -169,10 +170,51 @@ end
169170
MMI.reporting_operations(::Type{<:DBSCAN}) = (:predict,)
170171

171172

173+
# # HierarchicalClustering
174+
@mlj_model mutable struct HierarchicalClustering <: MMI.Static
175+
linkage::Symbol = :single :: (_ ∈ (:single, :average, :complete, :ward, :ward_presquared))
176+
metric::SemiMetric = SqEuclidean()
177+
branchorder::Symbol = :r :: (_ ∈ (:r, :barjoseph, :optimal))
178+
h::Union{Nothing,Float64} = nothing
179+
k::Int = 3
180+
end
181+
"""
182+
struct DendrogramCutter{T}
183+
dendrogram::T
184+
end
185+
186+
Callable object to cut a dendrogram.
187+
"""
188+
struct DendrogramCutter{T}
189+
dendrogram::T
190+
end
191+
"""
192+
(cutter::DendrogramCutter)(; h = nothing, k = 3)
193+
194+
Cuts the dendrogram at height `h` or, if `height == nothing`, such that `k` clusters are obtained.
195+
"""
196+
function (cutter::DendrogramCutter)(; h = nothing, k = 3)
197+
MMI.categorical(Cl.cutree(cutter.dendrogram, k = k, h = h))
198+
end
199+
function Base.show(io::IO, ::DendrogramCutter)
200+
print(io, "Dendrogram Cutter.")
201+
end
202+
203+
function MMI.predict(model::HierarchicalClustering, ::Nothing, X)
204+
Xarray = MMI.matrix(X)
205+
d = pairwise(model.metric, Xarray, dims = 1) # n x n
206+
dendrogram = Cl.hclust(d, linkage = model.linkage, branchorder = model.branchorder)
207+
cutter = DendrogramCutter(dendrogram)
208+
yhat = cutter(h = model.h, k = model.k)
209+
return yhat, (; cutter, dendrogram)
210+
end
211+
212+
MMI.reporting_operations(::Type{<:HierarchicalClustering}) = (:predict,)
213+
172214
# # METADATA
173215

174216
metadata_pkg.(
175-
(KMeans, KMedoids, DBSCAN),
217+
(KMeans, KMedoids, DBSCAN, HierarchicalClustering),
176218
name="Clustering",
177219
uuid="aaaa29a8-35af-508c-8bc3-b662a17a0fe5",
178220
url="https://github.com/JuliaStats/Clustering.jl",
@@ -205,6 +247,12 @@ metadata_model(
205247
path = "$(PKG).DBSCAN"
206248
)
207249

250+
metadata_model(
251+
HierarchicalClustering,
252+
human_name = "hierarchical clusterer",
253+
input = MMI.Table(Continuous),
254+
path = "$(PKG).HierarchicalClustering"
255+
)
208256

209257
"""
210258
$(MMI.doc_header(KMeans))
@@ -477,4 +525,78 @@ scatter(points, color=colors)
477525
"""
478526
DBSCAN
479527

528+
"""
529+
$(MMI.doc_header(HierarchicalClustering))
530+
531+
[Hierarchical Clustering](https://en.wikipedia.org/wiki/Hierarchical_clustering) is a
532+
clustering algorithm that organizes the data in a dendrogram based on distances between
533+
groups of points and computes cluster assignments by cutting the dendrogram at a given
534+
height. More information is available at the [Clustering.jl
535+
documentation](https://juliastats.org/Clustering.jl/stable/index.html). Use `predict` to
536+
get cluster assignments. The dendrogram and the dendrogram cutter are accessed from the
537+
machine report (see below).
538+
539+
This is a static implementation, i.e., it does not generalize to new data instances, and
540+
there is no training data. For clusterers that do generalize, see [`KMeans`](@ref) or
541+
[`KMedoids`](@ref).
542+
543+
In MLJ or MLJBase, create a machine with
544+
545+
mach = machine(model)
546+
547+
# Hyper-parameters
548+
549+
- `linkage = :single`: linkage method (:single, :average, :complete, :ward, :ward_presquared)
550+
551+
- `metric = SqEuclidean`: metric (see `Distances.jl` for available metrics)
552+
553+
- `branchorder = :r`: branchorder (:r, :barjoseph, :optimal)
554+
555+
- `h = nothing`: height at which the dendrogram is cut
556+
557+
- `k = 3`: number of clusters.
558+
559+
If both `k` and `h` are specified, it is guaranteed that the number of clusters is not less than `k` and their height is not above `h`.
560+
561+
562+
# Operations
563+
564+
- `predict(mach, X)`: return cluster label assignments, as an unordered
565+
`CategoricalVector`. Here `X` is any table of input features (eg, a `DataFrame`) whose
566+
columns are of scitype `Continuous`; check column scitypes with `schema(X)`.
567+
568+
569+
# Report
570+
571+
After calling `predict(mach)`, the fields of `report(mach)` are:
572+
573+
- `dendrogram`: the dendrogram that was computed when calling `predict`.
574+
575+
- `cutter`: a dendrogram cutter that can be called with a height `h` or a number of clusters `k`, to obtain a new assignment of the data points to clusters (see example below).
576+
577+
# Examples
578+
579+
```
580+
using MLJ
581+
582+
X, labels = make_moons(400, noise=0.09, rng=1) # synthetic data with 2 clusters; X
583+
584+
HierarchicalClustering = @load HierarchicalClustering pkg=Clustering
585+
model = HierarchicalClustering(linkage = :complete)
586+
mach = machine(model)
587+
588+
# compute and output cluster assignments for observations in `X`:
589+
yhat = predict(mach, X)
590+
591+
# plot dendrogram:
592+
using StatsPlots
593+
plot(report(mach).dendrogram)
594+
595+
# make new predictions by cutting the dendrogram at another height
596+
report(mach).cutter(h = 2.5)
597+
```
598+
599+
"""
600+
HierarchicalClustering
601+
480602
end # module

test/runtests.jl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,25 @@ end
9595

9696
end
9797

98+
# # HierarchicalClustering
99+
100+
@testset "HierarchicalClustering" begin
101+
h = Inf; k = 1; linkage = :complete; bo = :optimal;
102+
metric = Distances.Euclidean()
103+
mach = machine(HierarchicalClustering(h = h, k = k, metric = metric,
104+
linkage = linkage, branchorder = bo))
105+
yhat = predict(mach, X)
106+
@test length(union(yhat)) == 1 # uses h = Inf
107+
cutter = report(mach).cutter
108+
@test length(union(cutter(k = 4))) == 4 # uses k = 4
109+
dendro = Clustering.hclust(Distances.pairwise(metric, hcat(X...), dims = 1),
110+
linkage = linkage, branchorder = bo)
111+
@test cutter(k = 2) == Clustering.cutree(dendro, k = 2)
112+
@test report(mach).dendrogram.heights == dendro.heights
113+
end
114+
98115
@testset "MLJ interface" begin
99-
models = [KMeans, KMedoids, DBSCAN]
116+
models = [KMeans, KMedoids, DBSCAN, HierarchicalClustering]
100117
failures, summary = MLJTestIntegration.test(
101118
models,
102119
X;

0 commit comments

Comments
 (0)