Skip to content

Commit 2ed4387

Browse files
committed
dbscan(): merge impls, refactor
* merge the two dbscan() implementations * deprecate dbscan(dist, radius) in favor of dbscan(dist, radius, metric=nothing) * dbscan(points, ...) returns DbscanResult instead of Vector{DbscanCluster} (breaking change)
1 parent f16f717 commit 2ed4387

File tree

3 files changed

+165
-221
lines changed

3 files changed

+165
-221
lines changed

src/dbscan.jl

Lines changed: 108 additions & 185 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,11 @@
11
# DBSCAN Clustering
22
#
33

4-
"""
5-
DbscanResult <: ClusteringResult
6-
7-
The output of [`dbscan`](@ref) function (distance matrix-based implementation).
8-
9-
## Fields
10-
- `seeds::Vector{Int}`: indices of cluster starting points, length *K*
11-
- `counts::Vector{Int}`: cluster sizes (number of assigned points), length *K*
12-
- `assignments::Vector{Int}`: vector of clusters indices, where each point was assigned to, length *N*
13-
"""
14-
mutable struct DbscanResult <: ClusteringResult
15-
seeds::Vector{Int}
16-
counts::Vector{Int}
17-
assignments::Vector{Int}
18-
end
194

205
"""
216
DbscanCluster
227
23-
DBSCAN cluster returned by [`dbscan`](@ref) function (point coordinates-based
24-
implementation)
8+
DBSCAN cluster, part of [`DbscanResult`](@ref) returned by [`dbscan`](@ref) function.
259
2610
## Fields
2711
- `size::Int`: number of points in a cluster (core + boundary)
@@ -35,131 +19,67 @@ struct DbscanCluster
3519
boundary_indices::Vector{Int}
3620
end
3721

38-
## main algorithm
39-
4022
"""
41-
dbscan(D::AbstractMatrix, eps::Real, minpts::Int) -> DbscanResult
23+
DbscanResult <: ClusteringResult
4224
43-
Perform DBSCAN algorithm using the distance matrix `D`.
25+
The output of [`dbscan`](@ref) function.
4426
45-
# Arguments
46-
The following options control which points would be considered
47-
*density reachable*:
48-
- `eps::Real`: the radius of a point neighborhood
49-
- `minpts::Int`: the minimum number of neighboring points (including itself)
50-
to qualify a point as a density point.
27+
## Fields
28+
- `clusters::Vector{DbscanCluster}`: clusters, length *K*
29+
- `seeds::Vector{Int}`: indices of the first points of each cluster's *core*, length *K*
30+
- `counts::Vector{Int}`: cluster sizes (number of assigned points), length *K*
31+
- `assignments::Vector{Int}`: vector of clusters indices, where each point was assigned to, length *N*
5132
"""
52-
function dbscan(D::AbstractMatrix{T}, eps::Real, minpts::Int) where T<:Real
53-
# check arguments
54-
n = size(D, 1)
55-
size(D, 2) == n || throw(ArgumentError("D must be a square matrix ($(size(D)) given)."))
56-
n >= 2 || throw(ArgumentError("At least two data points are required ($n given)."))
57-
eps >= 0 || throw(ArgumentError("eps must be a positive value ($eps given)."))
58-
minpts >= 1 || throw(ArgumentError("minpts must be positive integer ($minpts given)."))
59-
60-
# invoke core algorithm
61-
_dbscan(D, convert(T, eps), minpts, 1:n)
62-
end
63-
64-
function _dbscan(D::AbstractMatrix{T}, eps::T, minpts::Int, visitseq::AbstractVector{Int}) where T<:Real
65-
n = size(D, 1)
66-
67-
# prepare
68-
seeds = Int[]
69-
counts = Int[]
70-
assignments = zeros(Int, n)
71-
visited = zeros(Bool, n)
72-
k = 0
73-
74-
# main loop
75-
for p in visitseq
76-
if assignments[p] == 0 && !visited[p]
77-
visited[p] = true
78-
nbs = _dbs_region_query(D, p, eps)
79-
if length(nbs) >= minpts
80-
k += 1
81-
cnt = _dbs_expand_cluster!(D, k, p, nbs, eps, minpts, assignments, visited)
82-
push!(seeds, p)
83-
push!(counts, cnt)
84-
end
85-
end
86-
end
87-
88-
# make output
89-
return DbscanResult(seeds, counts, assignments)
90-
end
91-
92-
## key steps
33+
struct DbscanResult <: ClusteringResult
34+
clusters::Vector{DbscanCluster}
35+
seeds::Vector{Int}
36+
counts::Vector{Int}
37+
assignments::Vector{Int}
9338

94-
function _dbs_region_query(D::AbstractMatrix{T}, p::Int, eps::T) where T<:Real
95-
n = size(D,1)
96-
nbs = Int[]
97-
dists = view(D,:,p)
98-
for i = 1:n
99-
@inbounds if dists[i] <= eps
100-
push!(nbs, i)
39+
function DbscanResult(clusters::AbstractVector{DbscanCluster}, num_points::Integer)
40+
assignments = zeros(Int, num_points)
41+
for (i, clu) in enumerate(clusters)
42+
assignments[clu.core_indices] .= i
43+
assignments[clu.boundary_indices] .= i
10144
end
45+
new(clusters,
46+
[c.core_indices[1] for c in clusters],
47+
[c.size for c in clusters],
48+
assignments)
10249
end
103-
return nbs::Vector{Int}
10450
end
10551

106-
function _dbs_expand_cluster!(D::AbstractMatrix{T}, # distance matrix
107-
k::Int, # the index of current cluster
108-
p::Int, # the index of seeding point
109-
nbs::Vector{Int}, # eps-neighborhood of p
110-
eps::T, # radius of neighborhood
111-
minpts::Int, # minimum number of neighbors of a density point
112-
assignments::Vector{Int}, # assignment vector
113-
visited::Vector{Bool}) where T<:Real # visited indicators
114-
assignments[p] = k
115-
cnt = 1
116-
while !isempty(nbs)
117-
q = popfirst!(nbs)
118-
if !visited[q]
119-
visited[q] = true
120-
qnbs = _dbs_region_query(D, q, eps)
121-
if length(qnbs) >= minpts
122-
for x in qnbs
123-
if assignments[x] == 0
124-
push!(nbs, x)
125-
end
126-
end
127-
end
128-
end
129-
if assignments[q] == 0
130-
assignments[q] = k
131-
cnt += 1
132-
end
133-
end
134-
return cnt
135-
end
13652

13753
"""
13854
dbscan(points::AbstractMatrix, radius::Real;
139-
[metric], [min_neighbors], [min_cluster_size],
140-
[nntree_kwargs...]) -> Vector{DbscanCluster}
55+
[metric=Euclidean()],
56+
[min_neighbors=1], [min_cluster_size=1],
57+
[nntree_kwargs...]) -> DbscanResult
14158
142-
Cluster `points` using the DBSCAN (density-based spatial clustering of
143-
applications with noise) algorithm.
59+
Cluster `points` using the DBSCAN (Density-Based Spatial Clustering of
60+
Applications with Noise) algorithm.
14461
145-
# Arguments
146-
- `points`: the ``d×n`` matrix of points. `points[:, j]` is a
147-
``d``-dimensional coordinates of ``j``-th point
148-
- `radius::Real`: query radius
62+
## Arguments
63+
- `points`: when `metric` is specified, the *d×n* matrix, where
64+
each column is a *d*-dimensional coordinate of a point;
65+
when `metric=nothing`, the *n×n* matrix of pairwise distances between the points
66+
- `radius::Real`: neighborhood radius; points within this distance
67+
are considered neighbors
14968
15069
Optional keyword arguments to control the algorithm:
151-
- `metric` (defaults to `Euclidean`): the points distance metric to use
152-
- `min_neighbors::Integer` (defaults to 1): the minimum number of a *core* point
153-
neighbors
154-
- `min_cluster_size::Integer` (defaults to 1): the minimum number of points in
155-
a valid cluster
156-
- `nntree_kwargs`: parameters (like `leafsize`) for the `KDTree` contructor
70+
- `metric` (defaults to `Euclidean()`): the points distance metric to use,
71+
`nothing` means `points` is the *n×n* precalculated distance matrix
72+
- `min_neighbors::Integer` (defaults to 1): the minimal number of neighbors
73+
required to assign a point to a cluster "core"
74+
- `min_cluster_size::Integer` (defaults to 1): the minimal number of points in
75+
a cluster; cluster candidates with fewer points are discarded
76+
- `nntree_kwargs...`: parameters (like `leafsize`) for the `KDTree` constructor
15777
15878
## Example
15979
```julia
16080
points = randn(3, 10000)
16181
# DBSCAN clustering, clusters with less than 20 points will be discarded:
162-
clusters = dbscan(points, 0.05, min_neighbors = 3, min_cluster_size = 20)
82+
clustering = dbscan(points, 0.05, min_neighbors = 3, min_cluster_size = 20)
16383
```
16484
16585
## References:
@@ -176,89 +96,92 @@ function dbscan(points::AbstractMatrix, radius::Real;
17696
metric = Euclidean(),
17797
min_neighbors::Integer = 1, min_cluster_size::Integer = 1,
17898
nntree_kwargs...)
179-
kdtree = KDTree(points, metric; nntree_kwargs...)
180-
return _dbscan(kdtree, points, radius;
181-
min_neighbors=min_neighbors,
182-
min_cluster_size=min_cluster_size)
183-
end
99+
0 <= radius || throw(ArgumentError("radius $radius must be ≥ 0"))
184100

101+
if metric !== nothing
102+
# points are point coordinates
103+
dim, num_points = size(points)
104+
num_points <= dim && throw(ArgumentError("points has $dim rows and $num_points columns. Must be a D x N matric with D < N"))
105+
kdtree = KDTree(points, metric; nntree_kwargs...)
106+
data = (kdtree, points)
107+
else
108+
# points is a distance matrix
109+
num_points = size(points, 1)
110+
size(points, 2) == num_points || throw(ArgumentError("When metric=nothing, points must be a square distance matrix ($(size(points)) given)."))
111+
num_points >= 2 || throw(ArgumentError("At least two data points are required ($num_points given)."))
112+
data = points
113+
end
114+
clusters = _dbscan(data, num_points, radius, min_neighbors, min_cluster_size)
115+
return DbscanResult(clusters, num_points)
116+
end
185117

186118
# An implementation of DBSCAN algorithm that keeps track of both the core and boundary points
187-
function _dbscan(kdtree::KDTree, points::AbstractMatrix, radius::Real;
188-
min_neighbors::Integer = 1, min_cluster_size::Integer = 1)
189-
dim, num_points = size(points)
190-
num_points <= dim && throw(ArgumentError("points has $dim rows and $num_points columns. Must be a D x N matric with D < N"))
191-
0 <= radius || throw(ArgumentError("radius $radius must be ≥ 0"))
119+
function _dbscan(data::Union{AbstractMatrix, Tuple{NNTree, AbstractMatrix}},
120+
num_points::Integer, radius::Real,
121+
min_neighbors::Integer, min_cluster_size::Integer)
192122
1 <= min_neighbors || throw(ArgumentError("min_neighbors $min_neighbors must be ≥ 1"))
193123
1 <= min_cluster_size || throw(ArgumentError("min_cluster_size $min_cluster_size must be ≥ 1"))
194124

195125
clusters = Vector{DbscanCluster}()
196-
visited = falses(num_points)
197-
cluster_selection = falses(num_points)
198-
core_selection = falses(num_points)
126+
visited = fill(false, num_points)
127+
cluster_mask = Vector{Bool}(undef, num_points)
128+
core_mask = similar(cluster_mask)
199129
to_explore = Vector{Int}()
200-
adj_list = Vector{Int}()
201-
for i = 1:num_points
130+
neighbors = Vector{Int}()
131+
@inbounds for i = 1:num_points
202132
visited[i] && continue
133+
@assert isempty(to_explore)
203134
push!(to_explore, i) # start a new cluster
204-
fill!(core_selection, false)
205-
fill!(cluster_selection, false)
135+
fill!(core_mask, false)
136+
fill!(cluster_mask, false)
137+
# depth-first search to find all points in the cluster
206138
while !isempty(to_explore)
207-
current_index = popfirst!(to_explore)
208-
visited[current_index] && continue
209-
visited[current_index] = true
210-
append!(adj_list, inrange(kdtree, points[:, current_index], radius))
211-
cluster_selection[adj_list] .= true
212-
# if a point doesn't have enough neighbors it is not a 'core' point and its neighbors are not added to the to_explore list
213-
if length(adj_list) < min_neighbors
214-
empty!(adj_list)
215-
continue # query returns the query point as well as the neighbors
139+
point = popfirst!(to_explore)
140+
visited[point] && continue
141+
visited[point] = true
142+
_dbscan_region_query!(neighbors, data, point, radius)
143+
cluster_mask[neighbors] .= true # mark as candidates
144+
145+
# if a point has enough neighbors, it is a 'core' point and its neighbors are added to the to_explore list
146+
if length(neighbors) >= min_neighbors
147+
core_mask[point] = true
148+
for j in neighbors
149+
visited[j] || push!(to_explore, j)
150+
end
216151
end
217-
core_selection[current_index] = true
218-
update_exploration_list!(adj_list, to_explore, visited)
152+
empty!(neighbors)
219153
end
220-
if any(core_selection) &&
221-
(cluster_size = sum(cluster_selection)) >= min_cluster_size
222-
accept_cluster!(clusters, core_selection, cluster_selection, cluster_size)
154+
155+
# if the cluster has core and is large enough, it is accepted
156+
if any(core_mask) && (cluster_size = sum(cluster_mask)) >= min_cluster_size
157+
core = Vector{Int}()
158+
boundary = Vector{Int}()
159+
for (i, (is_cluster, is_core)) in enumerate(zip(cluster_mask, core_mask))
160+
@assert is_core && is_cluster || !is_core # core is always in a cluster
161+
is_cluster && push!(ifelse(is_core, core, boundary), i)
162+
end
163+
@assert !isempty(core)
164+
push!(clusters, DbscanCluster(cluster_size, core, boundary))
223165
end
224166
end
225167
return clusters
226168
end
227169

228-
"""
229-
update_exploration_list!(adj_list, exploration_list, visited) -> adj_list
230-
231-
Update the queue for expanding the cluster.
232-
233-
# Arguments
234-
- `adj_list::Vector{Int}`: indices of the neighboring points to move to queue
235-
- `exploration_list::Vector{Int}`: the indices that will be explored in the future
236-
- `visited::BitVector`: a flag indicating whether a point has been explored already
237-
"""
238-
function update_exploration_list!(adj_list::Vector{T}, exploration_list::Vector{T},
239-
visited::BitVector) where T <: Integer
240-
for j in adj_list
241-
visited[j] && continue
242-
push!(exploration_list, j)
170+
# distance matrix-based
171+
function _dbscan_region_query!(neighbors::AbstractVector, dists::AbstractMatrix,
172+
point::Integer, radius::Real)
173+
empty!(neighbors)
174+
for (i, dist) in enumerate(view(dists, :, point))
175+
(dist <= radius) && push!(neighbors, i)
243176
end
244-
empty!(adj_list)
177+
return neighbors
245178
end
246179

247-
"""
248-
accept_cluster!(clusters, core_selection, cluster_selection) -> clusters
249-
250-
Accept cluster and update the clusters list.
251-
252-
# Arguments
253-
- `clusters::Vector{DbscanCluster}`: a list of the accepted clusters
254-
- `core_selection::Vector{Bool}`: selection of the core points of the cluster
255-
- `cluster_selection::Vector{Bool}`: selection of all the cluster points
256-
"""
257-
function accept_cluster!(clusters::Vector{DbscanCluster}, core_selection::BitVector,
258-
cluster_selection::BitVector, cluster_size::Int)
259-
core_idx = findall(core_selection) # index list of the core members
260-
boundary_selection = cluster_selection .& (~).(core_selection) #TODO change to .~ core_selection
261-
# when dropping 0.5
262-
boundary_idx = findall(boundary_selection) # index list of the boundary members
263-
push!(clusters, DbscanCluster(cluster_size, core_idx, boundary_idx))
180+
# NN-tree based
181+
function _dbscan_region_query!(neighbors::AbstractVector,
182+
nntree_and_points::Tuple{NNTree, AbstractMatrix},
183+
point::Integer, radius::Real)
184+
nntree, points = nntree_and_points
185+
empty!(neighbors)
186+
return append!(neighbors, inrange(nntree, view(points, :, point), radius))
264187
end

src/deprecate.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
k2::Int, a2::AbstractVector{Int}) varinfo(a1, a2)
1313
@deprecate varinfo(R::ClusteringResult, k0::Int, a0::AbstractVector{Int}) varinfo(R, a0)
1414

15+
# deprecated as of 0.14.5
16+
@deprecate(dbscan(D::AbstractMatrix{<:Real}, radius::Real, min_neighbors::Integer),
17+
dbscan(D, radius; metric=nothing, min_neighbors=min_neighbors))
18+
1519
# FIXME remove after deprecation period for merge/labels/height/method
1620
Base.propertynames(hclu::Hclust, private::Bool = false) =
1721
(fieldnames(typeof(hclu))...,

0 commit comments

Comments
 (0)