Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 18 additions & 24 deletions spine/model/layer/cnn/ppn.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,30 +522,24 @@ def get_ppn_positives(coords: torch.Tensor,
closests : torch.Tensor
(N) tensor of the closest label point index
"""
# Initialize the pixel assignment and the closest index tensors
device = coords.device
positives = torch.zeros(coords.shape[0], device=device, dtype=torch.bool)
closests = -torch.ones(coords.shape[0], device=device, dtype=torch.long)

# Loop over all unique particles with points within this entry
part_ids = ppn_labels[:, PPN_LPART_COL]
for part_id in torch.unique(part_ids):
# Restrict the set of labels points and voxels to one particle
point_index = torch.where(part_ids == part_id)[0]
index = torch.where(labels == part_id)[0]
points = ppn_labels[point_index][:, COORD_COLS]

# Compute the pairwise distance between the particle voxels and its
# label points
dist_mat = cdist_fast(coords[index], points)

# Generate a positive mask for all particle voxels within some
# distance of its label points
positives[index] = (dist_mat < resolution).any(dim=1)

# Assign the closest label point to each particle voxel
min_return = torch.min(dist_mat, dim=1)
closests[index] = point_index[min_return.indices] + offset

# Compute the distance from the PPN labels to all the image points
dist_mat = cdist_fast(ppn_labels[:, COORD_COLS], coords)

# Mask out particle voxels for which the particle ID disagrees
bad_mask = ppn_labels[:, [PPN_LPART_COL]] != labels
dist_mat[bad_mask] = torch.inf

# Generate a positive mask for all particle voxels within some
# distance of their label points
positives = (dist_mat < resolution).any(dim=0)

# Assign the closest label point to each postive particle voxel
pos_index = torch.where(positives)[0]
closests = torch.full(
(len(labels),), -1, dtype=torch.long, device=labels.device)
closests[pos_index] = offset + torch.argmin(
dist_mat[:, pos_index], dim=0)

return positives, closests

Expand Down
Loading