Skip to content

Commit 0bce6e4

Browse files
Improve DBSCAN performance (#325)
1 parent 48aa36e commit 0bce6e4

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ jobs:
2424
otp-version: ${{matrix.otp}}
2525
elixir-version: ${{matrix.elixir}}
2626

27-
- uses: actions/cache@v2
27+
- uses: actions/cache@v3
2828
with:
2929
path: |
3030
deps

lib/scholar/cluster/dbscan.ex

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,10 @@ defmodule Scholar.Cluster.DBSCAN do
122122
while {labels = Nx.broadcast(0, {Nx.axis_size(indices, 0)}),
123123
{indices, is_core?, label_num = 1, i = 0}},
124124
i < Nx.axis_size(indices, 0) do
125-
stack = Nx.broadcast(0, {Nx.axis_size(indices, 0) ** 2})
125+
# During cluster expansion, we store points to be visited on
126+
# the stack. Each point can be at the stack at most once, so
127+
# the number of points is the upper bound on stack size.
128+
stack = Nx.broadcast(-1, {Nx.axis_size(indices, 0)})
126129
stack_ptr = 0
127130

128131
if Nx.take(labels, i) != 0 or not Nx.take(is_core?, i) do
@@ -147,7 +150,9 @@ defmodule Scholar.Cluster.DBSCAN do
147150

148151
{stack, stack_ptr, _} =
149152
while {stack, stack_ptr, {mask, j = 0}}, j < Nx.axis_size(mask, 0) do
150-
if Nx.take(mask, j) != 0 do
153+
# Add point to the stack if it's a unlabelled neighbour
154+
# and it is already not on the stack.
155+
if Nx.take(mask, j) != 0 and not Nx.any(stack == j) do
151156
stack =
152157
Nx.indexed_put(
153158
stack,

0 commit comments

Comments
 (0)