Skip to content

Commit 7b9214f

Browse files
Big speep-up in DBSCAN-based cluster breaking using parallel + no GIL
1 parent f6bafd9 commit 7b9214f

File tree

2 files changed

+26
-27
lines changed

2 files changed

+26
-27
lines changed

spine/utils/cluster/label.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def __init__(self, break_eps=1.1, break_metric='chebyshev',
4747
Distance scale used in the break up procedure
4848
break_metric : str, default 'chebyshev'
4949
Distance metric used in the break up produce
50-
break_classes : List[int], default
50+
break_classes : List[int], default
5151
[SHOWR_SHP, TRACK_SHP, MICHL_SHP, DELTA_SHP]
5252
Classes to run DBSCAN on to break up
5353
"""
@@ -141,12 +141,12 @@ def _process(self, clust_label, seg_label, seg_pred, ghost_pred=None):
141141
if not len(clust_label):
142142
if ghost_pred is None:
143143
shape = (len(coords), num_cols)
144-
dummy_labels = -1 * self._ones(shape)
144+
dummy_labels = -self._ones(shape)
145145
dummy_labels[:, :VALUE_COL] = coords
146146

147147
else:
148148
shape = (len(deghost_index), num_cols)
149-
dummy_labels = -1 * self._ones(shape)
149+
dummy_labels = -self._ones(shape)
150150
dummy_labels[:, :VALUE_COL] = coords[deghost_index]
151151

152152
return dummy_labels
@@ -159,7 +159,7 @@ def _process(self, clust_label, seg_label, seg_pred, ghost_pred=None):
159159
seg_pred = seg_pred_long
160160

161161
# Prepare new labels
162-
new_label = -1. * self._ones((len(coords), num_cols))
162+
new_label = -self._ones((len(coords), num_cols))
163163
new_label[:, :VALUE_COL] = coords
164164

165165
# Check if the segment labels and predictions are compatible. If they are
@@ -172,7 +172,7 @@ def _process(self, clust_label, seg_label, seg_pred, ghost_pred=None):
172172
true_deghost = seg_label < GHOST_SHP
173173
seg_mismatch = ~compat_mat[(seg_pred, seg_label)]
174174
new_label[true_deghost] = clust_label
175-
new_label[seg_mismatch & true_deghost, VALUE_COL:] = -self._ones(1)
175+
new_label[true_deghost & seg_mismatch, VALUE_COL:] = -self._ones(1)
176176

177177
# For mismatched predictions, attempt to find a touching instance of the
178178
# same class to assign it sensible cluster labels.
@@ -182,7 +182,7 @@ def _process(self, clust_label, seg_label, seg_pred, ghost_pred=None):
182182
continue
183183

184184
# Restrict to points in this class that have incompatible segment
185-
# labels. Track points do not mix, EM points are allowed to.
185+
# labels. Track points do not mix, EM points are allowed to.
186186
bad_index = self._where(
187187
(seg_pred == s) & (~true_deghost | seg_mismatch))[0]
188188
if len(bad_index) == 0:
@@ -211,7 +211,7 @@ def _process(self, clust_label, seg_label, seg_pred, ghost_pred=None):
211211
if tagged_voxels_count > 0:
212212
# Use the label of the touching true voxel
213213
additional_clust_label = self._cat(
214-
[X_pred[select_index],
214+
[X_pred[select_index],
215215
X_true[closest_ids[select_index], VALUE_COL:]], 1)
216216
new_label[bad_index[select_index]] = additional_clust_label
217217

spine/utils/gnn/cluster.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -364,39 +364,38 @@ def break_clusters(data, clusts, eps, metric):
364364
if not len(clusts):
365365
return np.copy(data[:, CLUST_COL])
366366

367-
return _break_clusters(data, clusts, eps, metric)
367+
# Break labels
368+
break_labels = _break_clusters(data, clusts, eps, metric)
368369

369-
@nb.njit(cache=True)
370+
# Offset individual broken labels to prevent overlap
371+
labels = np.copy(data[:, CLUST_COL])
372+
offset = np.max(labels) + 1
373+
for k, clust in enumerate(clusts):
374+
# Update IDs, offset
375+
ids = break_labels[clust]
376+
labels[clust] = offset + ids
377+
offset += len(np.unique(ids))
378+
379+
return labels
380+
381+
@nb.njit(cache=True, parallel=True, nogil=True)
370382
def _break_clusters(data: nb.float64[:,:],
371383
clusts: nb.types.List(nb.int64[:]),
372384
eps: nb.float64,
373385
metric: str) -> nb.float64[:]:
374-
# Get the relevant data products
375-
points = data[:, COORD_COLS]
376-
labels = data[:, CLUST_COL]
377-
378386
# Loop over clusters to break, run DBSCAN
379-
break_ids = np.full_like(labels, -1)
380-
ids = np.arange(len(clusts)).astype(np.int64)
381-
for k in range(len(clusts)):
387+
break_labels = np.full(len(data), -1, dtype=data.dtype)
388+
points = data[:, COORD_COLS]
389+
for k in nb.prange(len(clusts)):
382390
# Restrict the points to those in the cluster
383-
clust = clusts[ids[k]]
391+
clust = clusts[k]
384392
points_c = points[clust]
385393

386394
# Run DBSCAN on the cluster, update labels
387395
clust_ids = nbl.dbscan(points_c, eps=eps, metric=metric)
388396

389397
# Store the breaking IDs
390-
break_ids[clust] = clust_ids
391-
392-
# Update the break IDs to ensure no overlap (has to be sequential)
393-
break_labels = np.copy(labels)
394-
offset = np.max(labels) + 1
395-
for k, clust in enumerate(clusts):
396-
# Update IDs, offset
397-
ids = break_ids[clust]
398-
break_labels[clust] = offset + ids
399-
offset += len(np.unique(ids))
398+
break_labels[clust] = clust_ids
400399

401400
return break_labels
402401

0 commit comments

Comments
 (0)