Skip to content
Merged
8 changes: 7 additions & 1 deletion bin/larcv_check_valid.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,13 @@ def main(source, source_list, output):
keys_list, unique_counts = [], []
for file_path in tqdm(source):
# Count the number of entries in each tree
f = TFile(file_path)
try:
f = TFile(file_path)
except OSError:
keys_list.append([])
unique_counts.append([])
continue

keys = [key.GetName() for key in f.GetListOfKeys()]
trees = [f.Get(key) for key in keys]
num_entries = [tree.GetEntries() for tree in trees]
Expand Down
2 changes: 1 addition & 1 deletion spine/data/out/particle.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ class TruthParticle(Particle, ParticleBase, TruthBase):
orig_interaction_id: int = -1
orig_parent_id: int = -1
orig_group_id: int = -1
orig_children_id: np.ndarray = -1
orig_children_id: np.ndarray = None
children_counts: np.ndarray = None
reco_length: float = -1.
reco_start_dir: np.ndarray = None
Expand Down
12 changes: 5 additions & 7 deletions spine/model/full_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from spine.utils.calib import CalibrationManager
from spine.utils.logger import logger
from spine.utils.ppn import get_particle_points
from spine.utils.ghost import (
compute_rescaled_charge_batch, adapt_labels_batch)
from spine.utils.ghost import compute_rescaled_charge_batch
from spine.utils.cluster.label import ClusterLabelAdapter
from spine.utils.gnn.cluster import (
form_clusters_batch, get_cluster_label_batch)
from spine.utils.gnn.evaluation import primary_assignment_batch
Expand Down Expand Up @@ -173,8 +173,7 @@ def __init__(self, chain, uresnet_deghost=None, uresnet=None,
self.uresnet_ppn = UResNetPPN(**uresnet_ppn)

# Initialize the relabeling process (adapt to the semantic predictions)
# TODO: make this a class which holds onto these parameters?
self.adapt_params = adapt_labels if adapt_labels is not None else {}
self.label_adapter = ClusterLabelAdapter(**(adapt_labels or {}))

# Initialize the dense clustering model
self.fragment_shapes = []
Expand Down Expand Up @@ -495,9 +494,8 @@ def run_segmentation_ppn(self, data, seg_label=None, clust_label=None):
if seg_label is not None and clust_label is not None:
seg_pred = self.result['seg_pred']
ghost_pred = self.result.get('ghost_pred', None)
clust_label = adapt_labels_batch(
clust_label, seg_label, seg_pred, ghost_pred,
**self.adapt_params)
clust_label = self.label_adapter(
clust_label, seg_label, seg_pred, ghost_pred)

self.result['clust_label_adapt'] = clust_label

Expand Down
306 changes: 306 additions & 0 deletions spine/utils/cluster/label.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,306 @@
"""Class which adapts clustering labels given upstream semantic predictions."""

import numpy as np
import torch
from torch_cluster import knn
from scipy.spatial.distance import cdist

from spine.data import TensorBatch

from spine.utils.gnn.cluster import form_clusters, break_clusters
from spine.utils.globals import (
COORD_COLS, VALUE_COL, CLUST_COL, SHAPE_COL, SHOWR_SHP, TRACK_SHP,
MICHL_SHP, DELTA_SHP, GHOST_SHP)

__all__ = ['ClusterLabelAdapter']


class ClusterLabelAdapter:
"""Adapts the cluster labels to account for the predicted semantics.

Points wrongly predicted get the cluster label of the closest touching
cluster, if there is one. Points that are predicted as ghosts get invalid
(-1) cluster labels everywhere.

Instances that have been broken up by the deghosting or semantic
segmentation process get assigned distinct cluster labels for each
effective fragment, provided they appearing in the `break_classes` list.

Notes
-----
This class supports both Numpy arrays and Torch tensors.

It uses the GPU implementation from `torch_cluster.knn` to speed up the
label adaptation computation (instead of cdist).

"""

def __init__(self, break_eps=1.1, break_metric='chebyshev',
break_classes=[SHOWR_SHP, TRACK_SHP, MICHL_SHP, DELTA_SHP]):
"""Initialize the adapter class.

Parameters
----------
dtype : str, default 'torch'
Type of data to be processed through the label adapter
break_eps : float, default 1.1
Distance scale used in the break up procedure
break_metric : str, default 'chebyshev'
Distance metric used in the break up produce
break_classes : List[int], default
[SHOWR_SHP, TRACK_SHP, MICHL_SHP, DELTA_SHP]
Classes to run DBSCAN on to break up
"""
# Store relevant parameters
self.break_eps = break_eps
self.break_metric = break_metric
self.break_classes = break_classes

# Attributes used to fetch the correct functions
self.torch, self.dtype, self.device = None, None, None

def __call__(self, clust_label, seg_label, seg_pred, ghost_pred=None):
"""Adapts the cluster labels for one entry or a batch of entries.

Parameters
----------
clust_label : Union[TensorBatch, np.ndarray, torch.Tensor]
(N, N_l) Cluster label tensor
seg_label : Union[TensorBatch, np.ndarray, torch.Tensor]
(M, 5) Segmentation label tensor
seg_pred : Union[TensorBatch, np.ndarray, torch.Tensor]
(M/N_deghost) Segmentation predictions for each voxel
ghost_pred : Union[TensorBatch, np.ndarray, torch.Tensor], optional
(M) Ghost predictions for each voxel

Returns
-------
Union[TensorBatch, np.ndarray, torch.Tensor]
(N_deghost, N_l) Adapted cluster label tensor
"""
# Set the data type/device based on the input
ref_tensor = clust_label
if isinstance(ref_tensor, TensorBatch):
ref_tensor = ref_tensor.tensor
self.torch = isinstance(ref_tensor, torch.Tensor)

self.dtype = clust_label.dtype
if self.torch:
self.device = clust_label.device

# Dispatch depending on the data type
if isinstance(clust_label, TensorBatch):
# If it is batch data, call the main process function of each entry
shape = (seg_pred.shape[0], clust_label.shape[1])
clust_label_adapted = torch.empty(
shape, dtype=clust_label.dtype, device=clust_label.device)
for b in range(clust_label.batch_size):
lower, upper = seg_pred.edges[b], seg_pred.edges[b+1]
ghost_pred_b = ghost_pred[b] if ghost_pred is not None else None
clust_label_adapted[lower:upper] = self._process(
clust_label[b], seg_label[b], seg_pred[b], ghost_pred_b)

return TensorBatch(clust_label_adapted, seg_pred.counts)

else:
# Otherwise, call the main process function directly
return self._process(clust_label, seg_label, seg_pred, ghost_pred)

def _process(self, clust_label, seg_label, seg_pred, ghost_pred=None):
"""Adapts the cluster labels for one entry or a batch of entries.

Parameters
----------
clust_label : Union[np.ndarray, torch.Tensor]
(N, N_l) Cluster label tensor
seg_label : Union[np.ndarray, torch.Tensor]
(M, 5) Segmentation label tensor
seg_pred : Union[np.ndarray, torch.Tensor]
(M/N_deghost) Segmentation predictions for each voxel
ghost_pred : Union[np.ndarray, torch.Tensor], optional
(M) Ghost predictions for each voxel

Returns
-------
Union[np.ndarray, torch.Tensor]
(N_deghost, N_l) Adapted cluster label tensor
"""
# If there are no points in this event, nothing to do
coords = seg_label[:, :VALUE_COL]
num_cols = clust_label.shape[1]
if not len(coords):
return self._ones((0, num_cols))

# If there are no points after deghosting, nothing to do
if ghost_pred is not None:
deghost_index = self._where(ghost_pred == 0)[0]
if not len(deghost_index):
return self._ones((0, num_cols))

# If there are no label points in this event, return dummy labels
if not len(clust_label):
if ghost_pred is None:
shape = (len(coords), num_cols)
dummy_labels = -self._ones(shape)
dummy_labels[:, :VALUE_COL] = coords

else:
shape = (len(deghost_index), num_cols)
dummy_labels = -self._ones(shape)
dummy_labels[:, :VALUE_COL] = coords[deghost_index]

return dummy_labels

# Build a tensor of predicted segmentation that includes ghost points
seg_label = self._to_long(seg_label[:, SHAPE_COL])
if ghost_pred is not None and (len(ghost_pred) != len(seg_pred)):
seg_pred_long = self._to_long(GHOST_SHP*self._ones(len(coords)))
seg_pred_long[deghost_index] = seg_pred
seg_pred = seg_pred_long

# Prepare new labels
new_label = -self._ones((len(coords), num_cols))
new_label[:, :VALUE_COL] = coords

# Check if the segment labels and predictions are compatible. If they are
# compatible, store the cluster labels as is. Track points do not mix
# with other classes, but EM classes are allowed to.
compat_mat = self._eye(GHOST_SHP + 1)
compat_mat[([SHOWR_SHP, SHOWR_SHP, MICHL_SHP, DELTA_SHP],
[MICHL_SHP, DELTA_SHP, SHOWR_SHP, SHOWR_SHP])] = True

true_deghost = seg_label < GHOST_SHP
seg_mismatch = ~compat_mat[(seg_pred, seg_label)]
new_label[true_deghost] = clust_label
new_label[true_deghost & seg_mismatch, VALUE_COL:] = -self._ones(1)

# For mismatched predictions, attempt to find a touching instance of the
# same class to assign it sensible cluster labels.
for s in self._unique(seg_pred):
# Skip predicted ghosts (they keep their invalid labels)
if s == GHOST_SHP:
continue

# Restrict to points in this class that have incompatible segment
# labels. Track points do not mix, EM points are allowed to.
bad_index = self._where(
(seg_pred == s) & (~true_deghost | seg_mismatch))[0]
if len(bad_index) == 0:
continue

# Find points in clust_label that have compatible segment labels
seg_clust_mask = compat_mat[s][self._to_long(clust_label[:, SHAPE_COL])]
X_true = clust_label[seg_clust_mask]
if len(X_true) == 0:
continue

# Loop over the set of unlabeled predicted points
X_pred = coords[bad_index]
tagged_voxels_count = 1
while tagged_voxels_count > 0 and len(X_pred) > 0:
# Find the nearest neighbor to each predicted point
closest_ids = self._compute_neighbor(X_pred, X_true)

# Compute Chebyshev distance between predicted and closest true.
distances = self._compute_distances(X_pred, X_true[closest_ids])

# Label unlabeled voxels that touch a compatible true voxel
select_mask = distances < 1.1
select_index = self._where(select_mask)[0]
tagged_voxels_count = len(select_index)
if tagged_voxels_count > 0:
# Use the label of the touching true voxel
additional_clust_label = self._cat(
[X_pred[select_index],
X_true[closest_ids[select_index], VALUE_COL:]], 1)
new_label[bad_index[select_index]] = additional_clust_label

# Update the mask to not include the new assigned points
leftover_index = self._where(~select_mask)[0]
bad_index = bad_index[leftover_index]

# The new true available points are the ones we just added.
# The new pred points are those not yet labeled
X_true = additional_clust_label
X_pred = X_pred[leftover_index]

# Remove predicted ghost points from the labels, set the shape
# column of the label to the segmentation predictions.
if ghost_pred is not None:
new_label = new_label[deghost_index]
new_label[:, SHAPE_COL] = seg_pred[deghost_index]
else:
new_label[:, SHAPE_COL] = seg_pred

# Build a list of cluster indexes to break
new_label_np = new_label
if torch.is_tensor(new_label):
new_label_np = new_label.detach().cpu().numpy()

clusts = []
labels = new_label_np[:, CLUST_COL]
shapes = new_label_np[:, SHAPE_COL]
for break_class in self.break_classes:
index_s = np.where(shapes == break_class)[0]
labels_s = labels[index_s]
for c in np.unique(labels_s):
# If the cluster ID is invalid, skip
if c < 0:
continue

# Append cluster
clusts.append(index_s[labels_s == c])

# Now if an instance was broken up, assign it different cluster IDs
new_label[:, CLUST_COL] = break_clusters(
new_label, clusts, self.break_eps, self.break_metric)

return new_label

def _where(self, x):
if self.torch:
return torch.where(x)
else:
return np.where(x)

def _cat(self, x, axis):
if self.torch:
return torch.cat(x, axis)
else:
return np.concatenate(x, axis)

def _ones(self, x):
if self.torch:
return torch.ones(x, dtype=self.dtype, device=self.device)
else:
return np.ones(x)

def _eye(self, x):
if self.torch:
return torch.eye(x, dtype=torch.bool, device=self.device)
else:
return np.eye(x, dtype=bool)

def _unique(self, x):
if self.torch:
return torch.unique(x).long()
else:
return np.unique(x).astype(np.int64)

def _to_long(self, x):
if self.torch:
return x.long()
else:
return x.astype(int64)

def _compute_neighbor(self, x, y):
if self.torch:
return knn(y[:, COORD_COLS], x[:, COORD_COLS], 1)[1]
else:
return cdist(x[:, COORD_COLS], y[:, COORD_COLS]).argmin(axis=1)

def _compute_distances(self, x, y):
if self.torch:
return torch.amax(torch.abs(y[:, COORD_COLS] - x[:, COORD_COLS]), dim=1)
else:
return np.amax(np.abs(x[:, COORD_COLS] - y[:, COORD_COLS]), axis=1)
Loading
Loading