Skip to content

Commit c1aecb8

Browse files
committed
Test unsupervised training and raise coverage
1 parent 912e6bd commit c1aecb8

File tree

4 files changed

+149
-58
lines changed

4 files changed

+149
-58
lines changed

napari_cellseg3d/_tests/fixtures.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import torch
12
from qtpy.QtWidgets import QTextEdit
23

34
from napari_cellseg3d.utils import LOGGER as logger
@@ -17,3 +18,41 @@ def warn(self, warning):
1718

1819
def error(self, e):
1920
raise (e)
21+
22+
23+
class WNetFixture(torch.nn.Module):
24+
def __init__(self):
25+
super().__init__()
26+
self.mock_conv = torch.nn.Conv3d(1, 1, 1)
27+
self.mock_conv.requires_grad_(False)
28+
29+
def forward_encoder(self, x):
30+
return x
31+
32+
def forward_decoder(self, x):
33+
return x
34+
35+
def forward(self, x):
36+
return self.forward_encoder(x), self.forward_decoder(x)
37+
38+
39+
class OptimizerFixture:
40+
def __call__(self, x):
41+
return x
42+
43+
def zero_grad(self):
44+
pass
45+
46+
def step(self):
47+
pass
48+
49+
50+
class LossFixture:
51+
def __call__(self, x):
52+
return x
53+
54+
def backward(self, x):
55+
pass
56+
57+
def item(self):
58+
return 0

napari_cellseg3d/_tests/test_training_plugin.py renamed to napari_cellseg3d/_tests/test_plugin_training.py

Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
from pathlib import Path
22

33
from napari_cellseg3d import config
4-
from napari_cellseg3d._tests.fixtures import LogFixture
5-
from napari_cellseg3d.code_models.models.model_test import TestModel
6-
from napari_cellseg3d.code_models.workers_utils import TrainingReport
74
from napari_cellseg3d.code_plugins.plugin_model_training import (
85
Trainer,
96
)
@@ -109,51 +106,3 @@ def test_check_matching_losses():
109106
worker = plugin._create_supervised_worker_from_config(config)
110107

111108
assert plugin.loss_list == list(worker.loss_dict.keys())
112-
113-
114-
def test_training(make_napari_viewer_proxy, qtbot):
115-
viewer = make_napari_viewer_proxy()
116-
widget = Trainer(viewer)
117-
widget.log = LogFixture()
118-
viewer.window.add_dock_widget(widget)
119-
120-
widget.images_filepath = []
121-
widget.labels_filepaths = []
122-
123-
assert not widget.check_ready()
124-
125-
widget.images_filepaths = [im_path_str]
126-
widget.labels_filepaths = [im_path_str]
127-
widget.epoch_choice.setValue(1)
128-
widget.val_interval_choice.setValue(1)
129-
130-
assert widget.check_ready()
131-
132-
MODEL_LIST["test"] = TestModel
133-
widget.model_choice.addItem("test")
134-
widget.model_choice.setCurrentText("test")
135-
widget.unsupervised_mode = False
136-
worker_config = widget._set_worker_config()
137-
assert worker_config.model_info.name == "test"
138-
worker = widget._create_supervised_worker_from_config(worker_config)
139-
worker.config.train_data_dict = [
140-
{"image": im_path_str, "label": im_path_str}
141-
]
142-
worker.config.val_data_dict = [
143-
{"image": im_path_str, "label": im_path_str}
144-
]
145-
worker.config.max_epochs = 1
146-
worker.config.validation_interval = 2
147-
worker.log_parameters()
148-
res = next(worker.train())
149-
150-
assert isinstance(res, TrainingReport)
151-
assert res.epoch == 0
152-
153-
widget.worker = worker
154-
res.show_plot = True
155-
res.loss_1_values = {"loss": [1, 1, 1, 1]}
156-
res.loss_2_values = [1, 1, 1, 1]
157-
widget.on_yield(res)
158-
assert widget.loss_1_values["loss"] == [1, 1, 1, 1]
159-
assert widget.loss_2_values == [1, 1, 1, 1]
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
from pathlib import Path
2+
3+
from napari_cellseg3d._tests.fixtures import (
4+
LogFixture,
5+
LossFixture,
6+
OptimizerFixture,
7+
WNetFixture,
8+
)
9+
from napari_cellseg3d.code_models.models.model_test import TestModel
10+
from napari_cellseg3d.code_models.workers_utils import TrainingReport
11+
from napari_cellseg3d.code_plugins.plugin_model_training import (
12+
Trainer,
13+
)
14+
from napari_cellseg3d.config import MODEL_LIST
15+
16+
im_path = Path(__file__).resolve().parent / "res/test.tif"
17+
im_path_str = str(im_path)
18+
19+
20+
def test_supervised_training(make_napari_viewer_proxy):
21+
viewer = make_napari_viewer_proxy()
22+
widget = Trainer(viewer)
23+
widget.log = LogFixture()
24+
25+
widget.images_filepath = []
26+
widget.labels_filepaths = []
27+
28+
assert not widget.check_ready()
29+
30+
widget.images_filepaths = [im_path_str]
31+
widget.labels_filepaths = [im_path_str]
32+
widget.epoch_choice.setValue(1)
33+
widget.val_interval_choice.setValue(1)
34+
35+
assert widget.check_ready()
36+
37+
MODEL_LIST["test"] = TestModel
38+
widget.model_choice.addItem("test")
39+
widget.model_choice.setCurrentText("test")
40+
widget.unsupervised_mode = False
41+
worker_config = widget._set_worker_config()
42+
assert worker_config.model_info.name == "test"
43+
worker = widget._create_supervised_worker_from_config(worker_config)
44+
worker.config.train_data_dict = [
45+
{"image": im_path_str, "label": im_path_str}
46+
]
47+
worker.config.val_data_dict = [
48+
{"image": im_path_str, "label": im_path_str}
49+
]
50+
worker.config.max_epochs = 1
51+
worker.config.validation_interval = 2
52+
worker.log_parameters()
53+
res = next(worker.train())
54+
55+
assert isinstance(res, TrainingReport)
56+
assert res.epoch == 0
57+
58+
widget.worker = worker
59+
res.show_plot = True
60+
res.loss_1_values = {"loss": [1, 1, 1, 1]}
61+
res.loss_2_values = [1, 1, 1, 1]
62+
widget.on_yield(res)
63+
assert widget.loss_1_values["loss"] == [1, 1, 1, 1]
64+
assert widget.loss_2_values == [1, 1, 1, 1]
65+
66+
67+
def test_unsupervised_training(make_napari_viewer_proxy):
68+
viewer = make_napari_viewer_proxy()
69+
widget = Trainer(viewer)
70+
widget.log = LogFixture()
71+
widget.worker = None
72+
widget._toggle_unsupervised_mode(enabled=True)
73+
widget.model_choice.setCurrentText("WNet")
74+
75+
widget.patch_choice.setChecked(True)
76+
[w.setValue(4) for w in widget.patch_size_widgets]
77+
78+
widget.unsupervised_images_filewidget.text_field.setText(
79+
str(im_path.parent)
80+
)
81+
# widget.start()
82+
widget.data = widget.create_dataset_dict_no_labs()
83+
widget.worker = widget._create_worker(
84+
additional_results_description="wnet_test"
85+
)
86+
assert widget.worker.config.train_data_dict is not None
87+
res = next(
88+
widget.worker.train(
89+
provided_model=WNetFixture(),
90+
provided_optimizer=OptimizerFixture(),
91+
provided_loss=LossFixture(),
92+
)
93+
)
94+
assert isinstance(res, TrainingReport)

napari_cellseg3d/code_models/worker_training.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,9 @@ def log_parameters(self):
362362
for k, v in d.items()
363363
]
364364

365-
def train(self):
365+
def train(
366+
self, provided_model=None, provided_optimizer=None, provided_loss=None
367+
):
366368
try:
367369
if self.config is None:
368370
self.config = config.WNetTrainingWorkerConfig()
@@ -395,11 +397,15 @@ def train(self):
395397
###################################################
396398
self.log("- Getting the model")
397399
# Initialize the model
398-
model = WNet(
399-
in_channels=self.config.in_channels,
400-
out_channels=self.config.out_channels,
401-
num_classes=self.config.num_classes,
402-
dropout=self.config.dropout,
400+
model = (
401+
WNet(
402+
in_channels=self.config.in_channels,
403+
out_channels=self.config.out_channels,
404+
num_classes=self.config.num_classes,
405+
dropout=self.config.dropout,
406+
)
407+
if provided_model is None
408+
else provided_model
403409
)
404410
model.to(device)
405411

@@ -458,7 +464,8 @@ def train(self):
458464
optimizer = torch.optim.Adam(
459465
model.parameters(), lr=self.config.learning_rate
460466
)
461-
467+
if provided_optimizer is not None:
468+
optimizer = provided_optimizer
462469
self.log("- Getting the loss functions")
463470
# Initialize the Ncuts loss function
464471
criterionE = SoftNCutsLoss(
@@ -538,6 +545,8 @@ def train(self):
538545
beta = self.config.rec_loss_weight
539546

540547
loss = alpha * Ncuts + beta * reconstruction_loss
548+
if provided_loss is not None:
549+
loss = provided_loss
541550
epoch_loss += loss.item()
542551
# if WANDB_INSTALLED:
543552
# wandb.log({"Weighted sum of losses": loss.item()})

0 commit comments

Comments
 (0)