diff --git a/spine/model/layer/cnn/ppn.py b/spine/model/layer/cnn/ppn.py index b0e5b83b..564cb7c0 100644 --- a/spine/model/layer/cnn/ppn.py +++ b/spine/model/layer/cnn/ppn.py @@ -522,30 +522,24 @@ def get_ppn_positives(coords: torch.Tensor, closests : torch.Tensor (N) tensor of the closest label point index """ - # Initialize the pixel assignment and the closest index tensors - device = coords.device - positives = torch.zeros(coords.shape[0], device=device, dtype=torch.bool) - closests = -torch.ones(coords.shape[0], device=device, dtype=torch.long) - - # Loop over all unique particles with points within this entry - part_ids = ppn_labels[:, PPN_LPART_COL] - for part_id in torch.unique(part_ids): - # Restrict the set of labels points and voxels to one particle - point_index = torch.where(part_ids == part_id)[0] - index = torch.where(labels == part_id)[0] - points = ppn_labels[point_index][:, COORD_COLS] - - # Compute the pairwise distance between the particle voxels and its - # label points - dist_mat = cdist_fast(coords[index], points) - - # Generate a positive mask for all particle voxels within some - # distance of its label points - positives[index] = (dist_mat < resolution).any(dim=1) - - # Assign the closest label point to each particle voxel - min_return = torch.min(dist_mat, dim=1) - closests[index] = point_index[min_return.indices] + offset + + # Compute the distance from the PPN labels to all the image points + dist_mat = cdist_fast(ppn_labels[:, COORD_COLS], coords) + + # Mask out particle voxels for which the particle ID disagrees + bad_mask = ppn_labels[:, [PPN_LPART_COL]] != labels + dist_mat[bad_mask] = torch.inf + + # Generate a positive mask for all particle voxels within some + # distance of their label points + positives = (dist_mat < resolution).any(dim=0) + + # Assign the closest label point to each postive particle voxel + pos_index = torch.where(positives)[0] + closests = torch.full( + (len(labels),), -1, dtype=torch.long, device=labels.device) + closests[pos_index] = offset + torch.argmin( + dist_mat[:, pos_index], dim=0) return positives, closests