Skip to content

Commit a76037c

Browse files
committed
Update worker_training.py
1 parent b7aa88b commit a76037c

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

napari_cellseg3d/code_models/worker_training.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Contains the workers used to train the models."""
2+
23
import platform
34
import time
45
from abc import abstractmethod
@@ -280,7 +281,7 @@ def get_dataset(self, train_transforms):
280281
load_single_images = Compose(
281282
[
282283
LoadImaged(keys=["image"]),
283-
EnsureChannelFirstd(keys=["image"], channel_dim="no_channel"),
284+
EnsureChannelFirstd(keys=["image"], channel_dim="no_channel", strict_check=False),
284285
Orientationd(keys=["image"], axcodes="PLI"),
285286
SpatialPadd(
286287
keys=["image"],
@@ -1345,9 +1346,9 @@ def get_patch_loader_func(num_samples):
13451346
)
13461347
sample_loader_eval = get_patch_loader_func(num_val_samples)
13471348
else:
1348-
num_train_samples = (
1349-
num_val_samples
1350-
) = self.config.num_samples
1349+
num_train_samples = num_val_samples = (
1350+
self.config.num_samples
1351+
)
13511352

13521353
sample_loader_train = get_patch_loader_func(
13531354
num_train_samples

0 commit comments

Comments
 (0)