Skip to content

Commit fba44fd

Browse files
committed
Improving testing
1 parent 93e2654 commit fba44fd

File tree

6 files changed

+73
-15
lines changed

6 files changed

+73
-15
lines changed

napari_cellseg3d/_tests/test_plugin_inference.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from tifffile import imread
22
from pathlib import Path
33

4+
from napari_cellseg3d.config import MODEL_LIST
45
from napari_cellseg3d._tests.fixtures import LogFixture
56
from napari_cellseg3d.code_plugins.plugin_model_inference import Inferer
6-
7+
from napari_cellseg3d.code_models.models.model_test import TestModel
78

89
def test_inference(make_napari_viewer, qtbot):
910

@@ -26,10 +27,14 @@ def test_inference(make_napari_viewer, qtbot):
2627

2728
assert widget.check_ready()
2829

29-
# widget.start() # takes too long on Github Actions
30-
# assert widget.worker is not None
30+
MODEL_LIST["test"] = TestModel
31+
widget.model_choice.addItem("test")
32+
widget.setCurrentIndex(-1)
33+
34+
widget.start() # takes too long on Github Actions
35+
assert widget.worker is not None
3136

32-
# with qtbot.waitSignal(signal=widget.worker.yielded, timeout=60000, raising=False) as blocker:
33-
# blocker.connect(widget.worker.errored)
37+
with qtbot.waitSignal(signal=widget.worker.finished, timeout=60000, raising=False) as blocker:
38+
blocker.connect(widget.worker.errored)
3439

3540
# assert len(viewer.layers) == 2

napari_cellseg3d/_tests/test_training.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from napari_cellseg3d import config
44
from napari_cellseg3d.code_plugins.plugin_model_training import Trainer
55
from napari_cellseg3d._tests.fixtures import LogFixture
6+
from napari_cellseg3d.config import MODEL_LIST
7+
from napari_cellseg3d.code_models.models.model_test import TestModel
68

79

810
def test_training(make_napari_viewer, qtbot):
@@ -30,12 +32,15 @@ def test_training(make_napari_viewer, qtbot):
3032
#################
3133
# Training is too long to test properly this way. Do not use on Github
3234
#################
35+
MODEL_LIST["test"] = TestModel()
36+
widget.model_choice.addItem("test")
37+
widget.model_choice.setCurrentIndex(len(MODEL_LIST.keys())-1)
3338

34-
# widget.start()
35-
# assert widget.worker is not None
39+
widget.start()
40+
assert widget.worker is not None
3641

37-
# with qtbot.waitSignal(signal=widget.worker.yielded, timeout=60000, raising=False) as blocker: # wait only for 60 seconds.
38-
# blocker.connect(widget.worker.errored)
42+
with qtbot.waitSignal(signal=widget.worker.finished, timeout=10000, raising=False) as blocker: # wait only for 60 seconds.
43+
blocker.connect(widget.worker.errored)
3944

4045

4146
def test_update_loss_plot(make_napari_viewer):

napari_cellseg3d/code_models/model_workers.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ def model_output(
455455
inputs = inputs.to("cpu")
456456

457457
model_output = lambda inputs: post_process_transforms(
458-
self.config.model_info.get_model().get_output(model, inputs)
458+
self.config.model_info.get_model().get_output(model, inputs) # TODO(cyril) refactor those functions
459459
)
460460

461461
def model_output(inputs):
@@ -870,7 +870,7 @@ def inference(self):
870870
model.to("cpu")
871871

872872
except Exception as e:
873-
self.log(f"Error : {e}")
873+
self.log(f"Error during inference : {e}")
874874
self.quit()
875875
finally:
876876
self.quit()
@@ -1078,6 +1078,7 @@ def train(self):
10781078
torch.set_num_threads(1)
10791079
self.log("Number of threads has been set to 1 for macOS")
10801080

1081+
self.log(f"config model : {self.config.model_info.name}")
10811082
model_name = model_config.name
10821083
model_class = model_config.get_model()
10831084

@@ -1314,7 +1315,7 @@ def train(self):
13141315
)
13151316
)
13161317
except RuntimeError as e:
1317-
logger.error(f"Error : {e}")
1318+
logger.error(f"Error when loading weights : {e}")
13181319
warn = (
13191320
"WARNING:\nIt'd seem that the weights were incompatible with the model,\n"
13201321
"the model will be trained from random weights"
@@ -1333,6 +1334,9 @@ def train(self):
13331334

13341335
device = self.config.device
13351336

1337+
if model_name == "test":
1338+
self.quit()
1339+
13361340
for epoch in range(self.config.max_epochs):
13371341
# self.log("\n")
13381342
self.log("-" * 10)
@@ -1472,7 +1476,7 @@ def train(self):
14721476
model.to("cpu")
14731477

14741478
except Exception as e:
1475-
self.log(f"Error : {e}")
1479+
self.log(f"Error in training : {e}")
14761480
self.quit()
14771481
finally:
14781482
self.quit()
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import torch
2+
from torch import nn
3+
4+
5+
def get_weights_file():
6+
return "test.pth"
7+
8+
9+
class TestModel(nn.Module):
10+
11+
def __init__(self):
12+
super().__init__()
13+
self.linear = nn.Linear(1, 1)
14+
15+
def forward(self, x):
16+
return self.linear(x)
17+
18+
def get_net(self):
19+
return self
20+
21+
def get_output(self, _, input):
22+
return input
23+
24+
def get_validation(self, val_inputs):
25+
return val_inputs
26+
27+
if __name__ == "__main__":
28+
29+
model = TestModel()
30+
model.train()
31+
model.zero_grad()
32+
from napari_cellseg3d.config import WEIGHTS_DIR
33+
torch.save(
34+
model.state_dict(),
35+
WEIGHTS_DIR + f"/{get_weights_file()}"
36+
)

napari_cellseg3d/code_plugins/plugin_model_training.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -791,6 +791,7 @@ def start(self):
791791
self.data = None
792792
raise err
793793

794+
794795
model_config = config.ModelInfo(
795796
name=self.model_choice.currentText()
796797
)
@@ -820,6 +821,7 @@ def start(self):
820821

821822
patch_size = [w.value() for w in self.patch_size_widgets]
822823

824+
logger.debug("Loading config...")
823825
self.worker_config = config.TrainingWorkerConfig(
824826
device=self.get_device(),
825827
model_info=model_config,

napari_cellseg3d/config.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
model_TRAILMAP_MS as TRAILMAP_MS,
1919
)
2020
from napari_cellseg3d.code_models.models import model_VNet as VNet
21+
from napari_cellseg3d.utils import LOGGER
22+
23+
logger = LOGGER
2124

2225
# TODO(cyril) DOCUMENT !!! and add default values
2326
# TODO(cyril) add JSON load/save
@@ -28,6 +31,7 @@
2831
# "TRAILMAP": TRAILMAP,
2932
"TRAILMAP_MS": TRAILMAP_MS,
3033
"SwinUNetR": SwinUNetR,
34+
# "test" : DO NOT USE, reserved for testing
3135
}
3236

3337
INSTANCE_SEGMENTATION_METHOD_LIST = {
@@ -81,12 +85,14 @@ def get_model(self):
8185
try:
8286
return MODEL_LIST[self.name]
8387
except KeyError as e:
84-
warnings.warn(f"Model {self.name} is not defined")
88+
msg = f"Model {self.name} is not defined"
89+
warnings.warn(msg)
90+
logger.warning(msg)
8591
raise KeyError(e)
8692

8793
@staticmethod
8894
def get_model_name_list():
89-
print(
95+
logger.info(
9096
f"Model list :\n" + str(f"{name}\n" for name in MODEL_LIST.keys())
9197
)
9298
return MODEL_LIST.keys()

0 commit comments

Comments
 (0)