Skip to content

Commit 8bcc014

Browse files
committed
added learning rate param + number of cells
1 parent 599fb73 commit 8bcc014

File tree

9 files changed

+94
-16
lines changed

9 files changed

+94
-16
lines changed

src/napari_cellseg3d/model_framework.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import napari
55
import torch
6+
67
# Qt
78
from qtpy.QtWidgets import QLineEdit
89
from qtpy.QtWidgets import QProgressBar

src/napari_cellseg3d/model_workers.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import numpy as np
66
import torch
7+
78
# MONAI
89
from monai.data import CacheDataset
910
from monai.data import DataLoader
@@ -30,6 +31,7 @@
3031
from monai.transforms import Zoom
3132
from napari.qt.threading import GeneratorWorker
3233
from napari.qt.threading import WorkerBaseSignals
34+
3335
# Qt
3436
from qtpy.QtCore import Signal
3537
from tifffile import imwrite
@@ -378,6 +380,7 @@ def __init__(
378380
data_dicts,
379381
max_epochs,
380382
loss_function,
383+
learning_rate,
381384
val_interval,
382385
batch_size,
383386
results_path,
@@ -401,6 +404,8 @@ def __init__(
401404
402405
* loss_function : the loss function to use for training
403406
407+
* learning_rate : the learning rate of the optimizer
408+
404409
* val_interval : the interval at which to perform validation (e.g. if 2 will validate once every 2 epochs.) Also determines frequency of saving, depending on whether the metric is better or not
405410
406411
* batch_size : the batch size to use for training
@@ -428,6 +433,7 @@ def __init__(
428433
self.data_dicts = data_dicts
429434
self.max_epochs = max_epochs
430435
self.loss_function = loss_function
436+
self.learning_rate = learning_rate
431437
self.val_interval = val_interval
432438
self.batch_size = batch_size
433439
self.results_path = results_path
@@ -486,6 +492,8 @@ def train(self):
486492
487493
* loss_function : the loss function to use for training
488494
495+
* learning rate : the learning rate of the optimizer
496+
489497
* val_interval : the interval at which to perform validation (e.g. if 2 will validate once every 2 epochs.) Also determines frequency of saving, depending on whether the metric is better or not
490498
491499
* batch_size : the batch size to use for training
@@ -653,7 +661,7 @@ def train(self):
653661
)
654662
print("\nDone")
655663

656-
optimizer = torch.optim.Adam(model.parameters(), 1e-3)
664+
optimizer = torch.optim.Adam(model.parameters(), self.learning_rate)
657665
dice_metric = DiceMetric(include_background=True, reduction="mean")
658666

659667
best_metric = -1

src/napari_cellseg3d/plugin_dock.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import os
22
import warnings
3+
34
# import shutil
45
from pathlib import Path
56

67
import pandas as pd
8+
79
# Qt
810
from qtpy.QtWidgets import QVBoxLayout
911
from qtpy.QtWidgets import QWidget

src/napari_cellseg3d/plugin_helper.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import napari
2+
23
# Qt
34
from qtpy.QtWidgets import QSizePolicy
45
from qtpy.QtWidgets import QVBoxLayout

src/napari_cellseg3d/plugin_metrics.py

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,8 @@ def remove_plots(self):
157157

158158
def compute_dice(self):
159159
"""Computes the dice metric between pairs of labels. Rotates the prediction label to find matching orientation as well."""
160-
u = 0
161-
t = 0
160+
# u = 0
161+
# t = 0
162162
total_metrics = []
163163
self.canvas = (
164164
None # kind of unsafe way to stack plots... but it works.
@@ -205,12 +205,29 @@ def compute_dice(self):
205205
f"Padded sizes of images do not match ! Padded ground label : {ground.shape} Padded pred label : {pred.shape}"
206206
)
207207
# if u < 1:
208-
self._viewer.add_image(ground, name="ground", colormap="blue",opacity=0.7)
209-
self._viewer.add_image(pred, name="pred", colormap="red")
210-
self._viewer.add_image(np.rot90(pred[0][0], axes=(0,1)), name="pred flip 0", colormap="red",opacity=0.7)
211-
self._viewer.add_image(np.rot90(pred[0][0], axes=(1,2)), name="pred flip 1", colormap="red",opacity=0.7)
212-
self._viewer.add_image(np.rot90(pred[0][0], axes=(0,2)), name="pred flip 2", colormap="red",opacity=0.7)
213-
u+=1
208+
# self._viewer.add_image(
209+
# ground, name="ground", colormap="blue", opacity=0.7
210+
# )
211+
# self._viewer.add_image(pred, name="pred", colormap="red")
212+
# self._viewer.add_image(
213+
# np.rot90(pred[0][0], axes=(0, 1)),
214+
# name="pred flip 0",
215+
# colormap="red",
216+
# opacity=0.7,
217+
# )
218+
# self._viewer.add_image(
219+
# np.rot90(pred[0][0], axes=(1, 2)),
220+
# name="pred flip 1",
221+
# colormap="red",
222+
# opacity=0.7,
223+
# )
224+
# self._viewer.add_image(
225+
# np.rot90(pred[0][0], axes=(0, 2)),
226+
# name="pred flip 2",
227+
# colormap="red",
228+
# opacity=0.7,
229+
# )
230+
# u += 1
214231

215232
# TODO add rotation toggle
216233
pred_flip_x = np.rot90(pred[0][0], axes=(0, 1))
@@ -224,10 +241,15 @@ def compute_dice(self):
224241
for i in range(3):
225242
scores.append(utils.dice_coeff(np.flip(p, axis=i), ground))
226243

227-
if t <1 :
228-
for i in range(3):
229-
self._viewer.add_image(np.flip(pred_flip_x,axis=i), name=f"flip", colormap="green",opacity=0.7)
230-
t+=1
244+
# if t < 1:
245+
# for i in range(3):
246+
# self._viewer.add_image(
247+
# np.flip(pred_flip_x, axis=i),
248+
# name=f"flip",
249+
# colormap="green",
250+
# opacity=0.7,
251+
# )
252+
# t += 1
231253

232254
# print(scores)
233255
score = max(scores)

src/napari_cellseg3d/plugin_model_inference.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44

55
import napari
6+
67
# Qt
78
from qtpy.QtWidgets import QLabel
89
from qtpy.QtWidgets import QSizePolicy
@@ -626,7 +627,9 @@ def on_yield(data, widget):
626627

627628
if data["instance_labels"] is not None:
628629

629-
widget.log.print_and_log(f"\nNUMBER OF CELLS : {np.amax(data['instance_labels'])}\n")
630+
widget.log.print_and_log(
631+
f"\nNUMBER OF CELLS : {np.amax(data['instance_labels'])}\n"
632+
)
630633

631634
name = f"instance_labels_{image_id}"
632635
instance_layer = viewer.add_labels(

src/napari_cellseg3d/plugin_model_training.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@
1010
FigureCanvasQTAgg as FigureCanvas,
1111
)
1212
from matplotlib.figure import Figure
13+
1314
# MONAI
1415
from monai.losses import DiceCELoss
1516
from monai.losses import DiceFocalLoss
1617
from monai.losses import DiceLoss
1718
from monai.losses import FocalLoss
1819
from monai.losses import GeneralizedDiceLoss
1920
from monai.losses import TverskyLoss
21+
2022
# Qt
2123
from qtpy.QtWidgets import QLabel
2224
from qtpy.QtWidgets import QProgressBar
@@ -163,6 +165,7 @@ def __init__(
163165
"""At which epochs to perform validation. E.g. if 2, will run validation on epochs 2,4,6..."""
164166
self.patch_size = []
165167
"""The size of samples to be extracted from images"""
168+
self.learning_rate = 1e-3
166169

167170
self.model = None # TODO : custom model loading ?
168171
self.worker = None
@@ -222,6 +225,20 @@ def __init__(
222225
)
223226
self.lbl_val_interv_choice = QLabel("Validation interval : ", self)
224227

228+
self.learning_rate_dict = {
229+
"1e-3": 1e-3,
230+
"1e-4": 1e-4,
231+
"1e-5": 1e-5,
232+
"1e-6": 1e-6,
233+
}
234+
235+
(
236+
self.learning_rate_choice,
237+
self.lbl_learning_rate_choice,
238+
) = ui.make_combobox(
239+
self.learning_rate_dict.keys(), label="Learning rate"
240+
)
241+
225242
self.augment_choice = ui.make_checkbox("Augment data")
226243

227244
# TODO add self.tabs, self.close_buttons etc...
@@ -528,8 +545,20 @@ def build(self):
528545
r=5,
529546
b=5,
530547
),
531-
alignment=ui.LEFT_AL,
548+
# alignment=ui.LEFT_AL,
532549
) # batch size
550+
train_param_group_l.addWidget(
551+
ui.combine_blocks(
552+
self.learning_rate_choice,
553+
self.lbl_learning_rate_choice,
554+
min_spacing=spacing,
555+
horizontal=False,
556+
l=5,
557+
t=5,
558+
r=5,
559+
b=5,
560+
)
561+
) # learning rate
533562
train_param_group_l.addWidget(
534563
ui.combine_blocks(
535564
self.epoch_choice,
@@ -676,6 +705,8 @@ def start(self):
676705
self.data = self.create_train_dataset_dict()
677706
self.max_epochs = self.epoch_choice.value()
678707

708+
self.learning_rate = self.learning_rate_dict[self.learning_rate_choice.currentText()]
709+
679710
self.patch_size = []
680711
[
681712
self.patch_size.append(w.value())
@@ -715,6 +746,7 @@ def start(self):
715746
data_dicts=self.data,
716747
max_epochs=self.max_epochs,
717748
loss_function=self.get_loss(self.loss_choice.currentText()),
749+
learning_rate=self.learning_rate,
718750
val_interval=self.val_interval,
719751
batch_size=self.batch_size,
720752
results_path=self.results_path,
@@ -804,7 +836,13 @@ def on_yield(data, widget):
804836
widget.update_loss_plot(data["losses"], data["val_metrics"])
805837

806838
if widget.stop_requested:
807-
torch.save(data["weights"], os.path.join(widget.results_path, f"latest_weights_aborted_training_{utils.get_date_time()}.pth"))
839+
torch.save(
840+
data["weights"],
841+
os.path.join(
842+
widget.results_path,
843+
f"latest_weights_aborted_training_{utils.get_date_time()}.pth",
844+
),
845+
)
808846
widget.stop_requested = False
809847

810848
# def clean_cache(self):

src/napari_cellseg3d/plugin_review.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66
import pims
77
import skimage.io as io
8+
89
# Qt
910
from qtpy import QtGui
1011
from qtpy.QtWidgets import QLabel

src/napari_cellseg3d/plugin_utilities.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import napari
2+
23
# Qt
34
from qtpy.QtWidgets import QTabWidget
45

56
from napari_cellseg3d.plugin_convert import ConvertUtils
7+
68
# local
79
from napari_cellseg3d.plugin_crop import Cropping
810
from napari_cellseg3d.plugin_metrics import MetricsUtils

0 commit comments

Comments
 (0)