File tree Expand file tree Collapse file tree 2 files changed +25
-0
lines changed
Expand file tree Collapse file tree 2 files changed +25
-0
lines changed Original file line number Diff line number Diff line change 1+ from pathlib import Path
2+ from napari_cellseg3d .code_models .models .wnet import train_wnet as t
3+
4+ def test_wnet_training ():
5+ config = t .Config ()
6+
7+ config .batch_size = 1
8+ config .num_epochs = 1
9+
10+ config .train_volume_directory = str (Path (__file__ ).resolve ().parent / "res/wnet_test" )
11+ config .eval_volume_directory = config .train_volume_directory
12+ config .save_every = 1
13+ config .val_interval = 2 # skip validation
14+ config .save_model_path = config .train_volume_directory + "/test.pth"
15+
16+ ncuts_loss , rec_loss , model = t .train (train_config = config )
17+
18+ assert ncuts_loss is not None
19+ assert rec_loss is not None
20+ assert model is not None
Original file line number Diff line number Diff line change @@ -720,8 +720,13 @@ def get_dataset(config):
720720 volume_directory = config .train_volume_directory
721721 )
722722 train_files = [d .get ("image" ) for d in train_files ]
723+ # logger.debug(f"train_files: {train_files}")
723724 volumes = tiff .imread (train_files ).astype (np .float32 )
724725 volume_shape = volumes .shape
726+ # logger.debug(f"volume_shape: {volume_shape}")
727+
728+ if len (volume_shape ) == 3 :
729+ volumes = np .expand_dims (volumes , axis = 0 )
725730
726731 if config .normalize_input :
727732 volumes = np .array (
You can’t perform that action at this time.
0 commit comments