Skip to content

Commit 7bb5edc

Browse files
committed
Cleanup + tests
- Removed previous train script - Fix tests - Enable test workflow on GH
1 parent 1f7c9ed commit 7bb5edc

File tree

5 files changed

+5
-1020
lines changed

5 files changed

+5
-1020
lines changed

.github/workflows/test_and_deploy.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ on:
77
push:
88
branches:
99
- main
10+
- cy/wnet-train
1011
tags:
1112
- "v*" # Push events to matching v*, i.e. v1.0, v20.15.10
1213
pull_request:

napari_cellseg3d/_tests/test_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def test_pretrained_weights_compatibility():
115115
for model_name in MODEL_LIST:
116116
file_name = MODEL_LIST[model_name].weights_file
117117
WeightsDownloader().download_weights(model_name, file_name)
118-
model = MODEL_LIST[model_name](input_img_size=[128, 128, 128])
118+
model = MODEL_LIST[model_name](input_img_size=[64, 64, 64])
119119
try:
120120
model.load_state_dict(
121121
torch.load(

napari_cellseg3d/_tests/test_unsup_training.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
Trainer,
66
)
77

8+
89
def test_unsupervised_worker(make_napari_viewer_proxy):
910
im_path = Path(__file__).resolve().parent / "res/test.tif"
1011
# im_path_str = str(im_path)
@@ -34,8 +35,8 @@ def test_unsupervised_worker(make_napari_viewer_proxy):
3435
assert eval_dataloader is None
3536
assert data_shape == (6, 6, 6)
3637

37-
widget.images_filepaths = [str(im_path.parent)]
38-
widget.labels_filepaths = [str(im_path.parent)]
38+
widget.images_filepaths = [str(im_path)]
39+
widget.labels_filepaths = [str(im_path)]
3940
# widget.unsupervised_eval_data = widget.create_train_dataset_dict()
4041
worker = widget._create_worker(additional_results_description="TEST_3")
4142
dataloader, eval_dataloader, data_shape = worker._get_data()
Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +0,0 @@
1-
#######################################################
2-
# Disabled as it takes too much memory for GH actions #
3-
#######################################################
4-
5-
6-
# from pathlib import Path
7-
# from napari_cellseg3d.code_models.models.wnet import train_wnet as t
8-
#
9-
# def test_wnet_training():
10-
# config = t.Config()
11-
#
12-
# config.batch_size = 1
13-
# config.num_epochs = 1
14-
#
15-
# config.train_volume_directory = str(Path(__file__).resolve().parent / "res/wnet_test")
16-
# config.eval_volume_directory = config.train_volume_directory
17-
# config.save_every = 1
18-
# config.val_interval = 2 # skip validation
19-
# config.save_model_path = config.train_volume_directory + "/test.pth"
20-
#
21-
# ncuts_loss, rec_loss, model = t.train(train_config=config)
22-
#
23-
# assert ncuts_loss is not None
24-
# assert rec_loss is not None
25-
# assert model is not None

0 commit comments

Comments
 (0)