Skip to content

Commit 65ebd8b

Browse files
authored
Add is_multi_tensor argument to default_sam_dataset to split kwargs in loader (#1092)
1 parent 20d74d7 commit 65ebd8b

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

micro_sam/training/training.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,7 @@ def default_sam_dataset(
582582
min_size: int = 25,
583583
max_sampling_attempts: Optional[int] = None,
584584
rois: Optional[Union[slice, Tuple[slice, ...]]] = None,
585+
is_multi_tensor: bool = True,
585586
**kwargs,
586587
) -> Dataset:
587588
"""Create a PyTorch Dataset for training a SAM model.
@@ -608,6 +609,7 @@ def default_sam_dataset(
608609
min_size: Minimal object size. Smaller objects will be filtered. By default, set to '25'.
609610
max_sampling_attempts: Number of sampling attempts to make from a dataset.
610611
rois: The region of interest(s) for the data.
612+
is_multi_tensor: Whether the input data to data transforms is multiple tensors or not.
611613
kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
612614
613615
Returns:
@@ -684,7 +686,6 @@ def default_sam_dataset(
684686
if custom_label_transform is None:
685687
label_transform = default_label_transform
686688
else:
687-
is_multi_tensor = kwargs.pop("is_multi_tensor", True)
688689
label_transform = torch_em.transform.generic.Compose(
689690
custom_label_transform, default_label_transform, is_multi_tensor=is_multi_tensor
690691
)

0 commit comments

Comments
 (0)