Skip to content

Commit 3d64e14

Browse files
authored
Merge pull request #248 from JuliaStats/ast/refactor_dbscan
Refactor dbscan()
2 parents 04b0705 + d63a009 commit 3d64e14

File tree

3 files changed

+216
-236
lines changed

3 files changed

+216
-236
lines changed

src/dbscan.jl

Lines changed: 136 additions & 200 deletions
Original file line numberDiff line numberDiff line change
@@ -1,251 +1,187 @@
11
# DBSCAN Clustering
22
#
3-
# References:
4-
#
5-
# Martin Ester, Hans-peter Kriegel, Jörg S, and Xiaowei Xu
6-
# A density-based algorithm for discovering clusters
7-
# in large spatial databases with noise. 1996.
8-
#
9-
10-
"""
11-
DbscanResult <: ClusteringResult
123

13-
The output of [`dbscan`](@ref) function (distance matrix-based implementation).
14-
15-
# Fields
16-
- `seeds::Vector{Int}`: indices of cluster starting points
17-
- `assignments::Vector{Int}`: vector of clusters indices, where each point was assigned to
18-
- `counts::Vector{Int}`: cluster sizes (number of assigned points)
19-
"""
20-
mutable struct DbscanResult <: ClusteringResult
21-
seeds::Vector{Int} # starting points of clusters, size (k,)
22-
assignments::Vector{Int} # assignments, size (n,)
23-
counts::Vector{Int} # number of points in each cluster, size (k,)
24-
end
254

265
"""
276
DbscanCluster
287
29-
DBSCAN cluster returned by [`dbscan`](@ref) function (point coordinates-based
30-
implementation)
8+
DBSCAN cluster, part of [`DbscanResult`](@ref) returned by [`dbscan`](@ref) function.
319
32-
# Fields
33-
* `size::Int`: number of points in a cluster (core + boundary)
34-
* `core_indices::Vector{Int}`: indices of points in the cluster *core*
35-
* `boundary_indices::Vector{Int}`: indices of points on the cluster *boundary*
10+
## Fields
11+
- `size::Int`: number of points in a cluster (core + boundary)
12+
- `core_indices::Vector{Int}`: indices of points in the cluster *core*, a.k.a. *seeds*
13+
(have at least `min_neighbors` neighbors in the cluster)
14+
- `boundary_indices::Vector{Int}`: indices of the cluster points outside of *core*
3615
"""
3716
struct DbscanCluster
38-
size::Int # number of points in cluster
39-
core_indices::Vector{Int} # core points indices
40-
boundary_indices::Vector{Int} # boundary points indices
17+
size::Int
18+
core_indices::Vector{Int}
19+
boundary_indices::Vector{Int}
4120
end
4221

43-
## main algorithm
44-
4522
"""
46-
dbscan(D::AbstractMatrix, eps::Real, minpts::Int) -> DbscanResult
23+
DbscanResult <: ClusteringResult
4724
48-
Perform DBSCAN algorithm using the distance matrix `D`.
25+
The output of [`dbscan`](@ref) function.
4926
50-
# Arguments
51-
The following options control which points would be considered
52-
*density reachable*:
53-
- `eps::Real`: the radius of a point neighborhood
54-
- `minpts::Int`: the minimum number of neighboring points (including itself)
55-
to qualify a point as a density point.
27+
## Fields
28+
- `clusters::Vector{DbscanCluster}`: clusters, length *K*
29+
- `seeds::Vector{Int}`: indices of the first points of each cluster's *core*, length *K*
30+
- `counts::Vector{Int}`: cluster sizes (number of assigned points), length *K*
31+
- `assignments::Vector{Int}`: vector of clusters indices, where each point was assigned to, length *N*
5632
"""
57-
function dbscan(D::AbstractMatrix{T}, eps::Real, minpts::Int) where T<:Real
58-
# check arguments
59-
n = size(D, 1)
60-
size(D, 2) == n || throw(ArgumentError("D must be a square matrix ($(size(D)) given)."))
61-
n >= 2 || throw(ArgumentError("At least two data points are required ($n given)."))
62-
eps > 0 || throw(ArgumentError("eps must be a positive value ($eps given)."))
63-
minpts >= 1 || throw(ArgumentError("minpts must be positive integer ($minpts given)."))
64-
65-
# invoke core algorithm
66-
_dbscan(D, convert(T, eps), minpts, 1:n)
67-
end
68-
69-
function _dbscan(D::AbstractMatrix{T}, eps::T, minpts::Int, visitseq::AbstractVector{Int}) where T<:Real
70-
n = size(D, 1)
71-
72-
# prepare
73-
seeds = Int[]
74-
counts = Int[]
75-
assignments = zeros(Int, n)
76-
visited = zeros(Bool, n)
77-
k = 0
78-
79-
# main loop
80-
for p in visitseq
81-
if assignments[p] == 0 && !visited[p]
82-
visited[p] = true
83-
nbs = _dbs_region_query(D, p, eps)
84-
if length(nbs) >= minpts
85-
k += 1
86-
cnt = _dbs_expand_cluster!(D, k, p, nbs, eps, minpts, assignments, visited)
87-
push!(seeds, p)
88-
push!(counts, cnt)
89-
end
33+
struct DbscanResult <: ClusteringResult
34+
clusters::Vector{DbscanCluster}
35+
seeds::Vector{Int}
36+
counts::Vector{Int}
37+
assignments::Vector{Int}
38+
39+
function DbscanResult(clusters::AbstractVector{DbscanCluster}, num_points::Integer)
40+
assignments = zeros(Int, num_points)
41+
for (i, clu) in enumerate(clusters)
42+
assignments[clu.core_indices] .= i
43+
assignments[clu.boundary_indices] .= i
9044
end
45+
new(clusters,
46+
[c.core_indices[1] for c in clusters],
47+
[c.size for c in clusters],
48+
assignments)
9149
end
92-
93-
# make output
94-
return DbscanResult(seeds, assignments, counts)
9550
end
9651

97-
## key steps
98-
99-
function _dbs_region_query(D::AbstractMatrix{T}, p::Int, eps::T) where T<:Real
100-
n = size(D,1)
101-
nbs = Int[]
102-
dists = view(D,:,p)
103-
for i = 1:n
104-
@inbounds if dists[i] < eps
105-
push!(nbs, i)
106-
end
107-
end
108-
return nbs::Vector{Int}
109-
end
110-
111-
function _dbs_expand_cluster!(D::AbstractMatrix{T}, # distance matrix
112-
k::Int, # the index of current cluster
113-
p::Int, # the index of seeding point
114-
nbs::Vector{Int}, # eps-neighborhood of p
115-
eps::T, # radius of neighborhood
116-
minpts::Int, # minimum number of neighbors of a density point
117-
assignments::Vector{Int}, # assignment vector
118-
visited::Vector{Bool}) where T<:Real # visited indicators
119-
assignments[p] = k
120-
cnt = 1
121-
while !isempty(nbs)
122-
q = popfirst!(nbs)
123-
if !visited[q]
124-
visited[q] = true
125-
qnbs = _dbs_region_query(D, q, eps)
126-
if length(qnbs) >= minpts
127-
for x in qnbs
128-
if assignments[x] == 0
129-
push!(nbs, x)
130-
end
131-
end
132-
end
133-
end
134-
if assignments[q] == 0
135-
assignments[q] = k
136-
cnt += 1
137-
end
138-
end
139-
return cnt
140-
end
14152

14253
"""
143-
dbscan(points::AbstractMatrix, radius::Real,
144-
[leafsize], [min_neighbors], [min_cluster_size]) -> Vector{DbscanCluster}
54+
dbscan(points::AbstractMatrix, radius::Real;
55+
[metric=Euclidean()],
56+
[min_neighbors=1], [min_cluster_size=1],
57+
[nntree_kwargs...]) -> DbscanResult
14558
146-
Cluster `points` using the DBSCAN (density-based spatial clustering of
147-
applications with noise) algorithm.
59+
Cluster `points` using the DBSCAN (Density-Based Spatial Clustering of
60+
Applications with Noise) algorithm.
14861
149-
# Arguments
150-
- `points`: the ``d×n`` matrix of points. `points[:, j]` is a
151-
``d``-dimensional coordinates of ``j``-th point
152-
- `radius::Real`: query radius
62+
## Arguments
63+
- `points`: when `metric` is specified, the *d×n* matrix, where
64+
each column is a *d*-dimensional coordinate of a point;
65+
when `metric=nothing`, the *n×n* matrix of pairwise distances between the points
66+
- `radius::Real`: neighborhood radius; points within this distance
67+
are considered neighbors
15368
15469
Optional keyword arguments to control the algorithm:
155-
- `leafsize::Int` (defaults to 20): the number of points binned in each
156-
leaf node in the `KDTree`
157-
- `min_neighbors::Int` (defaults to 1): the minimum number of a *core* point
158-
neighbors
159-
- `min_cluster_size::Int` (defaults to 1): the minimum number of points in
160-
a valid cluster
161-
162-
# Example
163-
``` julia
70+
- `metric` (defaults to `Euclidean()`): the points distance metric to use,
71+
`nothing` means `points` is the *n×n* precalculated distance matrix
72+
- `min_neighbors::Integer` (defaults to 1): the minimal number of neighbors
73+
required to assign a point to a cluster "core"
74+
- `min_cluster_size::Integer` (defaults to 1): the minimal number of points in
75+
a cluster; cluster candidates with fewer points are discarded
76+
- `nntree_kwargs...`: parameters (like `leafsize`) for the `KDTree` constructor
77+
78+
## Example
79+
```julia
16480
points = randn(3, 10000)
16581
# DBSCAN clustering, clusters with less than 20 points will be discarded:
166-
clusters = dbscan(points, 0.05, min_neighbors = 3, min_cluster_size = 20)
82+
clustering = dbscan(points, 0.05, min_neighbors = 3, min_cluster_size = 20)
16783
```
84+
85+
## References:
86+
87+
* Martin Ester, Hans-Peter Kriegel, Jörg Sander, and Xiaowei Xu,
88+
*"A density-based algorithm for discovering clusters
89+
in large spatial databases with noise"*, KDD-1996, pp. 226--231.
90+
* Erich Schubert, Jörg Sander, Martin Ester, Hans Peter Kriegel, and
91+
Xiaowei Xu, *"DBSCAN Revisited, Revisited: Why and How You Should
92+
(Still) Use DBSCAN"*, ACM Transactions on Database Systems,
93+
Vol.42(3)3, pp. 1--21, https://doi.org/10.1145/3068335
16894
"""
169-
function dbscan(points::AbstractMatrix, radius::Real; leafsize::Int = 20, kwargs ...)
170-
kdtree = KDTree(points; leafsize=leafsize)
171-
return _dbscan(kdtree, points, radius; kwargs ...)
172-
end
95+
function dbscan(points::AbstractMatrix, radius::Real;
96+
metric = Euclidean(),
97+
min_neighbors::Integer = 1, min_cluster_size::Integer = 1,
98+
nntree_kwargs...)
99+
0 <= radius || throw(ArgumentError("radius $radius must be ≥ 0"))
173100

101+
if metric !== nothing
102+
# points are point coordinates
103+
dim, num_points = size(points)
104+
num_points <= dim && throw(ArgumentError("points has $dim rows and $num_points columns. Must be a D x N matric with D < N"))
105+
kdtree = KDTree(points, metric; nntree_kwargs...)
106+
data = (kdtree, points)
107+
else
108+
# points is a distance matrix
109+
num_points = size(points, 1)
110+
size(points, 2) == num_points || throw(ArgumentError("When metric=nothing, points must be a square distance matrix ($(size(points)) given)."))
111+
num_points >= 2 || throw(ArgumentError("At least two data points are required ($num_points given)."))
112+
data = points
113+
end
114+
clusters = _dbscan(data, num_points, radius, min_neighbors, min_cluster_size)
115+
return DbscanResult(clusters, num_points)
116+
end
174117

175118
# An implementation of DBSCAN algorithm that keeps track of both the core and boundary points
176-
function _dbscan(kdtree::KDTree, points::AbstractMatrix, radius::Real;
177-
min_neighbors::Int = 1, min_cluster_size::Int = 1)
178-
dim, num_points = size(points)
179-
num_points <= dim && throw(ArgumentError("points has $dim rows and $num_points columns. Must be a D x N matric with D < N"))
180-
0 <= radius || throw(ArgumentError("radius $radius must be ≥ 0"))
119+
function _dbscan(data::Union{AbstractMatrix, Tuple{NNTree, AbstractMatrix}},
120+
num_points::Integer, radius::Real,
121+
min_neighbors::Integer, min_cluster_size::Integer)
181122
1 <= min_neighbors || throw(ArgumentError("min_neighbors $min_neighbors must be ≥ 1"))
182123
1 <= min_cluster_size || throw(ArgumentError("min_cluster_size $min_cluster_size must be ≥ 1"))
183124

184125
clusters = Vector{DbscanCluster}()
185-
visited = falses(num_points)
186-
cluster_selection = falses(num_points)
187-
core_selection = falses(num_points)
126+
visited = fill(false, num_points)
127+
cluster_mask = Vector{Bool}(undef, num_points)
128+
core_mask = similar(cluster_mask)
188129
to_explore = Vector{Int}()
189-
adj_list = Vector{Int}()
190-
for i = 1:num_points
130+
neighbors = Vector{Int}()
131+
@inbounds for i = 1:num_points
191132
visited[i] && continue
133+
@assert isempty(to_explore)
192134
push!(to_explore, i) # start a new cluster
193-
fill!(core_selection, false)
194-
fill!(cluster_selection, false)
135+
fill!(core_mask, false)
136+
fill!(cluster_mask, false)
137+
# depth-first search to find all points in the cluster
195138
while !isempty(to_explore)
196-
current_index = popfirst!(to_explore)
197-
visited[current_index] && continue
198-
visited[current_index] = true
199-
append!(adj_list, inrange(kdtree, points[:, current_index], radius))
200-
cluster_selection[adj_list] .= true
201-
# if a point doesn't have enough neighbors it is not a 'core' point and its neighbors are not added to the to_explore list
202-
if (length(adj_list) - 1) < min_neighbors
203-
empty!(adj_list)
204-
continue # query returns the query point as well as the neighbors
139+
point = popfirst!(to_explore)
140+
visited[point] && continue
141+
visited[point] = true
142+
_dbscan_region_query!(neighbors, data, point, radius)
143+
cluster_mask[neighbors] .= true # mark as candidates
144+
145+
# if a point has enough neighbors, it is a 'core' point and its neighbors are added to the to_explore list
146+
if length(neighbors) >= min_neighbors
147+
core_mask[point] = true
148+
for j in neighbors
149+
visited[j] || push!(to_explore, j)
150+
end
205151
end
206-
core_selection[current_index] = true
207-
update_exploration_list!(adj_list, to_explore, visited)
152+
empty!(neighbors)
153+
end
154+
155+
# if the cluster has core and is large enough, it is accepted
156+
if any(core_mask) && (cluster_size = sum(cluster_mask)) >= min_cluster_size
157+
core = Vector{Int}()
158+
boundary = Vector{Int}()
159+
for (i, (is_cluster, is_core)) in enumerate(zip(cluster_mask, core_mask))
160+
@assert is_core && is_cluster || !is_core # core is always in a cluster
161+
is_cluster && push!(ifelse(is_core, core, boundary), i)
162+
end
163+
@assert !isempty(core)
164+
push!(clusters, DbscanCluster(cluster_size, core, boundary))
208165
end
209-
cluster_size = sum(cluster_selection)
210-
min_cluster_size <= cluster_size && accept_cluster!(clusters, core_selection, cluster_selection, cluster_size)
211166
end
212167
return clusters
213168
end
214169

215-
"""
216-
update_exploration_list!(adj_list, exploration_list, visited) -> adj_list
217-
218-
Update the queue for expanding the cluster.
219-
220-
# Arguments
221-
- `adj_list::Vector{Int}`: indices of the neighboring points to move to queue
222-
- `exploration_list::Vector{Int}`: the indices that will be explored in the future
223-
- `visited::BitVector`: a flag indicating whether a point has been explored already
224-
"""
225-
function update_exploration_list!(adj_list::Vector{T}, exploration_list::Vector{T},
226-
visited::BitVector) where T <: Integer
227-
for j in adj_list
228-
visited[j] && continue
229-
push!(exploration_list, j)
170+
# distance matrix-based
171+
function _dbscan_region_query!(neighbors::AbstractVector, dists::AbstractMatrix,
172+
point::Integer, radius::Real)
173+
empty!(neighbors)
174+
for (i, dist) in enumerate(view(dists, :, point))
175+
(dist <= radius) && push!(neighbors, i)
230176
end
231-
empty!(adj_list)
177+
return neighbors
232178
end
233179

234-
"""
235-
accept_cluster!(clusters, core_selection, cluster_selection) -> clusters
236-
237-
Accept cluster and update the clusters list.
238-
239-
# Arguments
240-
- `clusters::Vector{DbscanCluster}`: a list of the accepted clusters
241-
- `core_selection::Vector{Bool}`: selection of the core points of the cluster
242-
- `cluster_selection::Vector{Bool}`: selection of all the cluster points
243-
"""
244-
function accept_cluster!(clusters::Vector{DbscanCluster}, core_selection::BitVector,
245-
cluster_selection::BitVector, cluster_size::Int)
246-
core_idx = findall(core_selection) # index list of the core members
247-
boundary_selection = cluster_selection .& (~).(core_selection) #TODO change to .~ core_selection
248-
# when dropping 0.5
249-
boundary_idx = findall(boundary_selection) # index list of the boundary members
250-
push!(clusters, DbscanCluster(cluster_size, core_idx, boundary_idx))
180+
# NN-tree based
181+
function _dbscan_region_query!(neighbors::AbstractVector,
182+
nntree_and_points::Tuple{NNTree, AbstractMatrix},
183+
point::Integer, radius::Real)
184+
nntree, points = nntree_and_points
185+
empty!(neighbors)
186+
return append!(neighbors, inrange(nntree, view(points, :, point), radius))
251187
end

src/deprecate.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
k2::Int, a2::AbstractVector{Int}) varinfo(a1, a2)
1313
@deprecate varinfo(R::ClusteringResult, k0::Int, a0::AbstractVector{Int}) varinfo(R, a0)
1414

15+
# deprecated as of 0.14.5
16+
@deprecate(dbscan(D::AbstractMatrix{<:Real}, radius::Real, min_neighbors::Integer),
17+
dbscan(D, radius; metric=nothing, min_neighbors=min_neighbors))
18+
1519
# FIXME remove after deprecation period for merge/labels/height/method
1620
Base.propertynames(hclu::Hclust, private::Bool = false) =
1721
(fieldnames(typeof(hclu))...,

0 commit comments

Comments
 (0)