Skip to content

Commit c1d12f5

Browse files
Merge pull request #86 from francois-drielsma/develop
Vectorized PPN loss cluster restriction routine, now much cheaper
2 parents 8e1b2d5 + c7d7a19 commit c1d12f5

File tree

1 file changed

+18
-24
lines changed

1 file changed

+18
-24
lines changed

spine/model/layer/cnn/ppn.py

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -522,30 +522,24 @@ def get_ppn_positives(coords: torch.Tensor,
522522
closests : torch.Tensor
523523
(N) tensor of the closest label point index
524524
"""
525-
# Initialize the pixel assignment and the closest index tensors
526-
device = coords.device
527-
positives = torch.zeros(coords.shape[0], device=device, dtype=torch.bool)
528-
closests = -torch.ones(coords.shape[0], device=device, dtype=torch.long)
529-
530-
# Loop over all unique particles with points within this entry
531-
part_ids = ppn_labels[:, PPN_LPART_COL]
532-
for part_id in torch.unique(part_ids):
533-
# Restrict the set of labels points and voxels to one particle
534-
point_index = torch.where(part_ids == part_id)[0]
535-
index = torch.where(labels == part_id)[0]
536-
points = ppn_labels[point_index][:, COORD_COLS]
537-
538-
# Compute the pairwise distance between the particle voxels and its
539-
# label points
540-
dist_mat = cdist_fast(coords[index], points)
541-
542-
# Generate a positive mask for all particle voxels within some
543-
# distance of its label points
544-
positives[index] = (dist_mat < resolution).any(dim=1)
545-
546-
# Assign the closest label point to each particle voxel
547-
min_return = torch.min(dist_mat, dim=1)
548-
closests[index] = point_index[min_return.indices] + offset
525+
526+
# Compute the distance from the PPN labels to all the image points
527+
dist_mat = cdist_fast(ppn_labels[:, COORD_COLS], coords)
528+
529+
# Mask out particle voxels for which the particle ID disagrees
530+
bad_mask = ppn_labels[:, [PPN_LPART_COL]] != labels
531+
dist_mat[bad_mask] = torch.inf
532+
533+
# Generate a positive mask for all particle voxels within some
534+
# distance of their label points
535+
positives = (dist_mat < resolution).any(dim=0)
536+
537+
# Assign the closest label point to each postive particle voxel
538+
pos_index = torch.where(positives)[0]
539+
closests = torch.full(
540+
(len(labels),), -1, dtype=torch.long, device=labels.device)
541+
closests[pos_index] = offset + torch.argmin(
542+
dist_mat[:, pos_index], dim=0)
549543

550544
return positives, closests
551545

0 commit comments

Comments
 (0)