Skip to content

Commit f33ccd8

Browse files
committed
Fix loss function instantiation
1 parent d29cdce commit f33ccd8

File tree

3 files changed

+46
-32
lines changed

3 files changed

+46
-32
lines changed

napari_cellseg3d/_tests/test_training.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,14 @@ def test_update_loss_plot(make_napari_viewer_proxy):
5353
assert widget.train_loss_plot is not None
5454

5555

56+
def test_check_matching_losses():
57+
plugin = Trainer(None)
58+
config = plugin._set_worker_config()
59+
worker = plugin._create_worker_from_config(config)
60+
61+
assert plugin.loss_list == list(worker.loss_dict.keys())
62+
63+
5664
def test_training(make_napari_viewer_proxy, qtbot):
5765
im_path = str(Path(__file__).resolve().parent / "res/test.tif")
5866

@@ -73,9 +81,6 @@ def test_training(make_napari_viewer_proxy, qtbot):
7381

7482
assert widget.check_ready()
7583

76-
#################
77-
# Training is too long to test properly this way. Do not use on Github
78-
#################
7984
MODEL_LIST["test"] = TestModel
8085
widget.model_choice.addItem("test")
8186
widget.model_choice.setCurrentText("test")

napari_cellseg3d/code_models/worker_training.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@
1515
pad_list_data_collate,
1616
)
1717
from monai.inferers import sliding_window_inference
18+
from monai.losses import (
19+
DiceCELoss,
20+
DiceLoss,
21+
GeneralizedDiceLoss,
22+
TverskyLoss,
23+
)
1824
from monai.metrics import DiceMetric
1925
from monai.transforms import (
2026
# AsDiscrete,
@@ -125,6 +131,25 @@ def __init__(
125131
#######################################
126132
self.downloader = WeightsDownloader()
127133

134+
self.loss_dict = {
135+
"Dice": DiceLoss(sigmoid=True),
136+
# "BCELoss": torch.nn.BCELoss(), # dev
137+
# "BCELogits": torch.nn.BCEWithLogitsLoss(),
138+
"Generalized Dice": GeneralizedDiceLoss(sigmoid=True),
139+
"DiceCE": DiceCELoss(sigmoid=True, lambda_ce=0.5),
140+
"Tversky": TverskyLoss(sigmoid=True),
141+
# "Focal loss": FocalLoss(),
142+
# "Dice-Focal loss": DiceFocalLoss(sigmoid=True, lambda_dice=0.5),
143+
}
144+
self.loss_function = None
145+
146+
def set_loss_from_config(self):
147+
try:
148+
self.loss_function = self.loss_dict[self.config.loss_function]
149+
except KeyError as e:
150+
self.raise_error(e, "Loss function not found, aborting job")
151+
return self.loss_function
152+
128153
def set_download_log(self, widget):
129154
self.downloader.log_widget = widget
130155

@@ -532,6 +557,7 @@ def get_loader_func(num_samples):
532557
self.log_parameters()
533558

534559
device = torch.device(self.config.device)
560+
self.set_loss_from_config()
535561

536562
# if model_name == "test":
537563
# self.quit()
@@ -571,7 +597,7 @@ def get_loader_func(num_samples):
571597
] # TODO(cyril): adapt if additional channels
572598
if len(outputs.shape) < 4:
573599
outputs = outputs.unsqueeze(0)
574-
loss = self.config.loss_function(outputs, labels)
600+
loss = self.loss_function(outputs, labels)
575601
loss.backward()
576602
optimizer.step()
577603
epoch_loss += loss.detach().item()

napari_cellseg3d/code_plugins/plugin_model_training.py

Lines changed: 11 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,6 @@
1515
if TYPE_CHECKING:
1616
import napari
1717

18-
# MONAI
19-
from monai.losses import (
20-
DiceCELoss,
21-
DiceLoss,
22-
GeneralizedDiceLoss,
23-
TverskyLoss,
24-
)
25-
2618
# Qt
2719
from qtpy.QtWidgets import QSizePolicy
2820

@@ -139,17 +131,15 @@ def __init__(
139131
"""Whether the worker should stop or not"""
140132
self.start_time = None
141133

142-
self.loss_dict = {
143-
"Dice": DiceLoss(sigmoid=True),
144-
# "BCELoss": torch.nn.BCELoss(), # dev
145-
# "BCELogits": torch.nn.BCEWithLogitsLoss(),
146-
"Generalized Dice": GeneralizedDiceLoss(sigmoid=True),
147-
"DiceCE": DiceCELoss(sigmoid=True, lambda_ce=0.5),
148-
"Tversky": TverskyLoss(sigmoid=True),
149-
# "Focal loss": FocalLoss(),
150-
# "Dice-Focal loss": DiceFocalLoss(sigmoid=True, lambda_dice=0.5),
151-
}
152-
"""Dict of loss functions"""
134+
self.loss_list = [ # MUST BE MATCHED WITH THE LOSS FUNCTIONS IN THE TRAINING WORKER DICT
135+
"Dice",
136+
"Generalized Dice",
137+
"DiceCE",
138+
"Tversky",
139+
# "Focal loss",
140+
# "Dice-Focal loss",
141+
]
142+
"""List of loss functions"""
153143

154144
self.canvas = None
155145
"""Canvas to plot loss and dice metric in"""
@@ -192,10 +182,7 @@ def __init__(
192182
)
193183

194184
self.loss_choice = ui.DropdownMenu(
195-
# sorted(
196-
list(
197-
self.loss_dict.keys(),
198-
),
185+
self.loss_list,
199186
text_label="Loss function",
200187
)
201188
self.lbl_loss_choice = self.loss_choice.label
@@ -383,10 +370,6 @@ def _update_validation_choice(self):
383370
elif validation.maximum() < max_epoch:
384371
validation.setMaximum(max_epoch)
385372

386-
def get_loss(self, key):
387-
"""Getter for loss function selected by user"""
388-
return self.loss_dict[key]
389-
390373
def _toggle_patch_dims(self):
391374
if self.patch_choice.isChecked():
392375
[w.setVisible(True) for w in self.patch_size_widgets]
@@ -899,7 +882,7 @@ def _set_worker_config(self) -> config.TrainingWorkerConfig:
899882
train_data_dict=self.data,
900883
validation_percent=validation_percent,
901884
max_epochs=self.epoch_choice.value(),
902-
loss_function=self.get_loss(self.loss_choice.currentText()),
885+
loss_function=self.loss_choice.currentText(),
903886
learning_rate=float(self.learning_rate_choice.currentText()),
904887
scheduler_patience=self.scheduler_patience_choice.value(),
905888
scheduler_factor=self.scheduler_factor_choice.slider_value,

0 commit comments

Comments
 (0)