Skip to content

Commit c541525

Browse files
committed
enable ReduceLROnPlateau
1 parent ac1c8fc commit c541525

File tree

2 files changed

+32
-13
lines changed

2 files changed

+32
-13
lines changed

pymic/net_run/agent_cls.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
import csv
66
import logging
77
import time
8-
import torch
9-
from torchvision import transforms
108
import numpy as np
9+
import torch
1110
import torch.nn as nn
1211
from datetime import datetime
12+
from torch.optim import lr_scheduler
13+
from torchvision import transforms
1314
from tensorboardX import SummaryWriter
1415
from pymic.io.nifty_dataset import ClassificationDataset
1516
from pymic.loss.loss_dict_cls import PyMICClsLossDict
@@ -149,7 +150,9 @@ def training(self):
149150
loss = self.get_loss_value(data, outputs, labels)
150151
loss.backward()
151152
self.optimizer.step()
152-
self.scheduler.step()
153+
if(self.scheduler is not None and \
154+
not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)):
155+
self.scheduler.step()
153156

154157
# statistics
155158
sample_num += labels.size(0)
@@ -185,7 +188,9 @@ def validation(self):
185188

186189
avg_loss = running_loss / sample_num
187190
avg_score= running_score.double() / sample_num
188-
metrics =self.config['training'].get("evaluation_metric", "accuracy")
191+
metrics = self.config['training'].get("evaluation_metric", "accuracy")
192+
if(isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)):
193+
self.scheduler.step(avg_score)
189194
valid_scalers = {'loss': avg_loss, metrics: avg_score}
190195
return valid_scalers
191196

@@ -222,7 +227,15 @@ def train_valid(self):
222227
iter_max = self.config['training']['iter_max']
223228
iter_valid = self.config['training']['iter_valid']
224229
iter_save = self.config['training']['iter_save']
230+
early_stop_it = self.config['training'].get('early_stop_patience', None)
225231
metrics = self.config['training'].get("evaluation_metric", "accuracy")
232+
if(iter_save is None):
233+
iter_save_list = [iter_max]
234+
elif(isinstance(iter_save, (tuple, list))):
235+
iter_save_list = iter_save
236+
else:
237+
iter_save_list = range(0, iter_max + 1, iter_save)
238+
226239
self.max_val_score = 0.0
227240
self.max_val_it = 0
228241
self.best_model_wts = None
@@ -243,29 +256,35 @@ def train_valid(self):
243256

244257
logging.info("{0:} training start".format(str(datetime.now())[:-7]))
245258
self.summ_writer = SummaryWriter(self.config['training']['ckpt_save_dir'])
259+
self.glob_it = iter_start
246260
for it in range(iter_start, iter_max, iter_valid):
247261
lr_value = self.optimizer.param_groups[0]['lr']
248262
train_scalars = self.training()
249263
valid_scalars = self.validation()
250-
glob_it = it + iter_valid
251-
self.write_scalars(train_scalars, valid_scalars, lr_value, glob_it)
264+
self.glob_it = it + iter_valid
265+
self.write_scalars(train_scalars, valid_scalars, lr_value, self.glob_it)
252266

253267
if(valid_scalars[metrics] > self.max_val_score):
254268
self.max_val_score = valid_scalars[metrics]
255-
self.max_val_it = glob_it
269+
self.max_val_it = self.glob_it
256270
self.best_model_wts = copy.deepcopy(self.net.state_dict())
257271

258-
if (glob_it % iter_save == 0):
259-
save_dict = {'iteration': glob_it,
272+
stop_now = True if(early_stop_it is not None and \
273+
self.glob_it - self.max_val_it > early_stop_it) else False
274+
275+
if ((self.glob_it in iter_save_list) or stop_now):
276+
save_dict = {'iteration': self.glob_it,
260277
'valid_pred': valid_scalars[metrics],
261278
'model_state_dict': self.net.state_dict(),
262279
'optimizer_state_dict': self.optimizer.state_dict()}
263-
save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, glob_it)
280+
save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, self.glob_it)
264281
torch.save(save_dict, save_name)
265282
txt_file = open("{0:}/{1:}_latest.txt".format(ckpt_dir, ckpt_prefix), 'wt')
266-
txt_file.write(str(glob_it))
283+
txt_file.write(str(self.glob_it))
267284
txt_file.close()
268-
285+
if(stop_now):
286+
logging.info("The training is early stopped")
287+
break
269288
# save the best performing checkpoint
270289
save_dict = {'iteration': self.max_val_it,
271290
'valid_pred': self.max_val_score,

pymic/net_run/agent_seg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
import numpy as np
1111
import torch.nn as nn
1212
import torch.optim as optim
13-
from torch.optim import lr_scheduler
1413
import torch.nn.functional as F
1514
from datetime import datetime
15+
from torch.optim import lr_scheduler
1616
from tensorboardX import SummaryWriter
1717
from pymic.io.image_read_write import save_nd_array_as_image
1818
from pymic.io.nifty_dataset import NiftyDataset

0 commit comments

Comments
 (0)