Skip to content

Commit 912e6bd

Browse files
committed
Testing fixes
Due to Singleton Trainer widget
1 parent 420a641 commit 912e6bd

File tree

3 files changed

+39
-49
lines changed

3 files changed

+39
-49
lines changed

napari_cellseg3d/_tests/test_supervised_training.py renamed to napari_cellseg3d/_tests/test_training_plugin.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313
im_path_str = str(im_path)
1414

1515

16-
def test_create_supervised_worker_from_config(make_napari_viewer_proxy):
16+
def test_worker_configs(make_napari_viewer_proxy):
1717
viewer = make_napari_viewer_proxy()
1818
widget = Trainer(viewer=viewer)
19+
# test supervised config and worker
1920
widget.device_choice.setCurrentIndex(0)
2021
widget.model_choice.setCurrentIndex(0)
2122
widget._toggle_unsupervised_mode(enabled=False)
@@ -34,6 +35,36 @@ def test_create_supervised_worker_from_config(make_napari_viewer_proxy):
3435
assert getattr(default_config, attr) == getattr(
3536
worker.config, attr
3637
)
38+
# test unsupervised config and worker
39+
widget.model_choice.setCurrentText("WNet")
40+
widget._toggle_unsupervised_mode(enabled=True)
41+
default_config = config.WNetTrainingWorkerConfig()
42+
worker = widget._create_worker(additional_results_description="TEST_1")
43+
excluded = ["results_path_folder", "sample_size", "weights_info"]
44+
for attr in dir(default_config):
45+
if not attr.startswith("__") and attr not in excluded:
46+
assert getattr(default_config, attr) == getattr(
47+
worker.config, attr
48+
)
49+
widget.unsupervised_images_filewidget.text_field.setText(
50+
str(im_path.parent)
51+
)
52+
widget.data = widget.create_dataset_dict_no_labs()
53+
worker = widget._create_worker(additional_results_description="TEST_2")
54+
dataloader, eval_dataloader, data_shape = worker._get_data()
55+
assert eval_dataloader is None
56+
assert data_shape == (6, 6, 6)
57+
58+
widget.images_filepaths = [str(im_path)]
59+
widget.labels_filepaths = [str(im_path)]
60+
# widget.unsupervised_eval_data = widget.create_train_dataset_dict()
61+
worker = widget._create_worker(additional_results_description="TEST_3")
62+
dataloader, eval_dataloader, data_shape = worker._get_data()
63+
assert widget.unsupervised_eval_data is not None
64+
assert eval_dataloader is not None
65+
assert widget.unsupervised_eval_data[0]["image"] is not None
66+
assert widget.unsupervised_eval_data[0]["label"] is not None
67+
assert data_shape == (6, 6, 6)
3768

3869

3970
def test_update_loss_plot(make_napari_viewer_proxy):
@@ -86,8 +117,8 @@ def test_training(make_napari_viewer_proxy, qtbot):
86117
widget.log = LogFixture()
87118
viewer.window.add_dock_widget(widget)
88119

89-
widget.images_filepath = None
90-
widget.labels_filepaths = None
120+
widget.images_filepath = []
121+
widget.labels_filepaths = []
91122

92123
assert not widget.check_ready()
93124

napari_cellseg3d/_tests/test_unsup_training.py

Lines changed: 0 additions & 45 deletions
This file was deleted.

napari_cellseg3d/code_plugins/plugin_model_training.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,11 @@ def check_ready(self):
431431
* False and displays a warning if not
432432
433433
"""
434-
if self.images_filepaths == [] and self.labels_filepaths != []:
434+
if (
435+
self.images_filepaths == []
436+
or self.labels_filepaths == []
437+
or len(self.images_filepaths) != len(self.labels_filepaths)
438+
):
435439
logger.warning("Image and label paths are not correctly set")
436440
return False
437441
return True

0 commit comments

Comments
 (0)