@@ -122,7 +122,10 @@ defmodule Scholar.Cluster.DBSCAN do
122122 while { labels = Nx . broadcast ( 0 , { Nx . axis_size ( indices , 0 ) } ) ,
123123 { indices , is_core? , label_num = 1 , i = 0 } } ,
124124 i < Nx . axis_size ( indices , 0 ) do
125- stack = Nx . broadcast ( 0 , { Nx . axis_size ( indices , 0 ) ** 2 } )
125+ # During cluster expansion, we store points to be visited on
126+ # the stack. Each point can be at the stack at most once, so
127+ # the number of points is the upper bound on stack size.
128+ stack = Nx . broadcast ( - 1 , { Nx . axis_size ( indices , 0 ) } )
126129 stack_ptr = 0
127130
128131 if Nx . take ( labels , i ) != 0 or not Nx . take ( is_core? , i ) do
@@ -147,7 +150,9 @@ defmodule Scholar.Cluster.DBSCAN do
147150
148151 { stack , stack_ptr , _ } =
149152 while { stack , stack_ptr , { mask , j = 0 } } , j < Nx . axis_size ( mask , 0 ) do
150- if Nx . take ( mask , j ) != 0 do
153+ # Add point to the stack if it's a unlabelled neighbour
154+ # and it is already not on the stack.
155+ if Nx . take ( mask , j ) != 0 and not Nx . any ( stack == j ) do
151156 stack =
152157 Nx . indexed_put (
153158 stack ,
0 commit comments