-
Notifications
You must be signed in to change notification settings - Fork 11
Open
Description
Hi,
for multi-class segmentation, extraction skeleton line for each label may diff from extraction of binary mask followed by multiply label mask.
class SkeletonTransform(BasicTransform):
def __init__(self, do_tube: bool = True, num_classes: int = 1):
"""
Calculates the skeleton of the segmentation (plus an optional 2 px tube around it)
and adds it to the dict with the key "skel"
"""
super().__init__()
self.do_tube = do_tube
self.num_classes = num_classes # needed for compatibility with 3D data
assert self.num_classes >= 1
def apply(self, data_dict, **params):
seg_all = data_dict['segmentation'].numpy()
# Add tubed skeleton GT
seg_all_skel = np.zeros_like(seg_all, dtype=np.int16)
for labelid in range(1, self.num_classes + 1):
# Skeletonize
if not np.sum(seg_all[0] == labelid) == 0:
skel = skeletonize(seg_all[0] == labelid)
skel = (skel > 0).astype(np.int16)
if self.do_tube:
skel = dilation(skel)
seg_all_skel[0][skel > 0] = labelid
data_dict["skel"] = torch.from_numpy(seg_all_skel)
return data_dict
def apply_old(self, data_dict, **params):
seg_all = data_dict['segmentation'].numpy()
# Add tubed skeleton GT
bin_seg = (seg_all > 0)
seg_all_skel = np.zeros_like(bin_seg, dtype=np.int16)
if not np.sum(bin_seg[0]) == 0:
skel = skeletonize(bin_seg[0])
skel = (skel > 0).astype(np.int16)
if self.do_tube:
skel = dilation(dilation(skel))
skel *= seg_all[0].astype(np.int16)
seg_all_skel[0] = skel
data_dict["skel"] = torch.from_numpy(seg_all_skel)
return data_dict
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels