|
10 | 10 | FigureCanvasQTAgg as FigureCanvas, |
11 | 11 | ) |
12 | 12 | from matplotlib.figure import Figure |
| 13 | + |
13 | 14 | # MONAI |
14 | 15 | from monai.losses import DiceCELoss |
15 | 16 | from monai.losses import DiceFocalLoss |
16 | 17 | from monai.losses import DiceLoss |
17 | 18 | from monai.losses import FocalLoss |
18 | 19 | from monai.losses import GeneralizedDiceLoss |
19 | 20 | from monai.losses import TverskyLoss |
| 21 | + |
20 | 22 | # Qt |
21 | 23 | from qtpy.QtWidgets import QLabel |
22 | 24 | from qtpy.QtWidgets import QProgressBar |
@@ -163,6 +165,7 @@ def __init__( |
163 | 165 | """At which epochs to perform validation. E.g. if 2, will run validation on epochs 2,4,6...""" |
164 | 166 | self.patch_size = [] |
165 | 167 | """The size of samples to be extracted from images""" |
| 168 | + self.learning_rate = 1e-3 |
166 | 169 |
|
167 | 170 | self.model = None # TODO : custom model loading ? |
168 | 171 | self.worker = None |
@@ -222,6 +225,20 @@ def __init__( |
222 | 225 | ) |
223 | 226 | self.lbl_val_interv_choice = QLabel("Validation interval : ", self) |
224 | 227 |
|
| 228 | + self.learning_rate_dict = { |
| 229 | + "1e-3": 1e-3, |
| 230 | + "1e-4": 1e-4, |
| 231 | + "1e-5": 1e-5, |
| 232 | + "1e-6": 1e-6, |
| 233 | + } |
| 234 | + |
| 235 | + ( |
| 236 | + self.learning_rate_choice, |
| 237 | + self.lbl_learning_rate_choice, |
| 238 | + ) = ui.make_combobox( |
| 239 | + self.learning_rate_dict.keys(), label="Learning rate" |
| 240 | + ) |
| 241 | + |
225 | 242 | self.augment_choice = ui.make_checkbox("Augment data") |
226 | 243 |
|
227 | 244 | # TODO add self.tabs, self.close_buttons etc... |
@@ -528,8 +545,20 @@ def build(self): |
528 | 545 | r=5, |
529 | 546 | b=5, |
530 | 547 | ), |
531 | | - alignment=ui.LEFT_AL, |
| 548 | + # alignment=ui.LEFT_AL, |
532 | 549 | ) # batch size |
| 550 | + train_param_group_l.addWidget( |
| 551 | + ui.combine_blocks( |
| 552 | + self.learning_rate_choice, |
| 553 | + self.lbl_learning_rate_choice, |
| 554 | + min_spacing=spacing, |
| 555 | + horizontal=False, |
| 556 | + l=5, |
| 557 | + t=5, |
| 558 | + r=5, |
| 559 | + b=5, |
| 560 | + ) |
| 561 | + ) # learning rate |
533 | 562 | train_param_group_l.addWidget( |
534 | 563 | ui.combine_blocks( |
535 | 564 | self.epoch_choice, |
@@ -676,6 +705,8 @@ def start(self): |
676 | 705 | self.data = self.create_train_dataset_dict() |
677 | 706 | self.max_epochs = self.epoch_choice.value() |
678 | 707 |
|
| 708 | + self.learning_rate = self.learning_rate_dict[self.learning_rate_choice.currentText()] |
| 709 | + |
679 | 710 | self.patch_size = [] |
680 | 711 | [ |
681 | 712 | self.patch_size.append(w.value()) |
@@ -715,6 +746,7 @@ def start(self): |
715 | 746 | data_dicts=self.data, |
716 | 747 | max_epochs=self.max_epochs, |
717 | 748 | loss_function=self.get_loss(self.loss_choice.currentText()), |
| 749 | + learning_rate=self.learning_rate, |
718 | 750 | val_interval=self.val_interval, |
719 | 751 | batch_size=self.batch_size, |
720 | 752 | results_path=self.results_path, |
@@ -804,7 +836,13 @@ def on_yield(data, widget): |
804 | 836 | widget.update_loss_plot(data["losses"], data["val_metrics"]) |
805 | 837 |
|
806 | 838 | if widget.stop_requested: |
807 | | - torch.save(data["weights"], os.path.join(widget.results_path, f"latest_weights_aborted_training_{utils.get_date_time()}.pth")) |
| 839 | + torch.save( |
| 840 | + data["weights"], |
| 841 | + os.path.join( |
| 842 | + widget.results_path, |
| 843 | + f"latest_weights_aborted_training_{utils.get_date_time()}.pth", |
| 844 | + ), |
| 845 | + ) |
808 | 846 | widget.stop_requested = False |
809 | 847 |
|
810 | 848 | # def clean_cache(self): |
|
0 commit comments