diff --git a/bin/larcv_check_valid.py b/bin/larcv_check_valid.py index d41f4540..2905bfe5 100644 --- a/bin/larcv_check_valid.py +++ b/bin/larcv_check_valid.py @@ -44,7 +44,13 @@ def main(source, source_list, output): keys_list, unique_counts = [], [] for file_path in tqdm(source): # Count the number of entries in each tree - f = TFile(file_path) + try: + f = TFile(file_path) + except OSError: + keys_list.append([]) + unique_counts.append([]) + continue + keys = [key.GetName() for key in f.GetListOfKeys()] trees = [f.Get(key) for key in keys] num_entries = [tree.GetEntries() for tree in trees] diff --git a/spine/data/out/particle.py b/spine/data/out/particle.py index 3abddb27..4169697d 100644 --- a/spine/data/out/particle.py +++ b/spine/data/out/particle.py @@ -485,7 +485,7 @@ class TruthParticle(Particle, ParticleBase, TruthBase): orig_interaction_id: int = -1 orig_parent_id: int = -1 orig_group_id: int = -1 - orig_children_id: np.ndarray = -1 + orig_children_id: np.ndarray = None children_counts: np.ndarray = None reco_length: float = -1. reco_start_dir: np.ndarray = None diff --git a/spine/model/full_chain.py b/spine/model/full_chain.py index 04668502..8790dbfc 100644 --- a/spine/model/full_chain.py +++ b/spine/model/full_chain.py @@ -23,8 +23,8 @@ from spine.utils.calib import CalibrationManager from spine.utils.logger import logger from spine.utils.ppn import get_particle_points -from spine.utils.ghost import ( - compute_rescaled_charge_batch, adapt_labels_batch) +from spine.utils.ghost import compute_rescaled_charge_batch +from spine.utils.cluster.label import ClusterLabelAdapter from spine.utils.gnn.cluster import ( form_clusters_batch, get_cluster_label_batch) from spine.utils.gnn.evaluation import primary_assignment_batch @@ -173,8 +173,7 @@ def __init__(self, chain, uresnet_deghost=None, uresnet=None, self.uresnet_ppn = UResNetPPN(**uresnet_ppn) # Initialize the relabeling process (adapt to the semantic predictions) - # TODO: make this a class which holds onto these parameters? - self.adapt_params = adapt_labels if adapt_labels is not None else {} + self.label_adapter = ClusterLabelAdapter(**(adapt_labels or {})) # Initialize the dense clustering model self.fragment_shapes = [] @@ -495,9 +494,8 @@ def run_segmentation_ppn(self, data, seg_label=None, clust_label=None): if seg_label is not None and clust_label is not None: seg_pred = self.result['seg_pred'] ghost_pred = self.result.get('ghost_pred', None) - clust_label = adapt_labels_batch( - clust_label, seg_label, seg_pred, ghost_pred, - **self.adapt_params) + clust_label = self.label_adapter( + clust_label, seg_label, seg_pred, ghost_pred) self.result['clust_label_adapt'] = clust_label diff --git a/spine/utils/cluster/label.py b/spine/utils/cluster/label.py new file mode 100644 index 00000000..bfded18a --- /dev/null +++ b/spine/utils/cluster/label.py @@ -0,0 +1,306 @@ +"""Class which adapts clustering labels given upstream semantic predictions.""" + +import numpy as np +import torch +from torch_cluster import knn +from scipy.spatial.distance import cdist + +from spine.data import TensorBatch + +from spine.utils.gnn.cluster import form_clusters, break_clusters +from spine.utils.globals import ( + COORD_COLS, VALUE_COL, CLUST_COL, SHAPE_COL, SHOWR_SHP, TRACK_SHP, + MICHL_SHP, DELTA_SHP, GHOST_SHP) + +__all__ = ['ClusterLabelAdapter'] + + +class ClusterLabelAdapter: + """Adapts the cluster labels to account for the predicted semantics. + + Points wrongly predicted get the cluster label of the closest touching + cluster, if there is one. Points that are predicted as ghosts get invalid + (-1) cluster labels everywhere. + + Instances that have been broken up by the deghosting or semantic + segmentation process get assigned distinct cluster labels for each + effective fragment, provided they appearing in the `break_classes` list. + + Notes + ----- + This class supports both Numpy arrays and Torch tensors. + + It uses the GPU implementation from `torch_cluster.knn` to speed up the + label adaptation computation (instead of cdist). + + """ + + def __init__(self, break_eps=1.1, break_metric='chebyshev', + break_classes=[SHOWR_SHP, TRACK_SHP, MICHL_SHP, DELTA_SHP]): + """Initialize the adapter class. + + Parameters + ---------- + dtype : str, default 'torch' + Type of data to be processed through the label adapter + break_eps : float, default 1.1 + Distance scale used in the break up procedure + break_metric : str, default 'chebyshev' + Distance metric used in the break up produce + break_classes : List[int], default + [SHOWR_SHP, TRACK_SHP, MICHL_SHP, DELTA_SHP] + Classes to run DBSCAN on to break up + """ + # Store relevant parameters + self.break_eps = break_eps + self.break_metric = break_metric + self.break_classes = break_classes + + # Attributes used to fetch the correct functions + self.torch, self.dtype, self.device = None, None, None + + def __call__(self, clust_label, seg_label, seg_pred, ghost_pred=None): + """Adapts the cluster labels for one entry or a batch of entries. + + Parameters + ---------- + clust_label : Union[TensorBatch, np.ndarray, torch.Tensor] + (N, N_l) Cluster label tensor + seg_label : Union[TensorBatch, np.ndarray, torch.Tensor] + (M, 5) Segmentation label tensor + seg_pred : Union[TensorBatch, np.ndarray, torch.Tensor] + (M/N_deghost) Segmentation predictions for each voxel + ghost_pred : Union[TensorBatch, np.ndarray, torch.Tensor], optional + (M) Ghost predictions for each voxel + + Returns + ------- + Union[TensorBatch, np.ndarray, torch.Tensor] + (N_deghost, N_l) Adapted cluster label tensor + """ + # Set the data type/device based on the input + ref_tensor = clust_label + if isinstance(ref_tensor, TensorBatch): + ref_tensor = ref_tensor.tensor + self.torch = isinstance(ref_tensor, torch.Tensor) + + self.dtype = clust_label.dtype + if self.torch: + self.device = clust_label.device + + # Dispatch depending on the data type + if isinstance(clust_label, TensorBatch): + # If it is batch data, call the main process function of each entry + shape = (seg_pred.shape[0], clust_label.shape[1]) + clust_label_adapted = torch.empty( + shape, dtype=clust_label.dtype, device=clust_label.device) + for b in range(clust_label.batch_size): + lower, upper = seg_pred.edges[b], seg_pred.edges[b+1] + ghost_pred_b = ghost_pred[b] if ghost_pred is not None else None + clust_label_adapted[lower:upper] = self._process( + clust_label[b], seg_label[b], seg_pred[b], ghost_pred_b) + + return TensorBatch(clust_label_adapted, seg_pred.counts) + + else: + # Otherwise, call the main process function directly + return self._process(clust_label, seg_label, seg_pred, ghost_pred) + + def _process(self, clust_label, seg_label, seg_pred, ghost_pred=None): + """Adapts the cluster labels for one entry or a batch of entries. + + Parameters + ---------- + clust_label : Union[np.ndarray, torch.Tensor] + (N, N_l) Cluster label tensor + seg_label : Union[np.ndarray, torch.Tensor] + (M, 5) Segmentation label tensor + seg_pred : Union[np.ndarray, torch.Tensor] + (M/N_deghost) Segmentation predictions for each voxel + ghost_pred : Union[np.ndarray, torch.Tensor], optional + (M) Ghost predictions for each voxel + + Returns + ------- + Union[np.ndarray, torch.Tensor] + (N_deghost, N_l) Adapted cluster label tensor + """ + # If there are no points in this event, nothing to do + coords = seg_label[:, :VALUE_COL] + num_cols = clust_label.shape[1] + if not len(coords): + return self._ones((0, num_cols)) + + # If there are no points after deghosting, nothing to do + if ghost_pred is not None: + deghost_index = self._where(ghost_pred == 0)[0] + if not len(deghost_index): + return self._ones((0, num_cols)) + + # If there are no label points in this event, return dummy labels + if not len(clust_label): + if ghost_pred is None: + shape = (len(coords), num_cols) + dummy_labels = -self._ones(shape) + dummy_labels[:, :VALUE_COL] = coords + + else: + shape = (len(deghost_index), num_cols) + dummy_labels = -self._ones(shape) + dummy_labels[:, :VALUE_COL] = coords[deghost_index] + + return dummy_labels + + # Build a tensor of predicted segmentation that includes ghost points + seg_label = self._to_long(seg_label[:, SHAPE_COL]) + if ghost_pred is not None and (len(ghost_pred) != len(seg_pred)): + seg_pred_long = self._to_long(GHOST_SHP*self._ones(len(coords))) + seg_pred_long[deghost_index] = seg_pred + seg_pred = seg_pred_long + + # Prepare new labels + new_label = -self._ones((len(coords), num_cols)) + new_label[:, :VALUE_COL] = coords + + # Check if the segment labels and predictions are compatible. If they are + # compatible, store the cluster labels as is. Track points do not mix + # with other classes, but EM classes are allowed to. + compat_mat = self._eye(GHOST_SHP + 1) + compat_mat[([SHOWR_SHP, SHOWR_SHP, MICHL_SHP, DELTA_SHP], + [MICHL_SHP, DELTA_SHP, SHOWR_SHP, SHOWR_SHP])] = True + + true_deghost = seg_label < GHOST_SHP + seg_mismatch = ~compat_mat[(seg_pred, seg_label)] + new_label[true_deghost] = clust_label + new_label[true_deghost & seg_mismatch, VALUE_COL:] = -self._ones(1) + + # For mismatched predictions, attempt to find a touching instance of the + # same class to assign it sensible cluster labels. + for s in self._unique(seg_pred): + # Skip predicted ghosts (they keep their invalid labels) + if s == GHOST_SHP: + continue + + # Restrict to points in this class that have incompatible segment + # labels. Track points do not mix, EM points are allowed to. + bad_index = self._where( + (seg_pred == s) & (~true_deghost | seg_mismatch))[0] + if len(bad_index) == 0: + continue + + # Find points in clust_label that have compatible segment labels + seg_clust_mask = compat_mat[s][self._to_long(clust_label[:, SHAPE_COL])] + X_true = clust_label[seg_clust_mask] + if len(X_true) == 0: + continue + + # Loop over the set of unlabeled predicted points + X_pred = coords[bad_index] + tagged_voxels_count = 1 + while tagged_voxels_count > 0 and len(X_pred) > 0: + # Find the nearest neighbor to each predicted point + closest_ids = self._compute_neighbor(X_pred, X_true) + + # Compute Chebyshev distance between predicted and closest true. + distances = self._compute_distances(X_pred, X_true[closest_ids]) + + # Label unlabeled voxels that touch a compatible true voxel + select_mask = distances < 1.1 + select_index = self._where(select_mask)[0] + tagged_voxels_count = len(select_index) + if tagged_voxels_count > 0: + # Use the label of the touching true voxel + additional_clust_label = self._cat( + [X_pred[select_index], + X_true[closest_ids[select_index], VALUE_COL:]], 1) + new_label[bad_index[select_index]] = additional_clust_label + + # Update the mask to not include the new assigned points + leftover_index = self._where(~select_mask)[0] + bad_index = bad_index[leftover_index] + + # The new true available points are the ones we just added. + # The new pred points are those not yet labeled + X_true = additional_clust_label + X_pred = X_pred[leftover_index] + + # Remove predicted ghost points from the labels, set the shape + # column of the label to the segmentation predictions. + if ghost_pred is not None: + new_label = new_label[deghost_index] + new_label[:, SHAPE_COL] = seg_pred[deghost_index] + else: + new_label[:, SHAPE_COL] = seg_pred + + # Build a list of cluster indexes to break + new_label_np = new_label + if torch.is_tensor(new_label): + new_label_np = new_label.detach().cpu().numpy() + + clusts = [] + labels = new_label_np[:, CLUST_COL] + shapes = new_label_np[:, SHAPE_COL] + for break_class in self.break_classes: + index_s = np.where(shapes == break_class)[0] + labels_s = labels[index_s] + for c in np.unique(labels_s): + # If the cluster ID is invalid, skip + if c < 0: + continue + + # Append cluster + clusts.append(index_s[labels_s == c]) + + # Now if an instance was broken up, assign it different cluster IDs + new_label[:, CLUST_COL] = break_clusters( + new_label, clusts, self.break_eps, self.break_metric) + + return new_label + + def _where(self, x): + if self.torch: + return torch.where(x) + else: + return np.where(x) + + def _cat(self, x, axis): + if self.torch: + return torch.cat(x, axis) + else: + return np.concatenate(x, axis) + + def _ones(self, x): + if self.torch: + return torch.ones(x, dtype=self.dtype, device=self.device) + else: + return np.ones(x) + + def _eye(self, x): + if self.torch: + return torch.eye(x, dtype=torch.bool, device=self.device) + else: + return np.eye(x, dtype=bool) + + def _unique(self, x): + if self.torch: + return torch.unique(x).long() + else: + return np.unique(x).astype(np.int64) + + def _to_long(self, x): + if self.torch: + return x.long() + else: + return x.astype(int64) + + def _compute_neighbor(self, x, y): + if self.torch: + return knn(y[:, COORD_COLS], x[:, COORD_COLS], 1)[1] + else: + return cdist(x[:, COORD_COLS], y[:, COORD_COLS]).argmin(axis=1) + + def _compute_distances(self, x, y): + if self.torch: + return torch.amax(torch.abs(y[:, COORD_COLS] - x[:, COORD_COLS]), dim=1) + else: + return np.amax(np.abs(x[:, COORD_COLS] - y[:, COORD_COLS]), axis=1) diff --git a/spine/utils/ghost.py b/spine/utils/ghost.py index 23f4f837..12795ac5 100644 --- a/spine/utils/ghost.py +++ b/spine/utils/ghost.py @@ -2,14 +2,10 @@ import numpy as np import torch -from scipy.spatial.distance import cdist from spine.data import TensorBatch -from spine.utils.numba_local import dbscan -from .globals import ( - COORD_COLS, VALUE_COL, CLUST_COL, SHAPE_COL, SHOWR_SHP, TRACK_SHP, - MICHL_SHP, DELTA_SHP, GHOST_SHP) +from .globals import SHOWR_SHP, TRACK_SHP, MICHL_SHP, DELTA_SHP def compute_rescaled_charge_batch(data, collection_only=False, collection_id=2): @@ -38,47 +34,6 @@ def compute_rescaled_charge_batch(data, collection_only=False, collection_id=2): return charges -def adapt_labels_batch(clust_label, seg_label, seg_pred, ghost_pred=None, - break_classes=[SHOWR_SHP,TRACK_SHP,MICHL_SHP,DELTA_SHP], - break_eps=1.1, break_metric='chebyshev'): - """Batched version of :func:`adapt_labels`. - - Parameters - ---------- - clust_label : TensorBatch - (N, N_l) Cluster label tensor - seg_label : TensorBatch - (M, 5) Segmentation label tensor - seg_pred : TensorBatch - (M/N_deghost) Segmentation predictions for each voxel - ghost_pred : TensorBatch, optional - (M) Ghost predictions for each voxel - break_classes : List[int], default - [SHOWR_SHP, TRACK_SHP, MICHL_SHP, DELTA_SHP] - Classes to run DBSCAN on to break up - break_eps : float, default 1.1 - Distance scale used in the break up procedure - break_metric : str, default 'chebyshev' - Distance metric used in the break up produce - - Returns - ------- - TensorBatch - (N_deghost, N_l) Adapted cluster label tensor - """ - shape = (seg_pred.shape[0], clust_label.shape[1]) - clust_label_adapted = torch.empty( - shape, dtype=clust_label.dtype, device=clust_label.device) - for b in range(clust_label.batch_size): - lower, upper = seg_pred.edges[b], seg_pred.edges[b+1] - ghost_pred_b = ghost_pred[b] if ghost_pred is not None else None - clust_label_adapted[lower:upper] = adapt_labels( - clust_label[b], seg_label[b], seg_pred[b], - ghost_pred_b, break_classes, break_eps, break_metric) - - return TensorBatch(clust_label_adapted, seg_pred.counts) - - def compute_rescaled_charge(data, collection_only=False, collection_id=2): """Computes rescaled charge after deghosting. @@ -132,211 +87,3 @@ def compute_rescaled_charge(data, collection_only=False, collection_id=2): charges = hit_charges[:, collection_id]/multiplicity[:, collection_id] return charges - - -def adapt_labels(clust_label, seg_label, seg_pred, ghost_pred=None, - break_classes=[SHOWR_SHP,TRACK_SHP,MICHL_SHP,DELTA_SHP], - break_eps=1.1, break_metric='chebyshev'): - """Adapts the cluster labels to account for the predicted semantics. - - Points wrongly predicted get the cluster label of the closest touching - cluster, if there is one. Points that are predicted as ghosts get invalid - (-1) cluster labels everywhere. - - Instances that have been broken up by the deghosting process get assigned - distinct cluster labels for each effective fragment. - - Notes - ----- - This function should work on Numpy arrays or Torch tensors. - - Uses GPU version from `torch_cluster.knn` to speed up the label adaptation - computation. - - Parameters - ---------- - clust_label : Union[np.ndarray, torch.Tensor] - (N, N_l) Cluster label tensor - seg_label : List[Union[np.ndarray, torch.Tensor]] - (M, 5) Segmentation label tensor - seg_pred : Union[np.ndarray, torch.Tensor] - (M/N_deghost) Segmentation predictions for each voxel - ghost_pred : Union[np.ndarray, torch.Tensor], optional - (M) Ghost predictions for each voxel - break_classes : List[int], default - [SHOWR_SHP, TRACK_SHP, MICHL_SHP, DELTA_SHP] - Classes to run DBSCAN on to break up - break_eps : float, default 1.1 - Distance scale used in the break up procedure - break_metric : str, default 'chebyshev' - Distance metric used in the break up produce - - Returns - ------- - Union[np.ndarray, torch.Tensor] - (N_deghost, N_l) Adapted cluster label tensor - """ - # Define operations on the basis of the input type - if torch.is_tensor(seg_label): - from torch_cluster import knn # TODO: not PEP8 compliant, refactor - dtype, device = clust_label.dtype, clust_label.device - where, cat, argmax = torch.where, torch.cat, torch.amax - ones = lambda x: torch.ones(x, dtype=dtype, device=device) - eye = lambda x: torch.eye(x, dtype=torch.bool, device=device) - unique = lambda x: torch.unique(x).long() - to_long = lambda x: x.long() - to_bool = lambda x: x.bool() - compute_neighbor = lambda x, y: knn( - y[:, COORD_COLS], x[:, COORD_COLS], 1)[1] - compute_distances = lambda x, y: torch.amax( - torch.abs(y[:, COORD_COLS] - x[:, COORD_COLS]), dim=1) - - else: - where, cat, argmax = np.where, np.concatenate, np.argmax - ones = lambda x: np.ones(x, dtype=clust_label.dtype) - eye = lambda x: np.eye(x, dtype=bool) - unique = lambda x: np.unique(x).astype(np.int64) - to_long = lambda x: x.astype(np.int64) - to_bool = lambda x: x.astype(bool) - compute_neighbor = lambda x, y: cdist( - x[:, COORD_COLS], y[:, COORD_COLS]).argmin(axis=1) - compute_distances = lambda x, y: np.amax( - np.abs(x[:, COORD_COLS] - y[:, COORD_COLS]), axis=1) - - # If there are no points in this event, nothing to do - coords = seg_label[:, :VALUE_COL] - num_cols = clust_label.shape[1] - if not len(coords): - return ones((0, num_cols)) - - # If there are no points after deghosting, nothing to do - if ghost_pred is not None: - deghost_index = where(ghost_pred == 0)[0] - if not len(deghost_index): - return ones((0, num_cols)) - - # If there are no label poins in this event, return dummy labels - if not len(clust_label): - if ghost_pred is None: - shape = (len(coords), num_cols) - dummy_labels = -1 * ones(shape) - dummy_labels[:, :VALUE_COL] = coords - - else: - shape = (len(deghost_index), num_cols) - dummy_labels = -1 * ones(shape) - dummy_labels[:, :VALUE_COL] = coords[deghost_index] - - return dummy_labels - - # Build a tensor of predicted segmentation that includes ghost points - seg_label = to_long(seg_label[:, SHAPE_COL]) - if ghost_pred is not None and (len(ghost_pred) != len(seg_pred)): - seg_pred_long = to_long(GHOST_SHP*ones(len(coords))) - seg_pred_long[deghost_index] = seg_pred - seg_pred = seg_pred_long - - # Prepare new labels - new_label = -1. * ones((len(coords), num_cols)) - new_label[:, :VALUE_COL] = coords - - # Check if the segment labels and predictions are compatible. If they are - # compatible, store the cluster labels as is. Track points do not mix - # with other classes, but EM classes are allowed to. - compat_mat = eye(GHOST_SHP + 1) - compat_mat[([SHOWR_SHP, SHOWR_SHP, MICHL_SHP, DELTA_SHP], - [MICHL_SHP, DELTA_SHP, SHOWR_SHP, SHOWR_SHP])] = True - - true_deghost = seg_label < GHOST_SHP - seg_mismatch = ~compat_mat[(seg_pred, seg_label)] - new_label[true_deghost] = clust_label - new_label[true_deghost & seg_mismatch, VALUE_COL:] = -1. - - # For mismatched predictions, attempt to find a touching instance of the - # same class to assign it sensible cluster labels. - for s in unique(seg_pred): - # Skip predicted ghosts (they keep their invalid labels) - if s == GHOST_SHP: - continue - - # Restrict to points in this class that have incompatible segment - # labels. Track points do not mix, EM points are allowed to. - bad_index = where((seg_pred == s) & (~true_deghost | seg_mismatch))[0] - if len(bad_index) == 0: - continue - - # Find points in clust_label that have compatible segment labels - seg_clust_mask = compat_mat[s][to_long(clust_label[:, SHAPE_COL])] - X_true = clust_label[seg_clust_mask] - if len(X_true) == 0: - continue - - # Loop over the set of unlabeled predicted points - X_pred = coords[bad_index] - tagged_voxels_count = 1 - while tagged_voxels_count > 0 and len(X_pred) > 0: - # Find the nearest neighbor to each predicted point - closest_ids = compute_neighbor(X_pred, X_true) - - # Compute Chebyshev distance between predicted and closest true. - distances = compute_distances(X_pred, X_true[closest_ids]) - - # Label unlabeled voxels that touch a compatible true voxel - select_mask = distances <= 1 - select_index = where(select_mask)[0] - tagged_voxels_count = len(select_index) - if tagged_voxels_count > 0: - # Use the label of the touching true voxel - additional_clust_label = cat( - [X_pred[select_index], - X_true[closest_ids[select_index], VALUE_COL:]], 1) - new_label[bad_index[select_index]] = additional_clust_label - - # Update the mask to not include the new assigned points - leftover_index = where(~select_mask)[0] - bad_index = bad_index[leftover_index] - - # The new true available points are the ones we just added. - # The new pred points are those not yet labeled - X_true = additional_clust_label - X_pred = X_pred[leftover_index] - - # Remove predicted ghost points from the labels, set the shape - # column of the label to the segmentation predictions. - if ghost_pred is not None: - new_label = new_label[deghost_index] - new_label[:, SHAPE_COL] = seg_pred[deghost_index] - else: - new_label[:, SHAPE_COL] = seg_pred - - # Now if an instance was broken up, assign it different cluster IDs - cluster_count = int(clust_label[:, CLUST_COL].max()) + 1 - for break_class in break_classes: - # Restrict to the set of labels associated with this class - break_index = where(new_label[:, SHAPE_COL] == break_class)[0] - restricted_label = new_label[break_index] - restricted_coordinates = restricted_label[:, COORD_COLS] - - # Loop over true cluster instances in the new label tensor, break - for c in unique(restricted_label[:, CLUST_COL]): - # Skip invalid cluster ID - if c < 0: - continue - - # Restrict tensor to a specific cluster, get voxel coordinates - cluster_index = where(restricted_label[:, CLUST_COL] == c)[0] - coordinates = restricted_coordinates[cluster_index] - if torch.is_tensor(coordinates): - coordinates = coordinates.detach().cpu().numpy() - - # Run DBSCAN on the cluster, update labels - break_labels = dbscan( - coordinates, eps=break_eps, metric=break_metric) - break_labels += cluster_count - if torch.is_tensor(new_label): - break_labels = torch.tensor(break_labels, - dtype=new_label.dtype, device=new_label.device) - new_label[break_index[cluster_index], CLUST_COL] = break_labels - cluster_count = int(break_labels.max()) + 1 - - return new_label diff --git a/spine/utils/gnn/cluster.py b/spine/utils/gnn/cluster.py index 2d958470..05f3bd54 100644 --- a/spine/utils/gnn/cluster.py +++ b/spine/utils/gnn/cluster.py @@ -340,6 +340,66 @@ def form_clusters(data, min_size=-1, column=CLUST_COL, shapes=None): return clusts, counts +@numbafy(cast_args=['data'], list_args=['clusts'], + keep_torch=True, ref_arg='data') +def break_clusters(data, clusts, eps, metric): + """Runs DBSCAN on each invididual cluster to segment them further if needed. + + Parameters + ---------- + data : np.ndarray + Cluster label data tensor + clusts : List[np.ndarray] + (C) List of cluster indexes + eps : float + DBSCAN clustering distance scale + metric : str + DBSCAN clustering distance metric + + Returns + ------- + np.ndarray + New array of broken cluster labels + """ + if not len(clusts): + return np.copy(data[:, CLUST_COL]) + + # Break labels + break_labels = _break_clusters(data, clusts, eps, metric) + + # Offset individual broken labels to prevent overlap + labels = np.copy(data[:, CLUST_COL]) + offset = np.max(labels) + 1 + for k, clust in enumerate(clusts): + # Update IDs, offset + ids = break_labels[clust] + labels[clust] = offset + ids + offset += len(np.unique(ids)) + + return labels + +@nb.njit(cache=True, parallel=True, nogil=True) +def _break_clusters(data: nb.float64[:,:], + clusts: nb.types.List(nb.int64[:]), + eps: nb.float64, + metric: str) -> nb.float64[:]: + # Loop over clusters to break, run DBSCAN + break_labels = np.full(len(data), -1, dtype=data.dtype) + points = data[:, COORD_COLS] + for k in nb.prange(len(clusts)): + # Restrict the points to those in the cluster + clust = clusts[k] + points_c = points[clust] + + # Run DBSCAN on the cluster, update labels + clust_ids = nbl.dbscan(points_c, eps=eps, metric=metric) + + # Store the breaking IDs + break_labels[clust] = clust_ids + + return break_labels + + @numbafy(cast_args=['data'], list_args=['clusts']) def get_cluster_label(data, clusts, column=CLUST_COL): """Returns the majority label of each cluster, specified by the @@ -486,7 +546,7 @@ def _get_cluster_primary_label(data: nb.float64[:,:], if len(primary_mask): # Only use the primary component to label the cluster v, cts = nbl.unique(data[clusts[i][primary_mask], column]) - else: + else: # If there is no primary contribution, use the whole cluster v, cts = nbl.unique(data[clusts[i], column]) labels[i] = v[np.argmax(cts)] @@ -851,7 +911,7 @@ def _get_cluster_features_extended(data: nb.float64[:,:], keep_torch=True, ref_arg='data') def get_cluster_points_label(data, coord_label, clusts, random_order=True): """Gets label points for each cluster. - + Returns start point of primary shower fragment twice if shower, delta or Michel and both end points of tracks if track. @@ -908,7 +968,7 @@ def _get_cluster_points_label(data: nb.float64[:,:], return points -@numbafy(cast_args=['data', 'starts'], list_args=['clusts'], +@numbafy(cast_args=['data', 'starts'], list_args=['clusts'], keep_torch=True, ref_arg='data') def get_cluster_directions(data, starts, clusts, max_dist=-1, optimize=False): """Estimates the direction of each cluster. @@ -1224,14 +1284,14 @@ def cluster_dedx_dir(voxels: nb.float64[:,:], vectors_to_axis = voxels_sp - np.outer(voxels_proj, start_dir) spreads = np.sqrt(np.sum(vectors_to_axis**2, axis=1)) spread = np.sum(spreads)/len(index) - + return dE/dx, dE, dx, spread, len(index) @numbafy(cast_args=['data'], list_args=['clusts'], keep_torch=True, ref_arg='data') def get_cluster_start_points(data, clusts): - """Estimates the start point of clusters based on their PCA and the + """Estimates the start point of clusters based on their PCA and the local curvature at each of the PCA extrema. Parameters diff --git a/spine/utils/gnn/evaluation.py b/spine/utils/gnn/evaluation.py index 9856da55..d76db8fd 100644 --- a/spine/utils/gnn/evaluation.py +++ b/spine/utils/gnn/evaluation.py @@ -287,9 +287,9 @@ def primary_assignment_batch(node_pred, group_ids=None): def edge_assignment(edge_index, group_ids): - """Determines which edges are turned on based on the group ID of the - clusters they are connecting. - + """Determines which edges are turned on based on the group ID of the + clusters they are connecting. + Parameters ---------- edge_index: np.ndarray @@ -302,7 +302,7 @@ def edge_assignment(edge_index, group_ids): np.ndarray: (E) Array specifying on/off edges """ - # Set the edge as true if it connects two nodes that belong to the same + # Set the edge as true if it connects two nodes that belong to the same # entry (free; no edges between entries) and the same group mask = (group_ids[edge_index[:, 0]] == group_ids[edge_index[:, 1]]) @@ -332,13 +332,13 @@ def edge_assignment_from_graph(edge_index, true_edge_index, part_ids): # Compare with the reference sparse incidence matrix compare_index = lambda x, y: (x.T == y[..., None]).all(axis=1).any(axis=1) - + return compare_index(edge_index_part, true_edge_index) def edge_assignment_forest(edge_index, edge_pred, group_ids): """Determines which edges must be turned on based on to form a - minimum-spanning tree (MST) for each node group. + minimum-spanning tree (MST) for each node group. For each group, find the most likely spanning tree, label the edges in the tree as 1. For all other edges, apply loss only if in separate groups. If @@ -389,43 +389,6 @@ def edge_assignment_forest(edge_index, edge_pred, group_ids): return edge_assn, edge_valid -@nb.njit(cache=True) -def union_find(edge_index: nb.int64[:,:], - num_nodes: nb.int64) -> ( - nb.int64[:], nb.types.DictType(nb.int64, nb.int64[:])): - """Implementation of the Union-Find algorithm. - - This algorithm forms group based on the connectivity of its consistuents. - If two entities are connected, they belong to the same group. - - Parameters - ---------- - edge_index : np.ndarray - (E, 2) Sparse incidence matrix - num_nodes : int - Number of nodes in the graph, C - - Returns - ------- - np.ndarray - (C) Node group IDs - Dict[int, np.ndarray] - Dictionary which maps group IDs onto constituent cluster IDs - """ - # Find the group_ids by merging groups when they are connected - group_ids = np.arange(n, dtype=np.int64) - for e in edge_index: - if group_ids[e[0]] != group_ids[e[1]]: - group_ids[group_ids == group_ids[e[1]]] = group_ids[e[0]] - - # Build group dictionary - groups = nb.typed.Dict.empty(nb.int64, int_array) - for g in np.unique(group_ids): - groups[g] = np.where(group_ids == g)[0] - - return group_ids, groups - - @nb.njit(cache=True) def node_assignment(edge_index: nb.int64[:,:], edge_pred: nb.int64[:,:], @@ -451,7 +414,7 @@ def node_assignment(edge_index: nb.int64[:,:], # Loop over on edges, reset the group IDs of connected node on_edges = edge_index[np.where(edge_pred[:, 1] > edge_pred[:, 0])[0]] - return union_find(on_edges, num_nodes)[0] + return nbl.union_find(on_edges, num_nodes, return_inverse=True)[0] @nb.njit(cache=True) @@ -460,7 +423,7 @@ def node_assignment_bipartite(edge_index: nb.int64[:,:], primaries: nb.int64[:], num_nodes: nb.int64) -> nb.int64[:]: """Assigns each node to a group represented by a primary node. - + This function loops over secondaries and associates it to the primary with that is connected to it with the strongest edge. @@ -616,7 +579,7 @@ def edge_assignment_score(edge_index: nb.int64[:,:], """ # If there is no edge, do not bother if not len(edge_index): - return (np.empty((0,2), dtype=np.int64), + return (np.empty((0,2), dtype=np.int64), np.arange(num_nodes, dtype=np.int64), 0.) # Build an input adjacency matrix to constrain the edge selection to @@ -734,7 +697,7 @@ def node_purity_mask(group_ids: nb.int64[:], ill-defined. Note: It is possible that the single true primary has been broken into - several nodes. In that case, the primary is also ambiguous, skip. + several nodes. In that case, the primary is also ambiguous, skip. TODO: pick the most sensible primary in that case, too restrictive otherwise (complicated, though). @@ -765,8 +728,8 @@ def edge_purity_mask(edge_index: nb.int64[:,:], group_ids: nb.int64[:], primary_ids: nb.int64[:]) -> nb.boolean[:]: """Creates a mask that is `True` only for edges which connect two nodes - that both belong to a common group which has a single clear primary. - + that both belong to a common group which has a single clear primary. + This is useful for shower clustering only, for which there can be no or multiple primaries in the group, making the the edge classification ill-defined (no primary typically indicates a shower which originates diff --git a/spine/utils/numba_local.py b/spine/utils/numba_local.py index fccab88c..8804e585 100644 --- a/spine/utils/numba_local.py +++ b/spine/utils/numba_local.py @@ -394,10 +394,11 @@ def pdist(x: nb.float32[:,:], """ # Initialize the return matrix assert x.shape[1] == 3, "Only supports 3D points for now." - res = np.zeros((len(x), len(x)), dtype=x.dtype) + res = np.empty((len(x), len(x)), dtype=x.dtype) if metric == 'euclidean': for i in range(x.shape[0]): + res[i, i] = 0. for j in range(i+1, x.shape[0]): res[i, j] = res[j, i] = np.sqrt( (x[i][0] - x[j][0])**2 + @@ -406,6 +407,7 @@ def pdist(x: nb.float32[:,:], elif metric == 'cityblock': for i in range(x.shape[0]): + res[i, i] = 0. for j in range(i+1, x.shape[0]): res[i, j] = res[j, i] = ( abs(x[i][0] - x[j][0]) + @@ -414,10 +416,11 @@ def pdist(x: nb.float32[:,:], elif metric == 'chebyshev': for i in range(x.shape[0]): + res[i, i] = 0. for j in range(i+1, x.shape[0]): res[i, j] = res[j, i] = max( - max(abs(x[i][0] - x[j][0]), - abs(x[i][1] - x[j][1])), + abs(x[i][0] - x[j][0]), + abs(x[i][1] - x[j][1]), abs(x[i][2] - x[j][2])) else: @@ -470,6 +473,73 @@ def cdist(x1: nb.float32[:,:], return res +@nb.njit(cache=True) +def radius_graph(x: nb.float32[:,:], + radius: nb.float32, + metric: str = 'euclidean') -> nb.float32[:,:]: + """Numba implementation of a radius-graph construction. + + This function generates a list of edges in a graph which connects all nodes + within some radius R of each other. + + Parameters + ---------- + x : np.ndarray + (N, 3) array of node coordinates + radius : float + Radius within which to build connections in the graph + metric : str, default 'euclidean' + Distance metric + + Returns + ------- + np.ndarray + (E, 2) array of edges in the radius graph + """ + # Initialize an empty list of edges to add to the graph + assert x.shape[1] == 3, "Only supports 3D points for now." + edges = nb.typed.List.empty_list(nb.int64[:]) + + if metric == 'euclidean': + for i in range(x.shape[0]): + for j in range(i+1, x.shape[0]): + dist = np.sqrt( + (x[i][0] - x[j][0])**2 + + (x[i][1] - x[j][1])**2 + + (x[i][2] - x[j][2])**2) + if dist <= radius: + edges.append(np.array([i, j])) + + elif metric == 'cityblock': + for i in range(x.shape[0]): + for j in range(i+1, x.shape[0]): + dist = ( + abs(x[i][0] - x[j][0]) + + abs(x[i][1] - x[j][1]) + + abs(x[i][2] - x[j][2])) + if dist <= radius: + edges.append(np.array([i, j])) + + elif metric == 'chebyshev': + for i in range(x.shape[0]): + for j in range(i+1, x.shape[0]): + dist = max( + abs(x[i][0] - x[j][0]), + abs(x[i][1] - x[j][1]), + abs(x[i][2] - x[j][2])) + if dist <= radius: + edges.append(np.array([i, j])) + + else: + raise ValueError("Distance metric not recognized.") + + edge_index = np.empty((len(edges), 2), dtype=np.int64) + for i, e in enumerate(edges): + edge_index[i] = e + + return edge_index + + @nb.njit(cache=True) def union_find(edge_index: nb.int64[:,:], count: nb.int64, @@ -492,11 +562,17 @@ def union_find(edge_index: nb.int64[:,:], ------- np.ndarray (C) Group assignments for each of the nodes in the graph + Dict[int, np.ndarray] + Dictionary which maps groups to indexes """ labels = np.arange(count) + groups = {i: np.array([i]) for i in labels} for e in edge_index: - if labels[e[0]] != labels[e[1]]: - labels[labels == labels[e[1]]] = labels[e[0]] + li, lj = labels[e[0]], labels[e[1]] + if li != lj: + labels[groups[lj]] = li + groups[li] = np.concatenate((groups[li], groups[lj])) + del groups[lj] if return_inverse: mask = np.zeros(count, dtype=np.bool_) @@ -505,7 +581,7 @@ def union_find(edge_index: nb.int64[:,:], mapping[mask] = np.arange(np.sum(mask)) labels = mapping[labels] - return labels + return labels, groups @nb.njit(cache=True) @@ -533,10 +609,10 @@ def dbscan(x: nb.float32[:, :], (N) Group assignments """ # Produce a sparse adjacency matrix (edge index) - edges = np.vstack(np.where(pdist(x, metric) < eps)).T - + edge_index = radius_graph(x, eps, metric) + # Build groups - return union_find(edges, len(x), return_inverse=True) + return union_find(edge_index, len(x), return_inverse=True)[0] @nb.njit(cache=True)