@@ -117,75 +117,91 @@ defmodule Scholar.Cluster.DBSCAN do
117117 }
118118 end
119119
120- defnp dbscan_inner ( is_core? , indices ) do
121- { labels , _ } =
122- while { labels = Nx . broadcast ( 0 , { Nx . axis_size ( indices , 0 ) } ) ,
123- { indices , is_core? , label_num = 1 , i = 0 } } ,
124- i < Nx . axis_size ( indices , 0 ) do
125- # During cluster expansion, we store points to be visited on
126- # the stack. Each point can be at the stack at most once, so
127- # the number of points is the upper bound on stack size.
128- stack = Nx . broadcast ( - 1 , { Nx . axis_size ( indices , 0 ) } )
129- stack_ptr = 0
130-
131- if Nx . take ( labels , i ) != 0 or not Nx . take ( is_core? , i ) do
132- { labels , { indices , is_core? , label_num , i + 1 } }
133- else
134- { labels , _ } =
135- while { labels , { k = i , label_num , indices , is_core? , stack , stack_ptr } } ,
136- stack_ptr >= 0 do
137- { labels , stack , stack_ptr } =
138- if Nx . take ( labels , k ) == 0 do
139- labels =
140- Nx . indexed_put (
141- labels ,
142- Nx . new_axis ( Nx . new_axis ( k , 0 ) , 0 ) ,
143- Nx . new_axis ( label_num , 0 )
144- )
145-
146- { stack , stack_ptr } =
147- if Nx . take ( is_core? , k ) do
148- neighb = Nx . take ( indices , k )
149- mask = neighb * ( labels == 0 )
150-
151- { stack , stack_ptr , _ } =
152- while { stack , stack_ptr , { mask , j = 0 } } , j < Nx . axis_size ( mask , 0 ) do
153- # Add point to the stack if it's a unlabelled neighbour
154- # and it is already not on the stack.
155- if Nx . take ( mask , j ) != 0 and not Nx . any ( stack == j ) do
156- stack =
157- Nx . indexed_put (
158- stack ,
159- Nx . new_axis ( Nx . new_axis ( stack_ptr , 0 ) , 0 ) ,
160- Nx . new_axis ( j , 0 )
161- )
162-
163- { stack , stack_ptr + 1 , { mask , j + 1 } }
164- else
165- { stack , stack_ptr , { mask , j + 1 } }
166- end
167- end
168-
169- { stack , stack_ptr }
170- else
171- { stack , stack_ptr }
172- end
173-
174- { labels , stack , stack_ptr }
175- else
176- { labels , stack , stack_ptr }
177- end
178-
179- k = if stack_ptr > 0 , do: Nx . take ( stack , stack_ptr - 1 ) , else: - 1
180- stack_ptr = stack_ptr - 1
181- { labels , { k , label_num , indices , is_core? , stack , stack_ptr } }
182- end
183-
184- { labels , { indices , is_core? , label_num + 1 , i + 1 } }
185- end
120+ defnp dbscan_inner ( is_core? , neighbors ) do
121+ # We implement the clustering via label propagation.
122+ #
123+ # Algorithm:
124+ #
125+ # 1. Initialize each core sample with a unique label (its index),
126+ # non-core samples get a "dummy label", which is bigger than
127+ # all others.
128+ # 2. Then iteratively, we update each sample with minimum label
129+ # from all of its core neighbors.
130+ # 3. Connected core samples (and all their neighbors) converge
131+ # to the same minimum label.
132+ # 4. Isolated non-core samples are left with "dummy label", which
133+ # we map to -1 at the end.
134+ #
135+ # This converges in O(D) iterations where D is the diameter of
136+ # the largest cluster.
137+ #
138+ # This approach is more parallelization-friendly than a sequential
139+ # sample-by-sample DFS traversal.
140+
141+ num_samples = Nx . axis_size ( is_core? , 0 )
142+ dummy_label = num_samples
143+
144+ labels = Nx . select ( is_core? , Nx . iota ( { num_samples } ) , dummy_label )
145+
146+ core_neighbors = Nx . new_axis ( is_core? , 0 ) and neighbors
147+
148+ # We create a tensor where for each sample (0-axis) we have indices
149+ # of its core neighbors (1-axis) and remaining spots filled with
150+ # its own index.
151+ core_neighbor_indices =
152+ Nx . select (
153+ core_neighbors ,
154+ # neighbor index
155+ Nx . iota ( { num_samples , num_samples } , axis: 1 ) ,
156+ # self index
157+ Nx . iota ( { num_samples , num_samples } , axis: 0 )
158+ )
159+
160+ { labels , _ , _ } =
161+ while { labels , core_neighbor_indices , finished? = Nx . tensor ( false ) } , not finished? do
162+ core_neighbor_labels = Nx . take ( labels , core_neighbor_indices )
163+ updated_labels = Nx . reduce_min ( core_neighbor_labels , axes: [ 1 ] )
164+ finished? = Nx . all ( labels == updated_labels )
165+ { updated_labels , core_neighbor_indices , finished? }
186166 end
187167
188- # we need to subtract 1 from labels because we started from label_num=1 which simplifies oprations
189- labels - 1
168+ # Normalize labels to be consecutive.
169+ normalized_labels = normalize_labels ( labels )
170+
171+ # Noisy samples don't get any label from core samples, so they keep
172+ # the dummy label, which we replace with -1.
173+ Nx . select ( labels == dummy_label , - 1 , normalized_labels )
174+ end
175+
176+ # Normalizes non-consecutive labels into consecutive labels.
177+ #
178+ # For example [1, 4, 2, 2, 1, 4] -> [0, 2, 1, 1, 0, 2].
179+ defnp normalize_labels ( labels ) do
180+ sort_indices = Nx . argsort ( labels )
181+ unsort_indices = inverse_permutation ( sort_indices )
182+
183+ sorted = Nx . take_along_axis ( labels , sort_indices )
184+
185+ # Create a mask with 1 at every position where a new value appears,
186+ # then use cumulative sum, so that each group gets the same value.
187+ normalized_sorted =
188+ Nx . concatenate ( [
189+ Nx . tensor ( [ 0 ] ) ,
190+ Nx . not_equal ( sorted [ 0 .. - 2 // 1 ] , sorted [ 1 .. - 1 // 1 ] )
191+ ] )
192+ |> Nx . cumulative_sum ( )
193+
194+ Nx . take_along_axis ( normalized_sorted , unsort_indices )
195+ end
196+
197+ defnp inverse_permutation ( indices ) do
198+ shape = Nx . shape ( indices )
199+ type = Nx . type ( indices )
200+
201+ Nx . indexed_put (
202+ Nx . broadcast ( Nx . tensor ( 0 , type: type ) , shape ) ,
203+ Nx . new_axis ( indices , - 1 ) ,
204+ Nx . iota ( shape , type: type )
205+ )
190206 end
191207end
0 commit comments