Skip to content

Commit 3ec97ab

Browse files
lufre1anwai98
andauthored
Allow support for custom label transforms in default_sam_dataset (#1063)
* added feature to add custom label transform * Make label_transform behaviour consistent with and without addl. decoder --------- Co-authored-by: Anwai Archit <[email protected]>
1 parent ebd10e8 commit 3ec97ab

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

micro_sam/training/training.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -668,8 +668,9 @@ def default_sam_dataset(
668668
if raw_transform is None:
669669
raw_transform = require_8bit
670670

671+
# Prepare the label transform.
671672
if with_segmentation_decoder:
672-
label_transform = torch_em.transform.label.PerObjectDistanceTransform(
673+
default_label_transform = torch_em.transform.label.PerObjectDistanceTransform(
673674
distances=True,
674675
boundary_distances=True,
675676
directed_distances=False,
@@ -678,7 +679,14 @@ def default_sam_dataset(
678679
min_size=min_size,
679680
)
680681
else:
681-
label_transform = torch_em.transform.label.MinSizeLabelTransform(min_size=min_size)
682+
default_label_transform = torch_em.transform.label.MinSizeLabelTransform(min_size=min_size)
683+
684+
# Allow combining label transforms.
685+
custom_label_transform = kwargs.pop("label_transform", None)
686+
if custom_label_transform is None:
687+
label_transform = default_label_transform
688+
else:
689+
label_transform = torch_em.transform.generic.Compose(custom_label_transform, default_label_transform)
682690

683691
# Check the patch shape to add a singleton if required.
684692
patch_shape = _update_patch_shape(

0 commit comments

Comments
 (0)