Skip to content

Commit 39b1fdb

Browse files
authored
Merge pull request #17 from JuliaAI/dbscan
Add interface for DBSCAN
2 parents 867f7f0 + 85cfa5c commit 39b1fdb

File tree

4 files changed

+252
-42
lines changed

4 files changed

+252
-42
lines changed

Project.toml

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,3 @@ Clustering = "0.14"
1313
Distances = "0.9, 0.10"
1414
MLJModelInterface = "1.4"
1515
julia = "1.6"
16-
17-
[extras]
18-
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
19-
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
20-
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
21-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
22-
23-
[targets]
24-
test = ["LinearAlgebra", "MLJBase", "Random", "Test"]

src/MLJClusteringInterface.jl

Lines changed: 170 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ using Distances
1616

1717
# ===================================================================
1818
## EXPORTS
19-
export KMeans, KMedoids
19+
export KMeans, KMedoids, DBSCAN
2020

2121
# ===================================================================
2222
## CONSTANTS
@@ -25,19 +25,14 @@ const MMI = MLJModelInterface
2525
const Cl = Clustering
2626
const PKG = "MLJClusteringInterface"
2727

28-
####
29-
#### KMeans
30-
####
28+
29+
# # K_MEANS
3130

3231
@mlj_model mutable struct KMeans <: MMI.Unsupervised
3332
k::Int = 3::(_ ≥ 2)
3433
metric::SemiMetric = SqEuclidean()
3534
end
3635

37-
####
38-
#### KMeans
39-
####
40-
4136
function MMI.fit(model::KMeans, verbosity::Int, X)
4237
# NOTE: using transpose here to get a LinearAlgebra.Transpose object
4338
# which Kmeans can handle.
@@ -66,6 +61,8 @@ function MMI.transform(model::KMeans, fitresult, X)
6661
return MMI.table(X̃, prototype=X)
6762
end
6863

64+
# # K_MEDOIDS
65+
6966
@mlj_model mutable struct KMedoids <: MMI.Unsupervised
7067
k::Int = 3::(_ ≥ 2)
7168
metric::SemiMetric = SqEuclidean()
@@ -100,9 +97,8 @@ function MMI.transform(model::KMedoids, fitresult, X)
10097
return MMI.table(X̃, prototype=X)
10198
end
10299

103-
####
104-
#### Predict methods
105-
####
100+
101+
# # PREDICT FOR K_MEANS AND K_MEDOIDS
106102

107103
function MMI.predict(model::Union{KMeans,KMedoids}, fitresult, Xnew)
108104
locations, cluster_labels = fitresult
@@ -124,12 +120,59 @@ function MMI.predict(model::Union{KMeans,KMedoids}, fitresult, Xnew)
124120
return cluster_labels[pred]
125121
end
126122

127-
####
128-
#### METADATA
129-
####
123+
# # DBSCAN
124+
125+
@mlj_model mutable struct DBSCAN <: MMI.Static
126+
radius::Real = 1.0::(_ > 0)
127+
leafsize::Int = 20::(_ > 0)
128+
min_neighbors::Int = 1::(_ > 0)
129+
min_cluster_size::Int = 1::(_ > 0)
130+
end
131+
132+
# As DBSCAN is `Static`, there is no `fit` to implement.
133+
134+
function MMI.predict(model::DBSCAN, ::Nothing, X)
135+
136+
Xarray = MMI.matrix(X)'
137+
138+
# output of core algorithm:
139+
clusters = Cl.dbscan(
140+
Xarray, model.radius;
141+
leafsize=model.leafsize,
142+
min_neighbors=model.min_neighbors,
143+
min_cluster_size=model.min_cluster_size,
144+
)
145+
nclusters = length(clusters)
146+
147+
# assignments and point types
148+
npoints = size(Xarray, 2)
149+
assignments = zeros(Int, npoints)
150+
raw_point_types = fill('N', npoints)
151+
for (k, cluster) in enumerate(clusters)
152+
for i in cluster.core_indices
153+
assignments[i] = k
154+
raw_point_types[i] = 'C'
155+
end
156+
for i in cluster.boundary_indices
157+
assignments[i] = k
158+
raw_point_types[i] = 'B'
159+
end
160+
end
161+
point_types = MMI.categorical(raw_point_types)
162+
cluster_labels = unique(assignments)
163+
164+
yhat = MMI.categorical(assignments)
165+
report = (; point_types, nclusters, cluster_labels, clusters)
166+
return yhat, report
167+
end
168+
169+
MMI.reporting_operations(::Type{<:DBSCAN}) = (:predict,)
170+
171+
172+
# # METADATA
130173

131174
metadata_pkg.(
132-
(KMeans, KMedoids),
175+
(KMeans, KMedoids, DBSCAN),
133176
name="Clustering",
134177
uuid="aaaa29a8-35af-508c-8bc3-b662a17a0fe5",
135178
url="https://github.com/JuliaStats/Clustering.jl",
@@ -143,7 +186,6 @@ metadata_model(
143186
human_name = "K-means clusterer",
144187
input = MMI.Table(Continuous),
145188
output = MMI.Table(Continuous),
146-
weights = false,
147189
path = "$(PKG).KMeans"
148190
)
149191

@@ -152,9 +194,18 @@ metadata_model(
152194
human_name = "K-medoids clusterer",
153195
input = MMI.Table(Continuous),
154196
output = MMI.Table(Continuous),
155-
weights = false,
156197
path = "$(PKG).KMedoids"
157198
)
199+
200+
metadata_model(
201+
DBSCAN,
202+
human_name = "DBSCAN clusterer (density-based spatial clustering of "*
203+
"applications with noise)",
204+
input = MMI.Table(Continuous),
205+
path = "$(PKG).DBSCAN"
206+
)
207+
208+
158209
"""
159210
$(MMI.doc_header(KMeans))
160211
@@ -323,6 +374,107 @@ See also
323374
"""
324375
KMedoids
325376

377+
"""
378+
$(MMI.doc_header(DBSCAN))
326379
327-
end # module
380+
[DBSCAN](https://en.wikipedia.org/wiki/DBSCAN) is a clustering algorithm that groups
381+
together points that are closely packed together (points with many nearby neighbors),
382+
marking as outliers points that lie alone in low-density regions (whose nearest neighbors
383+
are too far away). More information is available at the [Clustering.jl
384+
documentation](https://juliastats.org/Clustering.jl/stable/index.html). Use `predict` to
385+
get cluster assignments. Point types - core, boundary or noise - are accessed from the
386+
machine report (see below).
387+
388+
This is a static implementation, i.e., it does not generalize to new data instances, and
389+
there is no training data. For clusterers that do generalize, see [`KMeans`](@ref) or
390+
[`KMedoids`](@ref).
391+
392+
In MLJ or MLJBase, create a machine with
393+
394+
mach = machine(model)
395+
396+
# Hyper-parameters
397+
398+
- `radius=1.0`: query radius.
399+
400+
- `leafsize=20`: number of points binned in each leaf node of the nearest neighbor k-d
401+
tree.
402+
403+
- `min_neighbors=1`: minimum number of a core point neighbors.
328404
405+
- `min_cluster_size=1`: minimum number of points in a valid cluster.
406+
407+
408+
# Operations
409+
410+
- `predict(mach, X)`: return cluster label assignments, as an unordered
411+
`CategoricalVector`. Here `X` is any table of input features (eg, a `DataFrame`) whose
412+
columns are of scitype `Continuous`; check column scitypes with `schema(X)`. Note that
413+
points of type `noise` will always get a label of `0`.
414+
415+
416+
# Report
417+
418+
After calling `predict(mach)`, the fields of `report(mach)` are:
419+
420+
- `point_types`: A `CategoricalVector` with the DBSCAN point type classification, one
421+
element per row of `X`. Elements are either `'C'` (core), `'B'` (boundary), or `'N'`
422+
(noise).
423+
424+
- `nclusters`: The number of clusters (excluding the noise "cluster")
425+
426+
- `cluster_labels`: The unique list of cluster labels
427+
428+
- `clusters`: A vector of `Clustering.DbscanCluster` objects from Clustering.jl, which
429+
have these fields:
430+
431+
- `size`: number of points in a cluster (core + boundary)
432+
433+
- `core_indices`: indices of points in the cluster core
434+
435+
- `boundary_indices`: indices of points on the cluster boundary
436+
437+
438+
# Examples
439+
440+
```
441+
using MLJ
442+
443+
X, labels = make_moons(400, noise=0.09, rng=1) # synthetic data with 2 clusters; X
444+
y = map(labels) do label
445+
label == 0 ? "cookie" : "monster"
446+
end;
447+
y = coerce(y, Multiclass);
448+
449+
DBSCAN = @load DBSCAN pkg=Clustering
450+
model = DBSCAN(radius=0.13, min_cluster_size=5)
451+
mach = machine(model)
452+
453+
# compute and output cluster assignments for observations in `X`:
454+
yhat = predict(mach, X)
455+
456+
# get DBSCAN point types:
457+
report(mach).point_types
458+
report(mach).nclusters
459+
460+
# compare cluster labels with actual labels:
461+
compare = zip(yhat, y) |> collect;
462+
compare[1:10] # clusters align with classes
463+
464+
# visualize clusters, noise in red:
465+
points = zip(X.x1, X.x2) |> collect
466+
colors = map(yhat) do i
467+
i == 0 ? :red :
468+
i == 1 ? :blue :
469+
i == 2 ? :green :
470+
i == 3 ? :yellow :
471+
:black
472+
end
473+
using Plots
474+
scatter(points, color=colors)
475+
```
476+
477+
"""
478+
DBSCAN
479+
480+
end # module

test/Project.toml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
[deps]
2+
Clustering = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5"
3+
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
4+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
5+
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
6+
MLJTestIntegration = "697918b4-fdc1-4f9e-8ff9-929724cee270"
7+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
8+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
9+
10+
[compat]
11+
MLJTestIntegration = "0.2.2"

test/runtests.jl

Lines changed: 71 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,18 @@ import Distances
33
import LinearAlgebra: norm
44

55
using MLJBase
6+
using MLJTestIntegration
67
using MLJClusteringInterface
7-
using Random:seed!
8+
using Random: seed!
89
using Test
910

10-
const Dist = Distances
11-
1211
seed!(132442)
1312
X, y = @load_crabs
1413

15-
####
16-
#### KMEANS
17-
####
1814

19-
@testset "Kmeans" begin
15+
# # K_MEANS
16+
17+
@testset "KMeans" begin
2018
barekm = KMeans()
2119
fitresult, cache, report = fit(barekm, 1, X)
2220
R = matrix(transform(barekm, fitresult, X))
@@ -28,25 +26,83 @@ X, y = @load_crabs
2826
p = predict(barekm, fitresult, X)
2927
@test argmin(R[1, :]) == p[1]
3028
@test argmin(R[10, :]) == p[10]
31-
32-
3329
end
3430

35-
####
36-
#### KMEDOIDS
37-
####
3831

39-
@testset "Kmedoids" begin
32+
# # K_MEDOIDS
33+
34+
@testset "KMedoids" begin
4035
barekm = KMedoids()
4136
fitresult, cache, report = fit(barekm, 1, X)
4237
X_array = matrix(X)
4338
R = matrix(transform(barekm, fitresult, X))
44-
@test R[1, 2] Dist.evaluate(
39+
@test R[1, 2] Distances.evaluate(
4540
barekm.metric, view(X_array, 1, :), view(fitresult[1], :, 2)
4641
)
47-
@test R[10, 3] Dist.evaluate(
42+
@test R[10, 3] Distances.evaluate(
4843
barekm.metric, view(X_array, 10, :), view(fitresult[1], :, 3)
4944
)
5045
p = predict(barekm, fitresult, X)
5146
@test all(report.assignments .== p)
5247
end
48+
49+
50+
# # DBSCAN
51+
52+
@testset "DBSCAN" begin
53+
54+
# five spot pattern
55+
X = [
56+
0.0 0.0
57+
1.0 0.0
58+
1.0 1.0
59+
0.0 1.0
60+
0.5 0.5
61+
] |> MLJBase.table
62+
63+
# radius < √2 ==> 5 clusters
64+
dbscan = DBSCAN(radius=0.1)
65+
yhat1, report1 = predict(dbscan, nothing, X)
66+
@test report1.nclusters == 5
67+
@test report1.point_types == fill('B', 5)
68+
@test Set(yhat1) == Set(unique(yhat1))
69+
@test Set(report1.cluster_labels) == Set(unique(yhat1))
70+
71+
# DbscanCluster fields:
72+
@test propertynames(report1.clusters[1]) == (:size, :core_indices, :boundary_indices)
73+
74+
# radius > √2 ==> 1 cluster
75+
dbscan = DBSCAN(radius=2+eps())
76+
yhat, report = predict(dbscan, nothing, X)
77+
@test report.nclusters == 1
78+
@test report.point_types == fill('C', 5)
79+
@test length(unique(yhat)) == 1
80+
81+
# radius < √2 && min_cluster_size = 2 ==> all points are noise
82+
dbscan = DBSCAN(radius=0.1, min_cluster_size=2)
83+
yhat, report = predict(dbscan, nothing, X)
84+
@test report.nclusters == 0
85+
@test report.point_types == fill('N', 5)
86+
@test length(unique(yhat)) == 1
87+
88+
# MLJ integration:
89+
model = DBSCAN(radius=0.1)
90+
mach = machine(model) # no training data
91+
yhat = predict(mach, X)
92+
@test yhat == yhat1
93+
@test MLJBase.report(mach).point_types == report1.point_types
94+
@test MLJBase.report(mach).nclusters == report1.nclusters
95+
96+
end
97+
98+
@testset "MLJ interface" begin
99+
models = [KMeans, KMedoids, DBSCAN]
100+
failures, summary = MLJTestIntegration.test(
101+
models,
102+
X;
103+
mod=@__MODULE__,
104+
verbosity=0,
105+
throw=false, # set to true to debug
106+
)
107+
@test isempty(failures)
108+
end

0 commit comments

Comments
 (0)