Skip to content

Commit e4b10a3

Browse files
committed
Extend supervised train tests
1 parent fb1b130 commit e4b10a3

File tree

3 files changed

+78
-33
lines changed

3 files changed

+78
-33
lines changed

napari_cellseg3d/_tests/fixtures.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,23 +36,42 @@ def forward(self, x):
3636
return self.forward_encoder(x), self.forward_decoder(x)
3737

3838

39-
class OptimizerFixture:
40-
def __call__(self, x):
39+
class ModelFixture(torch.nn.Module):
40+
def __init__(self):
41+
super().__init__()
42+
self.mock_conv = torch.nn.Conv3d(1, 1, 1)
43+
self.mock_conv.requires_grad_(False)
44+
45+
def forward(self, x):
4146
return x
4247

48+
49+
class OptimizerFixture:
50+
def __init__(self):
51+
self.param_groups = []
52+
self.param_groups.append({"lr": 0})
53+
4354
def zero_grad(self):
4455
pass
4556

46-
def step(self):
57+
def step(self, *args):
58+
pass
59+
60+
61+
class SchedulerFixture:
62+
def step(self, *args):
4763
pass
4864

4965

5066
class LossFixture:
51-
def __call__(self, x):
52-
return x
67+
def __call__(self, *args):
68+
return self
5369

54-
def backward(self, x):
70+
def backward(self, *args):
5571
pass
5672

5773
def item(self):
5874
return 0
75+
76+
def detach(self):
77+
return self

napari_cellseg3d/_tests/test_training.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
from napari_cellseg3d._tests.fixtures import (
66
LogFixture,
77
LossFixture,
8+
ModelFixture,
89
OptimizerFixture,
10+
SchedulerFixture,
911
WNetFixture,
1012
)
1113
from napari_cellseg3d.code_models.models.model_test import TestModel
@@ -33,6 +35,7 @@ def test_supervised_training(make_napari_viewer_proxy):
3335
widget.labels_filepaths = [im_path_str]
3436
widget.epoch_choice.setValue(1)
3537
widget.val_interval_choice.setValue(1)
38+
widget.device_choice.setCurrentIndex(0)
3639

3740
assert widget.check_ready()
3841

@@ -49,13 +52,19 @@ def test_supervised_training(make_napari_viewer_proxy):
4952
worker.config.val_data_dict = [
5053
{"image": im_path_str, "label": im_path_str}
5154
]
52-
worker.config.max_epochs = 1
55+
worker.config.max_epochs = 2
5356
worker.config.validation_interval = 2
54-
worker.log_parameters()
55-
res = next(worker.train())
5657

57-
assert isinstance(res, TrainingReport)
58-
assert res.epoch == 0
58+
worker.log_parameters()
59+
for res_i in worker.train(
60+
provided_model=ModelFixture(),
61+
provided_optimizer=OptimizerFixture(),
62+
provided_loss=LossFixture(),
63+
provided_scheduler=SchedulerFixture(),
64+
):
65+
assert isinstance(res_i, TrainingReport)
66+
res = res_i
67+
assert res.epoch == 1
5968

6069
widget.worker = worker
6170
res.show_plot = True
@@ -86,15 +95,15 @@ def test_unsupervised_training(make_napari_viewer_proxy):
8695
additional_results_description="wnet_test"
8796
)
8897
assert widget.worker.config.train_data_dict is not None
89-
res = next(
90-
widget.worker.train(
91-
provided_model=WNetFixture(),
92-
provided_optimizer=OptimizerFixture(),
93-
provided_loss=LossFixture(),
94-
)
95-
)
96-
assert isinstance(res, TrainingReport)
97-
assert not res.show_plot
98+
widget.worker.config.max_epochs = 1
99+
for res_i in widget.worker.train(
100+
provided_model=WNetFixture(),
101+
provided_optimizer=OptimizerFixture(),
102+
provided_loss=LossFixture(),
103+
):
104+
assert isinstance(res_i, TrainingReport)
105+
res = res_i
106+
assert res.epoch == 0
98107
widget.worker._abort_requested = True
99108
res = next(
100109
widget.worker.train(

napari_cellseg3d/code_models/worker_training.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -999,7 +999,13 @@ def log_parameters(self):
999999
# self.log("\n")
10001000
# self.log("-" * 20)
10011001

1002-
def train(self):
1002+
def train(
1003+
self,
1004+
provided_model=None,
1005+
provided_optimizer=None,
1006+
provided_loss=None,
1007+
provided_scheduler=None,
1008+
):
10031009
"""Trains the PyTorch model for the given number of epochs, with the selected model and data,
10041010
using the chosen batch size, validation interval, loss function, and number of samples.
10051011
Will perform validation once every :py:obj:`val_interval` and save results if the mean dice is better
@@ -1070,13 +1076,16 @@ def train(self):
10701076
self.config.train_data_dict[0]
10711077
)
10721078
check = data_check["image"].shape
1073-
10741079
do_sampling = self.config.sampling
1075-
10761080
size = self.config.sample_size if do_sampling else check
1077-
10781081
PADDING = utils.get_padding_dim(size)
1079-
model = model_class(input_img_size=PADDING, use_checkpoint=True)
1082+
1083+
model = (
1084+
model_class(input_img_size=PADDING, use_checkpoint=True)
1085+
if provided_model is None
1086+
else provided_model
1087+
)
1088+
10801089
device = torch.device(self.config.device)
10811090
model = model.to(device)
10821091

@@ -1276,8 +1285,10 @@ def get_patch_loader_func(num_samples):
12761285
logger.info("\nDone")
12771286

12781287
logger.debug("Optimizer")
1279-
optimizer = torch.optim.Adam(
1280-
model.parameters(), self.config.learning_rate
1288+
optimizer = (
1289+
torch.optim.Adam(model.parameters(), self.config.learning_rate)
1290+
if provided_optimizer is None
1291+
else provided_optimizer
12811292
)
12821293

12831294
factor = self.config.scheduler_factor
@@ -1286,12 +1297,16 @@ def get_patch_loader_func(num_samples):
12861297
self.log("Setting it to 0.5")
12871298
factor = 0.5
12881299

1289-
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
1290-
optimizer=optimizer,
1291-
mode="min",
1292-
factor=factor,
1293-
patience=self.config.scheduler_patience,
1294-
verbose=VERBOSE_SCHEDULER,
1300+
scheduler = (
1301+
torch.optim.lr_scheduler.ReduceLROnPlateau(
1302+
optimizer=optimizer,
1303+
mode="min",
1304+
factor=factor,
1305+
patience=self.config.scheduler_patience,
1306+
verbose=VERBOSE_SCHEDULER,
1307+
)
1308+
if provided_scheduler is None
1309+
else provided_scheduler
12951310
)
12961311
dice_metric = DiceMetric(
12971312
include_background=True, reduction="mean", ignore_empty=False
@@ -1342,6 +1357,8 @@ def get_patch_loader_func(num_samples):
13421357

13431358
# device = torch.device(self.config.device)
13441359
self.set_loss_from_config()
1360+
if provided_loss is not None:
1361+
self.loss_function = provided_loss
13451362

13461363
# if model_name == "test":
13471364
# self.quit()

0 commit comments

Comments
 (0)