Skip to content

Commit aeeb937

Browse files
committed
Added WNet training test
1 parent 23b54d3 commit aeeb937

File tree

2 files changed

+25
-0
lines changed

2 files changed

+25
-0
lines changed
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from pathlib import Path
2+
from napari_cellseg3d.code_models.models.wnet import train_wnet as t
3+
4+
def test_wnet_training():
5+
config = t.Config()
6+
7+
config.batch_size = 1
8+
config.num_epochs = 1
9+
10+
config.train_volume_directory = str(Path(__file__).resolve().parent / "res/wnet_test")
11+
config.eval_volume_directory = config.train_volume_directory
12+
config.save_every = 1
13+
config.val_interval = 2 # skip validation
14+
config.save_model_path = config.train_volume_directory + "/test.pth"
15+
16+
ncuts_loss, rec_loss, model = t.train(train_config=config)
17+
18+
assert ncuts_loss is not None
19+
assert rec_loss is not None
20+
assert model is not None

napari_cellseg3d/code_models/models/wnet/train_wnet.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -720,8 +720,13 @@ def get_dataset(config):
720720
volume_directory=config.train_volume_directory
721721
)
722722
train_files = [d.get("image") for d in train_files]
723+
# logger.debug(f"train_files: {train_files}")
723724
volumes = tiff.imread(train_files).astype(np.float32)
724725
volume_shape = volumes.shape
726+
# logger.debug(f"volume_shape: {volume_shape}")
727+
728+
if len(volume_shape) == 3:
729+
volumes = np.expand_dims(volumes, axis=0)
725730

726731
if config.normalize_input:
727732
volumes = np.array(

0 commit comments

Comments
 (0)