Skip to content

Commit d9a3a16

Browse files
Merge pull request #80 from francois-drielsma/develop
Cluster label adaptation speed-up
2 parents 0ff213f + bea63fb commit d9a3a16

File tree

8 files changed

+482
-326
lines changed

8 files changed

+482
-326
lines changed

bin/larcv_check_valid.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,13 @@ def main(source, source_list, output):
4444
keys_list, unique_counts = [], []
4545
for file_path in tqdm(source):
4646
# Count the number of entries in each tree
47-
f = TFile(file_path)
47+
try:
48+
f = TFile(file_path)
49+
except OSError:
50+
keys_list.append([])
51+
unique_counts.append([])
52+
continue
53+
4854
keys = [key.GetName() for key in f.GetListOfKeys()]
4955
trees = [f.Get(key) for key in keys]
5056
num_entries = [tree.GetEntries() for tree in trees]

spine/data/out/particle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,7 @@ class TruthParticle(Particle, ParticleBase, TruthBase):
485485
orig_interaction_id: int = -1
486486
orig_parent_id: int = -1
487487
orig_group_id: int = -1
488-
orig_children_id: np.ndarray = -1
488+
orig_children_id: np.ndarray = None
489489
children_counts: np.ndarray = None
490490
reco_length: float = -1.
491491
reco_start_dir: np.ndarray = None

spine/model/full_chain.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
from spine.utils.calib import CalibrationManager
2424
from spine.utils.logger import logger
2525
from spine.utils.ppn import get_particle_points
26-
from spine.utils.ghost import (
27-
compute_rescaled_charge_batch, adapt_labels_batch)
26+
from spine.utils.ghost import compute_rescaled_charge_batch
27+
from spine.utils.cluster.label import ClusterLabelAdapter
2828
from spine.utils.gnn.cluster import (
2929
form_clusters_batch, get_cluster_label_batch)
3030
from spine.utils.gnn.evaluation import primary_assignment_batch
@@ -173,8 +173,7 @@ def __init__(self, chain, uresnet_deghost=None, uresnet=None,
173173
self.uresnet_ppn = UResNetPPN(**uresnet_ppn)
174174

175175
# Initialize the relabeling process (adapt to the semantic predictions)
176-
# TODO: make this a class which holds onto these parameters?
177-
self.adapt_params = adapt_labels if adapt_labels is not None else {}
176+
self.label_adapter = ClusterLabelAdapter(**(adapt_labels or {}))
178177

179178
# Initialize the dense clustering model
180179
self.fragment_shapes = []
@@ -495,9 +494,8 @@ def run_segmentation_ppn(self, data, seg_label=None, clust_label=None):
495494
if seg_label is not None and clust_label is not None:
496495
seg_pred = self.result['seg_pred']
497496
ghost_pred = self.result.get('ghost_pred', None)
498-
clust_label = adapt_labels_batch(
499-
clust_label, seg_label, seg_pred, ghost_pred,
500-
**self.adapt_params)
497+
clust_label = self.label_adapter(
498+
clust_label, seg_label, seg_pred, ghost_pred)
501499

502500
self.result['clust_label_adapt'] = clust_label
503501

spine/utils/cluster/label.py

Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,306 @@
1+
"""Class which adapts clustering labels given upstream semantic predictions."""
2+
3+
import numpy as np
4+
import torch
5+
from torch_cluster import knn
6+
from scipy.spatial.distance import cdist
7+
8+
from spine.data import TensorBatch
9+
10+
from spine.utils.gnn.cluster import form_clusters, break_clusters
11+
from spine.utils.globals import (
12+
COORD_COLS, VALUE_COL, CLUST_COL, SHAPE_COL, SHOWR_SHP, TRACK_SHP,
13+
MICHL_SHP, DELTA_SHP, GHOST_SHP)
14+
15+
__all__ = ['ClusterLabelAdapter']
16+
17+
18+
class ClusterLabelAdapter:
19+
"""Adapts the cluster labels to account for the predicted semantics.
20+
21+
Points wrongly predicted get the cluster label of the closest touching
22+
cluster, if there is one. Points that are predicted as ghosts get invalid
23+
(-1) cluster labels everywhere.
24+
25+
Instances that have been broken up by the deghosting or semantic
26+
segmentation process get assigned distinct cluster labels for each
27+
effective fragment, provided they appearing in the `break_classes` list.
28+
29+
Notes
30+
-----
31+
This class supports both Numpy arrays and Torch tensors.
32+
33+
It uses the GPU implementation from `torch_cluster.knn` to speed up the
34+
label adaptation computation (instead of cdist).
35+
36+
"""
37+
38+
def __init__(self, break_eps=1.1, break_metric='chebyshev',
39+
break_classes=[SHOWR_SHP, TRACK_SHP, MICHL_SHP, DELTA_SHP]):
40+
"""Initialize the adapter class.
41+
42+
Parameters
43+
----------
44+
dtype : str, default 'torch'
45+
Type of data to be processed through the label adapter
46+
break_eps : float, default 1.1
47+
Distance scale used in the break up procedure
48+
break_metric : str, default 'chebyshev'
49+
Distance metric used in the break up produce
50+
break_classes : List[int], default
51+
[SHOWR_SHP, TRACK_SHP, MICHL_SHP, DELTA_SHP]
52+
Classes to run DBSCAN on to break up
53+
"""
54+
# Store relevant parameters
55+
self.break_eps = break_eps
56+
self.break_metric = break_metric
57+
self.break_classes = break_classes
58+
59+
# Attributes used to fetch the correct functions
60+
self.torch, self.dtype, self.device = None, None, None
61+
62+
def __call__(self, clust_label, seg_label, seg_pred, ghost_pred=None):
63+
"""Adapts the cluster labels for one entry or a batch of entries.
64+
65+
Parameters
66+
----------
67+
clust_label : Union[TensorBatch, np.ndarray, torch.Tensor]
68+
(N, N_l) Cluster label tensor
69+
seg_label : Union[TensorBatch, np.ndarray, torch.Tensor]
70+
(M, 5) Segmentation label tensor
71+
seg_pred : Union[TensorBatch, np.ndarray, torch.Tensor]
72+
(M/N_deghost) Segmentation predictions for each voxel
73+
ghost_pred : Union[TensorBatch, np.ndarray, torch.Tensor], optional
74+
(M) Ghost predictions for each voxel
75+
76+
Returns
77+
-------
78+
Union[TensorBatch, np.ndarray, torch.Tensor]
79+
(N_deghost, N_l) Adapted cluster label tensor
80+
"""
81+
# Set the data type/device based on the input
82+
ref_tensor = clust_label
83+
if isinstance(ref_tensor, TensorBatch):
84+
ref_tensor = ref_tensor.tensor
85+
self.torch = isinstance(ref_tensor, torch.Tensor)
86+
87+
self.dtype = clust_label.dtype
88+
if self.torch:
89+
self.device = clust_label.device
90+
91+
# Dispatch depending on the data type
92+
if isinstance(clust_label, TensorBatch):
93+
# If it is batch data, call the main process function of each entry
94+
shape = (seg_pred.shape[0], clust_label.shape[1])
95+
clust_label_adapted = torch.empty(
96+
shape, dtype=clust_label.dtype, device=clust_label.device)
97+
for b in range(clust_label.batch_size):
98+
lower, upper = seg_pred.edges[b], seg_pred.edges[b+1]
99+
ghost_pred_b = ghost_pred[b] if ghost_pred is not None else None
100+
clust_label_adapted[lower:upper] = self._process(
101+
clust_label[b], seg_label[b], seg_pred[b], ghost_pred_b)
102+
103+
return TensorBatch(clust_label_adapted, seg_pred.counts)
104+
105+
else:
106+
# Otherwise, call the main process function directly
107+
return self._process(clust_label, seg_label, seg_pred, ghost_pred)
108+
109+
def _process(self, clust_label, seg_label, seg_pred, ghost_pred=None):
110+
"""Adapts the cluster labels for one entry or a batch of entries.
111+
112+
Parameters
113+
----------
114+
clust_label : Union[np.ndarray, torch.Tensor]
115+
(N, N_l) Cluster label tensor
116+
seg_label : Union[np.ndarray, torch.Tensor]
117+
(M, 5) Segmentation label tensor
118+
seg_pred : Union[np.ndarray, torch.Tensor]
119+
(M/N_deghost) Segmentation predictions for each voxel
120+
ghost_pred : Union[np.ndarray, torch.Tensor], optional
121+
(M) Ghost predictions for each voxel
122+
123+
Returns
124+
-------
125+
Union[np.ndarray, torch.Tensor]
126+
(N_deghost, N_l) Adapted cluster label tensor
127+
"""
128+
# If there are no points in this event, nothing to do
129+
coords = seg_label[:, :VALUE_COL]
130+
num_cols = clust_label.shape[1]
131+
if not len(coords):
132+
return self._ones((0, num_cols))
133+
134+
# If there are no points after deghosting, nothing to do
135+
if ghost_pred is not None:
136+
deghost_index = self._where(ghost_pred == 0)[0]
137+
if not len(deghost_index):
138+
return self._ones((0, num_cols))
139+
140+
# If there are no label points in this event, return dummy labels
141+
if not len(clust_label):
142+
if ghost_pred is None:
143+
shape = (len(coords), num_cols)
144+
dummy_labels = -self._ones(shape)
145+
dummy_labels[:, :VALUE_COL] = coords
146+
147+
else:
148+
shape = (len(deghost_index), num_cols)
149+
dummy_labels = -self._ones(shape)
150+
dummy_labels[:, :VALUE_COL] = coords[deghost_index]
151+
152+
return dummy_labels
153+
154+
# Build a tensor of predicted segmentation that includes ghost points
155+
seg_label = self._to_long(seg_label[:, SHAPE_COL])
156+
if ghost_pred is not None and (len(ghost_pred) != len(seg_pred)):
157+
seg_pred_long = self._to_long(GHOST_SHP*self._ones(len(coords)))
158+
seg_pred_long[deghost_index] = seg_pred
159+
seg_pred = seg_pred_long
160+
161+
# Prepare new labels
162+
new_label = -self._ones((len(coords), num_cols))
163+
new_label[:, :VALUE_COL] = coords
164+
165+
# Check if the segment labels and predictions are compatible. If they are
166+
# compatible, store the cluster labels as is. Track points do not mix
167+
# with other classes, but EM classes are allowed to.
168+
compat_mat = self._eye(GHOST_SHP + 1)
169+
compat_mat[([SHOWR_SHP, SHOWR_SHP, MICHL_SHP, DELTA_SHP],
170+
[MICHL_SHP, DELTA_SHP, SHOWR_SHP, SHOWR_SHP])] = True
171+
172+
true_deghost = seg_label < GHOST_SHP
173+
seg_mismatch = ~compat_mat[(seg_pred, seg_label)]
174+
new_label[true_deghost] = clust_label
175+
new_label[true_deghost & seg_mismatch, VALUE_COL:] = -self._ones(1)
176+
177+
# For mismatched predictions, attempt to find a touching instance of the
178+
# same class to assign it sensible cluster labels.
179+
for s in self._unique(seg_pred):
180+
# Skip predicted ghosts (they keep their invalid labels)
181+
if s == GHOST_SHP:
182+
continue
183+
184+
# Restrict to points in this class that have incompatible segment
185+
# labels. Track points do not mix, EM points are allowed to.
186+
bad_index = self._where(
187+
(seg_pred == s) & (~true_deghost | seg_mismatch))[0]
188+
if len(bad_index) == 0:
189+
continue
190+
191+
# Find points in clust_label that have compatible segment labels
192+
seg_clust_mask = compat_mat[s][self._to_long(clust_label[:, SHAPE_COL])]
193+
X_true = clust_label[seg_clust_mask]
194+
if len(X_true) == 0:
195+
continue
196+
197+
# Loop over the set of unlabeled predicted points
198+
X_pred = coords[bad_index]
199+
tagged_voxels_count = 1
200+
while tagged_voxels_count > 0 and len(X_pred) > 0:
201+
# Find the nearest neighbor to each predicted point
202+
closest_ids = self._compute_neighbor(X_pred, X_true)
203+
204+
# Compute Chebyshev distance between predicted and closest true.
205+
distances = self._compute_distances(X_pred, X_true[closest_ids])
206+
207+
# Label unlabeled voxels that touch a compatible true voxel
208+
select_mask = distances < 1.1
209+
select_index = self._where(select_mask)[0]
210+
tagged_voxels_count = len(select_index)
211+
if tagged_voxels_count > 0:
212+
# Use the label of the touching true voxel
213+
additional_clust_label = self._cat(
214+
[X_pred[select_index],
215+
X_true[closest_ids[select_index], VALUE_COL:]], 1)
216+
new_label[bad_index[select_index]] = additional_clust_label
217+
218+
# Update the mask to not include the new assigned points
219+
leftover_index = self._where(~select_mask)[0]
220+
bad_index = bad_index[leftover_index]
221+
222+
# The new true available points are the ones we just added.
223+
# The new pred points are those not yet labeled
224+
X_true = additional_clust_label
225+
X_pred = X_pred[leftover_index]
226+
227+
# Remove predicted ghost points from the labels, set the shape
228+
# column of the label to the segmentation predictions.
229+
if ghost_pred is not None:
230+
new_label = new_label[deghost_index]
231+
new_label[:, SHAPE_COL] = seg_pred[deghost_index]
232+
else:
233+
new_label[:, SHAPE_COL] = seg_pred
234+
235+
# Build a list of cluster indexes to break
236+
new_label_np = new_label
237+
if torch.is_tensor(new_label):
238+
new_label_np = new_label.detach().cpu().numpy()
239+
240+
clusts = []
241+
labels = new_label_np[:, CLUST_COL]
242+
shapes = new_label_np[:, SHAPE_COL]
243+
for break_class in self.break_classes:
244+
index_s = np.where(shapes == break_class)[0]
245+
labels_s = labels[index_s]
246+
for c in np.unique(labels_s):
247+
# If the cluster ID is invalid, skip
248+
if c < 0:
249+
continue
250+
251+
# Append cluster
252+
clusts.append(index_s[labels_s == c])
253+
254+
# Now if an instance was broken up, assign it different cluster IDs
255+
new_label[:, CLUST_COL] = break_clusters(
256+
new_label, clusts, self.break_eps, self.break_metric)
257+
258+
return new_label
259+
260+
def _where(self, x):
261+
if self.torch:
262+
return torch.where(x)
263+
else:
264+
return np.where(x)
265+
266+
def _cat(self, x, axis):
267+
if self.torch:
268+
return torch.cat(x, axis)
269+
else:
270+
return np.concatenate(x, axis)
271+
272+
def _ones(self, x):
273+
if self.torch:
274+
return torch.ones(x, dtype=self.dtype, device=self.device)
275+
else:
276+
return np.ones(x)
277+
278+
def _eye(self, x):
279+
if self.torch:
280+
return torch.eye(x, dtype=torch.bool, device=self.device)
281+
else:
282+
return np.eye(x, dtype=bool)
283+
284+
def _unique(self, x):
285+
if self.torch:
286+
return torch.unique(x).long()
287+
else:
288+
return np.unique(x).astype(np.int64)
289+
290+
def _to_long(self, x):
291+
if self.torch:
292+
return x.long()
293+
else:
294+
return x.astype(int64)
295+
296+
def _compute_neighbor(self, x, y):
297+
if self.torch:
298+
return knn(y[:, COORD_COLS], x[:, COORD_COLS], 1)[1]
299+
else:
300+
return cdist(x[:, COORD_COLS], y[:, COORD_COLS]).argmin(axis=1)
301+
302+
def _compute_distances(self, x, y):
303+
if self.torch:
304+
return torch.amax(torch.abs(y[:, COORD_COLS] - x[:, COORD_COLS]), dim=1)
305+
else:
306+
return np.amax(np.abs(x[:, COORD_COLS] - y[:, COORD_COLS]), axis=1)

0 commit comments

Comments
 (0)