Skip to content

Commit c54ee26

Browse files
committed
Functional WNet training
1 parent 385552b commit c54ee26

File tree

2 files changed

+4
-15
lines changed

2 files changed

+4
-15
lines changed

napari_cellseg3d/code_models/worker_training.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,6 @@
8686
# https://www.pythoncentral.io/pysidepyqt-tutorial-creating-your-own-signals-and-slots/
8787
# https://napari-staging-site.github.io/guides/stable/threading.html
8888

89-
# TODO list for WNet training :
90-
# 1. Create a custom base worker for training to avoid code duplication
91-
# 2. Create a custom worker for WNet training
92-
# 3. Adapt UI for WNet training (Advanced tab + model choice on first tab)
93-
# 4. Adapt plots and TrainingReport for WNet training
94-
# 5. log_parameters function
95-
9689

9790
class TrainingWorkerBase(GeneratorWorker):
9891
"""A basic worker abstract class, to run training jobs in.

napari_cellseg3d/code_plugins/plugin_model_training.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1295,13 +1295,9 @@ def _display_results(self, images_dict, complete_missing=False):
12951295
self.result_layers[i].refresh()
12961296
self.result_layers[i].reset_contrast_limits()
12971297

1298-
def on_yield(self, report: TrainingReport): # TODO refactor for dict
1299-
# logger.info(
1300-
# f"\nCatching results : for epoch {data['epoch']},
1301-
# loss is {data['losses']} and validation is {data['val_metrics']}"
1302-
# )
1298+
def on_yield(self, report: TrainingReport):
13031299
if report == TrainingReport():
1304-
return
1300+
return # skip empty reports
13051301

13061302
if report.show_plot:
13071303
try:
@@ -1375,7 +1371,7 @@ def _make_csv(self):
13751371
dice_metric = self.loss_1_values["Dice metric"]
13761372
self.df = pd.DataFrame(
13771373
{
1378-
"epoch": size_column,
1374+
"Epoch": size_column,
13791375
"Ncuts loss": ncuts_loss,
13801376
"Dice metric": dice_metric,
13811377
"Reconstruction loss": self.loss_2_values,
@@ -1384,7 +1380,7 @@ def _make_csv(self):
13841380
except KeyError:
13851381
self.df = pd.DataFrame(
13861382
{
1387-
"epoch": size_column,
1383+
"Epoch": size_column,
13881384
"Ncuts loss": ncuts_loss,
13891385
"Reconstruction loss": self.loss_2_values,
13901386
}

0 commit comments

Comments
 (0)