@@ -270,9 +270,9 @@ def train_sam(
270270 with _filter_warnings (ignore_warnings ):
271271
272272 t_start = time .time ()
273-
274- _check_loader (train_loader , with_segmentation_decoder , "train" , verify_n_labels_in_loader )
275- _check_loader (val_loader , with_segmentation_decoder , "val" , verify_n_labels_in_loader )
273+ if verify_n_labels_in_loader is not None :
274+ _check_loader (train_loader , with_segmentation_decoder , "train" , verify_n_labels_in_loader )
275+ _check_loader (val_loader , with_segmentation_decoder , "val" , verify_n_labels_in_loader )
276276
277277 device = get_device (device )
278278 # Get the trainable segment anything model.
@@ -684,7 +684,10 @@ def default_sam_dataset(
684684 if custom_label_transform is None :
685685 label_transform = default_label_transform
686686 else :
687- label_transform = torch_em .transform .generic .Compose (custom_label_transform , default_label_transform )
687+ is_multi_tensor = kwargs .pop ("is_multi_tensor" , True )
688+ label_transform = torch_em .transform .generic .Compose (
689+ custom_label_transform , default_label_transform , is_multi_tensor = is_multi_tensor
690+ )
688691
689692 # Check the patch shape to add a singleton if required.
690693 patch_shape = _update_patch_shape (
0 commit comments