Skip to content

Commit 63d9aa7

Browse files
Rewrite DBSCAN from DFS to iterative propagation (#326)
1 parent 14e3914 commit 63d9aa7

File tree

1 file changed

+84
-68
lines changed

1 file changed

+84
-68
lines changed

lib/scholar/cluster/dbscan.ex

Lines changed: 84 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -117,75 +117,91 @@ defmodule Scholar.Cluster.DBSCAN do
117117
}
118118
end
119119

120-
defnp dbscan_inner(is_core?, indices) do
121-
{labels, _} =
122-
while {labels = Nx.broadcast(0, {Nx.axis_size(indices, 0)}),
123-
{indices, is_core?, label_num = 1, i = 0}},
124-
i < Nx.axis_size(indices, 0) do
125-
# During cluster expansion, we store points to be visited on
126-
# the stack. Each point can be at the stack at most once, so
127-
# the number of points is the upper bound on stack size.
128-
stack = Nx.broadcast(-1, {Nx.axis_size(indices, 0)})
129-
stack_ptr = 0
130-
131-
if Nx.take(labels, i) != 0 or not Nx.take(is_core?, i) do
132-
{labels, {indices, is_core?, label_num, i + 1}}
133-
else
134-
{labels, _} =
135-
while {labels, {k = i, label_num, indices, is_core?, stack, stack_ptr}},
136-
stack_ptr >= 0 do
137-
{labels, stack, stack_ptr} =
138-
if Nx.take(labels, k) == 0 do
139-
labels =
140-
Nx.indexed_put(
141-
labels,
142-
Nx.new_axis(Nx.new_axis(k, 0), 0),
143-
Nx.new_axis(label_num, 0)
144-
)
145-
146-
{stack, stack_ptr} =
147-
if Nx.take(is_core?, k) do
148-
neighb = Nx.take(indices, k)
149-
mask = neighb * (labels == 0)
150-
151-
{stack, stack_ptr, _} =
152-
while {stack, stack_ptr, {mask, j = 0}}, j < Nx.axis_size(mask, 0) do
153-
# Add point to the stack if it's a unlabelled neighbour
154-
# and it is already not on the stack.
155-
if Nx.take(mask, j) != 0 and not Nx.any(stack == j) do
156-
stack =
157-
Nx.indexed_put(
158-
stack,
159-
Nx.new_axis(Nx.new_axis(stack_ptr, 0), 0),
160-
Nx.new_axis(j, 0)
161-
)
162-
163-
{stack, stack_ptr + 1, {mask, j + 1}}
164-
else
165-
{stack, stack_ptr, {mask, j + 1}}
166-
end
167-
end
168-
169-
{stack, stack_ptr}
170-
else
171-
{stack, stack_ptr}
172-
end
173-
174-
{labels, stack, stack_ptr}
175-
else
176-
{labels, stack, stack_ptr}
177-
end
178-
179-
k = if stack_ptr > 0, do: Nx.take(stack, stack_ptr - 1), else: -1
180-
stack_ptr = stack_ptr - 1
181-
{labels, {k, label_num, indices, is_core?, stack, stack_ptr}}
182-
end
183-
184-
{labels, {indices, is_core?, label_num + 1, i + 1}}
185-
end
120+
defnp dbscan_inner(is_core?, neighbors) do
121+
# We implement the clustering via label propagation.
122+
#
123+
# Algorithm:
124+
#
125+
# 1. Initialize each core sample with a unique label (its index),
126+
# non-core samples get a "dummy label", which is bigger than
127+
# all others.
128+
# 2. Then iteratively, we update each sample with minimum label
129+
# from all of its core neighbors.
130+
# 3. Connected core samples (and all their neighbors) converge
131+
# to the same minimum label.
132+
# 4. Isolated non-core samples are left with "dummy label", which
133+
# we map to -1 at the end.
134+
#
135+
# This converges in O(D) iterations where D is the diameter of
136+
# the largest cluster.
137+
#
138+
# This approach is more parallelization-friendly than a sequential
139+
# sample-by-sample DFS traversal.
140+
141+
num_samples = Nx.axis_size(is_core?, 0)
142+
dummy_label = num_samples
143+
144+
labels = Nx.select(is_core?, Nx.iota({num_samples}), dummy_label)
145+
146+
core_neighbors = Nx.new_axis(is_core?, 0) and neighbors
147+
148+
# We create a tensor where for each sample (0-axis) we have indices
149+
# of its core neighbors (1-axis) and remaining spots filled with
150+
# its own index.
151+
core_neighbor_indices =
152+
Nx.select(
153+
core_neighbors,
154+
# neighbor index
155+
Nx.iota({num_samples, num_samples}, axis: 1),
156+
# self index
157+
Nx.iota({num_samples, num_samples}, axis: 0)
158+
)
159+
160+
{labels, _, _} =
161+
while {labels, core_neighbor_indices, finished? = Nx.tensor(false)}, not finished? do
162+
core_neighbor_labels = Nx.take(labels, core_neighbor_indices)
163+
updated_labels = Nx.reduce_min(core_neighbor_labels, axes: [1])
164+
finished? = Nx.all(labels == updated_labels)
165+
{updated_labels, core_neighbor_indices, finished?}
186166
end
187167

188-
# we need to subtract 1 from labels because we started from label_num=1 which simplifies oprations
189-
labels - 1
168+
# Normalize labels to be consecutive.
169+
normalized_labels = normalize_labels(labels)
170+
171+
# Noisy samples don't get any label from core samples, so they keep
172+
# the dummy label, which we replace with -1.
173+
Nx.select(labels == dummy_label, -1, normalized_labels)
174+
end
175+
176+
# Normalizes non-consecutive labels into consecutive labels.
177+
#
178+
# For example [1, 4, 2, 2, 1, 4] -> [0, 2, 1, 1, 0, 2].
179+
defnp normalize_labels(labels) do
180+
sort_indices = Nx.argsort(labels)
181+
unsort_indices = inverse_permutation(sort_indices)
182+
183+
sorted = Nx.take_along_axis(labels, sort_indices)
184+
185+
# Create a mask with 1 at every position where a new value appears,
186+
# then use cumulative sum, so that each group gets the same value.
187+
normalized_sorted =
188+
Nx.concatenate([
189+
Nx.tensor([0]),
190+
Nx.not_equal(sorted[0..-2//1], sorted[1..-1//1])
191+
])
192+
|> Nx.cumulative_sum()
193+
194+
Nx.take_along_axis(normalized_sorted, unsort_indices)
195+
end
196+
197+
defnp inverse_permutation(indices) do
198+
shape = Nx.shape(indices)
199+
type = Nx.type(indices)
200+
201+
Nx.indexed_put(
202+
Nx.broadcast(Nx.tensor(0, type: type), shape),
203+
Nx.new_axis(indices, -1),
204+
Nx.iota(shape, type: type)
205+
)
190206
end
191207
end

0 commit comments

Comments
 (0)