Skip to content

Commit b60b2be

Browse files
authored
Additional tests (#42)
* Additional tests * Fixed weight compatibility test * Added newest weight to HF for new WNet * Improving worker tests * Fix style * Improve stats display test coverage * Disable direct thread worker test for GH actions * Update tox.ini Fix ONNX reqs * Update .coveragerc Disable dev scripts coverage * More tests & fixes Tests : - Base plugin - ONNX inference - Cropping Fixes : - Fixed review image loading not working - Fixed scaling when running inference on a folder and showing originals * More worker tests * More extensive training tests * Fix anisotropy calculation
1 parent a1e363e commit b60b2be

33 files changed

+928
-738
lines changed

.coveragerc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ exclude_lines =
44

55
[run]
66
omit =
7-
napari_cellseg3d/setup.py, napari_cellseg3d/code_models/models/wnet/train_wnet.py
7+
napari_cellseg3d/setup.py, napari_cellseg3d/code_models/models/wnet/train_wnet.py, napari_cellseg3d/code_models/models/wnet/model.py,napari_cellseg3d/code_models/models/TEMPLATE_model.py, napari_cellseg3d/code_models/models/unet/*, napari_cellseg3d/dev_scripts/*

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ __pycache__/
1111
*.tiff
1212
napari_cellseg3d/_tests/res/*.csv
1313
*.pth
14+
*.pt
15+
*.onnx
16+
*.tar.gz
1417
*.db
1518

1619
# Distribution / packaging

docs/res/code/utils.rst

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,6 @@ get_time_filepath
2727
**************************************
2828
.. autofunction:: napari_cellseg3d.utils::get_time_filepath
2929

30-
save_stack
31-
**************************************
32-
.. autofunction:: napari_cellseg3d.utils::save_stack
33-
3430
get_padding_dim
3531
**************************************
3632
.. autofunction:: napari_cellseg3d.utils::get_padding_dim
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from pathlib import Path
2+
3+
from napari_cellseg3d.code_plugins.plugin_base import (
4+
BasePluginSingleImage,
5+
)
6+
7+
8+
def test_base_single_image(make_napari_viewer_proxy):
9+
viewer = make_napari_viewer_proxy()
10+
plugin = BasePluginSingleImage(viewer)
11+
12+
test_folder = Path(__file__).parent.resolve()
13+
test_image = str(test_folder / "res/test.tif")
14+
15+
assert plugin._check_results_path(str(test_folder))
16+
plugin.image_path = test_image
17+
assert plugin._default_path[0] != test_image
18+
plugin._update_default()
19+
assert plugin._default_path[0] == test_image

napari_cellseg3d/_tests/test_dock_widget.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
from napari_cellseg3d.code_plugins.plugin_review_dock import Datamanager
66

77

8-
def test_prepare(make_napari_viewer):
8+
def test_prepare(make_napari_viewer_proxy):
99
path_image = str(Path(__file__).resolve().parent / "res/test.tif")
1010
image = imread(str(path_image))
11-
viewer = make_napari_viewer()
11+
viewer = make_napari_viewer_proxy()
1212
viewer.add_image(image)
1313
widget = Datamanager(viewer)
1414
viewer.window.add_dock_widget(widget)

napari_cellseg3d/_tests/test_helper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from napari_cellseg3d.code_plugins.plugin_helper import Helper
22

33

4-
def test_helper(make_napari_viewer):
5-
viewer = make_napari_viewer()
4+
def test_helper(make_napari_viewer_proxy):
5+
viewer = make_napari_viewer_proxy()
66
widget = Helper(viewer)
77

88
dock = viewer.window.add_dock_widget(widget)

napari_cellseg3d/_tests/test_interface.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,6 @@ def test_log(qtbot):
1515

1616

1717
def test_zoom_factor():
18-
resolution = [10.0, 10.0, 5.0]
18+
resolution = [5.0, 10.0, 5.0]
1919
zoom = AnisotropyWidgets.anisotropy_zoom_factor(resolution)
20-
assert zoom == [1, 1, 0.5]
20+
assert zoom == [1, 0.5, 1]

napari_cellseg3d/_tests/test_model_framework.py

Lines changed: 66 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
from pathlib import Path
22

33
from napari_cellseg3d.code_models import model_framework
4+
from napari_cellseg3d.config import MODEL_LIST
45

56

67
def pth(path):
78
return str(Path(path))
89

910

10-
def test_update_default(make_napari_viewer):
11-
view = make_napari_viewer()
11+
def test_update_default(make_napari_viewer_proxy):
12+
view = make_napari_viewer_proxy()
1213
widget = model_framework.ModelFramework(view)
1314

1415
widget.images_filepaths = []
@@ -38,8 +39,8 @@ def test_update_default(make_napari_viewer):
3839
]
3940

4041

41-
def test_create_train_dataset_dict(make_napari_viewer):
42-
view = make_napari_viewer()
42+
def test_create_train_dataset_dict(make_napari_viewer_proxy):
43+
view = make_napari_viewer_proxy()
4344
widget = model_framework.ModelFramework(view)
4445

4546
widget.images_filepaths = [str(f"{i}.tif") for i in range(3)]
@@ -52,3 +53,64 @@ def test_create_train_dataset_dict(make_napari_viewer):
5253
]
5354

5455
assert widget.create_train_dataset_dict() == expect
56+
57+
58+
def test_log(make_napari_viewer_proxy):
59+
mock_test = "test"
60+
framework = model_framework.ModelFramework(
61+
viewer=make_napari_viewer_proxy()
62+
)
63+
framework.log.print_and_log(mock_test)
64+
assert len(framework.log.toPlainText()) != 0
65+
assert framework.log.toPlainText() == "\n" + mock_test
66+
67+
framework.results_path = str(Path(__file__).resolve().parent / "res")
68+
framework.save_log(do_timestamp=False)
69+
log_path = Path(__file__).resolve().parent / "res/Log_report.txt"
70+
assert log_path.is_file()
71+
with Path.open(log_path.resolve(), "r") as f:
72+
assert f.read() == "\n" + mock_test
73+
74+
# remove log file
75+
log_path.unlink(missing_ok=False)
76+
log_path = Path(__file__).resolve().parent / "res/Log_report.txt"
77+
framework.save_log_to_path(str(log_path.parent), do_timestamp=False)
78+
assert log_path.is_file()
79+
with Path.open(log_path.resolve(), "r") as f:
80+
assert f.read() == "\n" + mock_test
81+
log_path.unlink(missing_ok=False)
82+
83+
84+
def test_display_elements(make_napari_viewer_proxy):
85+
framework = model_framework.ModelFramework(
86+
viewer=make_napari_viewer_proxy()
87+
)
88+
89+
framework.display_status_report()
90+
framework.display_status_report()
91+
92+
framework.custom_weights_choice.setChecked(False)
93+
framework._toggle_weights_path()
94+
assert not framework.weights_filewidget.isVisible()
95+
96+
97+
def test_available_models_retrieval(make_napari_viewer_proxy):
98+
framework = model_framework.ModelFramework(
99+
viewer=make_napari_viewer_proxy()
100+
)
101+
assert framework.get_available_models() == MODEL_LIST
102+
103+
104+
def test_update_weights_path(make_napari_viewer_proxy):
105+
framework = model_framework.ModelFramework(
106+
viewer=make_napari_viewer_proxy()
107+
)
108+
assert (
109+
framework._update_weights_path(framework._default_weights_folder)
110+
is None
111+
)
112+
name = str(Path.home() / "test/weight.pth")
113+
framework._update_weights_path([name])
114+
assert framework.weights_config.path == name
115+
assert framework.weights_filewidget.text_field.text() == name
116+
assert framework._default_weights_folder == str(Path.home() / "test")

napari_cellseg3d/_tests/test_models.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
from pathlib import Path
2+
13
import numpy as np
4+
import pytest
25
import torch
36
from numpy.random import PCG64, Generator
47

@@ -8,6 +11,7 @@
811
crf_batch,
912
crf_with_config,
1013
)
14+
from napari_cellseg3d.code_models.models.model_TRAILMAP_MS import TRAILMAP_MS_
1115
from napari_cellseg3d.code_models.models.wnet.soft_Ncuts import SoftNCutsLoss
1216
from napari_cellseg3d.config import MODEL_LIST, CRFConfig
1317

@@ -51,6 +55,15 @@ def test_soft_ncuts_loss():
5155
assert isinstance(res, torch.Tensor)
5256
assert 0 <= res <= 1 # ASSUMES NUMBER OF CLASS IS 2, NOT CORRECT IF K>2
5357

58+
loss = SoftNCutsLoss(
59+
data_shape=[dims, dims, dims],
60+
device="cpu",
61+
intensity_sigma=4,
62+
spatial_sigma=4,
63+
radius=None, # test radius=None init
64+
)
65+
assert loss.radius == 5
66+
5467

5568
def test_crf_batch():
5669
dims = 8
@@ -95,3 +108,33 @@ def on_yield(result):
95108

96109
result = next(crf._run_crf_job())
97110
on_yield(result)
111+
112+
113+
def test_pretrained_weights_compatibility():
114+
from napari_cellseg3d.code_models.workers import WeightsDownloader
115+
from napari_cellseg3d.config import MODEL_LIST, PRETRAINED_WEIGHTS_DIR
116+
117+
for model_name in MODEL_LIST:
118+
file_name = MODEL_LIST[model_name].weights_file
119+
WeightsDownloader().download_weights(model_name, file_name)
120+
model = MODEL_LIST[model_name](input_img_size=[128, 128, 128])
121+
try:
122+
model.load_state_dict(
123+
torch.load(
124+
str(Path(PRETRAINED_WEIGHTS_DIR) / file_name),
125+
map_location="cpu",
126+
),
127+
strict=True,
128+
)
129+
except RuntimeError:
130+
pytest.fail(f"Failed to load weights for {model_name}")
131+
132+
133+
def test_trailmap_init():
134+
test = TRAILMAP_MS_(
135+
input_img_size=[128, 128, 128],
136+
in_channels=1,
137+
out_channels=1,
138+
dropout_prob=0.3,
139+
)
140+
assert isinstance(test, TRAILMAP_MS_)

napari_cellseg3d/_tests/test_plugin_inference.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from pathlib import Path
22

3+
from numpy.random import PCG64, Generator
34
from tifffile import imread
45

56
from napari_cellseg3d._tests.fixtures import LogFixture
67
from napari_cellseg3d.code_models.instance_segmentation import (
78
INSTANCE_SEGMENTATION_METHOD_LIST,
9+
volume_stats,
810
)
911
from napari_cellseg3d.code_models.models.model_test import TestModel
1012
from napari_cellseg3d.code_plugins.plugin_model_inference import (
@@ -13,14 +15,16 @@
1315
)
1416
from napari_cellseg3d.config import MODEL_LIST
1517

18+
rand_gen = Generator(PCG64(12345))
1619

17-
def test_inference(make_napari_viewer, qtbot):
20+
21+
def test_inference(make_napari_viewer_proxy, qtbot):
1822
im_path = str(Path(__file__).resolve().parent / "res/test.tif")
1923
image = imread(im_path)
2024

2125
assert image.shape == (6, 6, 6)
2226

23-
viewer = make_napari_viewer()
27+
viewer = make_napari_viewer_proxy()
2428
widget = Inferer(viewer)
2529
widget.log = LogFixture()
2630
viewer.window.add_dock_widget(widget)
@@ -48,8 +52,8 @@ def test_inference(make_napari_viewer, qtbot):
4852
assert widget.worker_config is not None
4953
assert widget.model_info is not None
5054
widget.window_infer_box.setChecked(False)
51-
worker = widget._create_worker_from_config(widget.worker_config)
5255

56+
worker = widget._create_worker_from_config(widget.worker_config)
5357
assert worker.config is not None
5458
assert worker.config.model_info is not None
5559
worker.config.layer = viewer.layers[0].data
@@ -65,5 +69,31 @@ def test_inference(make_napari_viewer, qtbot):
6569
assert isinstance(res, InferenceResult)
6670
assert res.result.shape == (8, 8, 8)
6771
assert res.instance_labels.shape == (8, 8, 8)
68-
6972
widget.on_yield(res)
73+
74+
mock_image = rand_gen.random(size=(10, 10, 10))
75+
mock_labels = rand_gen.integers(0, 10, (10, 10, 10))
76+
mock_results = InferenceResult(
77+
image_id=0,
78+
original=mock_image,
79+
instance_labels=mock_labels,
80+
crf_results=mock_image,
81+
stats=[volume_stats(mock_labels)],
82+
result=mock_image,
83+
model_name="test",
84+
)
85+
num_layers = len(viewer.layers)
86+
widget.worker_config.post_process_config.instance.enabled = True
87+
widget._display_results(mock_results)
88+
assert len(viewer.layers) == num_layers + 4
89+
90+
# assert widget.check_ready()
91+
# widget._setup_worker()
92+
# # widget.config.show_results = True
93+
# with qtbot.waitSignal(widget.worker.yielded, timeout=10000) as blocker:
94+
# blocker.connect(
95+
# widget.worker.errored
96+
# ) # Can add other signals to blocker
97+
# widget.worker.start()
98+
99+
assert widget.on_finish()

0 commit comments

Comments
 (0)