@@ -522,30 +522,24 @@ def get_ppn_positives(coords: torch.Tensor,
522522 closests : torch.Tensor
523523 (N) tensor of the closest label point index
524524 """
525- # Initialize the pixel assignment and the closest index tensors
526- device = coords .device
527- positives = torch .zeros (coords .shape [0 ], device = device , dtype = torch .bool )
528- closests = - torch .ones (coords .shape [0 ], device = device , dtype = torch .long )
529-
530- # Loop over all unique particles with points within this entry
531- part_ids = ppn_labels [:, PPN_LPART_COL ]
532- for part_id in torch .unique (part_ids ):
533- # Restrict the set of labels points and voxels to one particle
534- point_index = torch .where (part_ids == part_id )[0 ]
535- index = torch .where (labels == part_id )[0 ]
536- points = ppn_labels [point_index ][:, COORD_COLS ]
537-
538- # Compute the pairwise distance between the particle voxels and its
539- # label points
540- dist_mat = cdist_fast (coords [index ], points )
541-
542- # Generate a positive mask for all particle voxels within some
543- # distance of its label points
544- positives [index ] = (dist_mat < resolution ).any (dim = 1 )
545-
546- # Assign the closest label point to each particle voxel
547- min_return = torch .min (dist_mat , dim = 1 )
548- closests [index ] = point_index [min_return .indices ] + offset
525+
526+ # Compute the distance from the PPN labels to all the image points
527+ dist_mat = cdist_fast (ppn_labels [:, COORD_COLS ], coords )
528+
529+ # Mask out particle voxels for which the particle ID disagrees
530+ bad_mask = ppn_labels [:, [PPN_LPART_COL ]] != labels
531+ dist_mat [bad_mask ] = torch .inf
532+
533+ # Generate a positive mask for all particle voxels within some
534+ # distance of their label points
535+ positives = (dist_mat < resolution ).any (dim = 0 )
536+
537+ # Assign the closest label point to each postive particle voxel
538+ pos_index = torch .where (positives )[0 ]
539+ closests = torch .full (
540+ (len (labels ),), - 1 , dtype = torch .long , device = labels .device )
541+ closests [pos_index ] = offset + torch .argmin (
542+ dist_mat [:, pos_index ], dim = 0 )
549543
550544 return positives , closests
551545
0 commit comments