Skip to content

Commit 20d74d7

Browse files
authored
Update default_sam_dataset to allow parsing expected parameters
added params to edit multitensor in torchem generic transform compose and corrected no loader check to work properly
1 parent 9aa455b commit 20d74d7

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

micro_sam/training/training.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -270,9 +270,9 @@ def train_sam(
270270
with _filter_warnings(ignore_warnings):
271271

272272
t_start = time.time()
273-
274-
_check_loader(train_loader, with_segmentation_decoder, "train", verify_n_labels_in_loader)
275-
_check_loader(val_loader, with_segmentation_decoder, "val", verify_n_labels_in_loader)
273+
if verify_n_labels_in_loader is not None:
274+
_check_loader(train_loader, with_segmentation_decoder, "train", verify_n_labels_in_loader)
275+
_check_loader(val_loader, with_segmentation_decoder, "val", verify_n_labels_in_loader)
276276

277277
device = get_device(device)
278278
# Get the trainable segment anything model.
@@ -684,7 +684,10 @@ def default_sam_dataset(
684684
if custom_label_transform is None:
685685
label_transform = default_label_transform
686686
else:
687-
label_transform = torch_em.transform.generic.Compose(custom_label_transform, default_label_transform)
687+
is_multi_tensor = kwargs.pop("is_multi_tensor", True)
688+
label_transform = torch_em.transform.generic.Compose(
689+
custom_label_transform, default_label_transform, is_multi_tensor=is_multi_tensor
690+
)
688691

689692
# Check the patch shape to add a singleton if required.
690693
patch_shape = _update_patch_shape(

0 commit comments

Comments
 (0)