55import csv
66import logging
77import time
8- import torch
9- from torchvision import transforms
108import numpy as np
9+ import torch
1110import torch .nn as nn
1211from datetime import datetime
12+ from torch .optim import lr_scheduler
13+ from torchvision import transforms
1314from tensorboardX import SummaryWriter
1415from pymic .io .nifty_dataset import ClassificationDataset
1516from pymic .loss .loss_dict_cls import PyMICClsLossDict
@@ -149,7 +150,9 @@ def training(self):
149150 loss = self .get_loss_value (data , outputs , labels )
150151 loss .backward ()
151152 self .optimizer .step ()
152- self .scheduler .step ()
153+ if (self .scheduler is not None and \
154+ not isinstance (self .scheduler , lr_scheduler .ReduceLROnPlateau )):
155+ self .scheduler .step ()
153156
154157 # statistics
155158 sample_num += labels .size (0 )
@@ -185,7 +188,9 @@ def validation(self):
185188
186189 avg_loss = running_loss / sample_num
187190 avg_score = running_score .double () / sample_num
188- metrics = self .config ['training' ].get ("evaluation_metric" , "accuracy" )
191+ metrics = self .config ['training' ].get ("evaluation_metric" , "accuracy" )
192+ if (isinstance (self .scheduler , lr_scheduler .ReduceLROnPlateau )):
193+ self .scheduler .step (avg_score )
189194 valid_scalers = {'loss' : avg_loss , metrics : avg_score }
190195 return valid_scalers
191196
@@ -222,7 +227,15 @@ def train_valid(self):
222227 iter_max = self .config ['training' ]['iter_max' ]
223228 iter_valid = self .config ['training' ]['iter_valid' ]
224229 iter_save = self .config ['training' ]['iter_save' ]
230+ early_stop_it = self .config ['training' ].get ('early_stop_patience' , None )
225231 metrics = self .config ['training' ].get ("evaluation_metric" , "accuracy" )
232+ if (iter_save is None ):
233+ iter_save_list = [iter_max ]
234+ elif (isinstance (iter_save , (tuple , list ))):
235+ iter_save_list = iter_save
236+ else :
237+ iter_save_list = range (0 , iter_max + 1 , iter_save )
238+
226239 self .max_val_score = 0.0
227240 self .max_val_it = 0
228241 self .best_model_wts = None
@@ -243,29 +256,35 @@ def train_valid(self):
243256
244257 logging .info ("{0:} training start" .format (str (datetime .now ())[:- 7 ]))
245258 self .summ_writer = SummaryWriter (self .config ['training' ]['ckpt_save_dir' ])
259+ self .glob_it = iter_start
246260 for it in range (iter_start , iter_max , iter_valid ):
247261 lr_value = self .optimizer .param_groups [0 ]['lr' ]
248262 train_scalars = self .training ()
249263 valid_scalars = self .validation ()
250- glob_it = it + iter_valid
251- self .write_scalars (train_scalars , valid_scalars , lr_value , glob_it )
264+ self . glob_it = it + iter_valid
265+ self .write_scalars (train_scalars , valid_scalars , lr_value , self . glob_it )
252266
253267 if (valid_scalars [metrics ] > self .max_val_score ):
254268 self .max_val_score = valid_scalars [metrics ]
255- self .max_val_it = glob_it
269+ self .max_val_it = self . glob_it
256270 self .best_model_wts = copy .deepcopy (self .net .state_dict ())
257271
258- if (glob_it % iter_save == 0 ):
259- save_dict = {'iteration' : glob_it ,
272+ stop_now = True if (early_stop_it is not None and \
273+ self .glob_it - self .max_val_it > early_stop_it ) else False
274+
275+ if ((self .glob_it in iter_save_list ) or stop_now ):
276+ save_dict = {'iteration' : self .glob_it ,
260277 'valid_pred' : valid_scalars [metrics ],
261278 'model_state_dict' : self .net .state_dict (),
262279 'optimizer_state_dict' : self .optimizer .state_dict ()}
263- save_name = "{0:}/{1:}_{2:}.pt" .format (ckpt_dir , ckpt_prefix , glob_it )
280+ save_name = "{0:}/{1:}_{2:}.pt" .format (ckpt_dir , ckpt_prefix , self . glob_it )
264281 torch .save (save_dict , save_name )
265282 txt_file = open ("{0:}/{1:}_latest.txt" .format (ckpt_dir , ckpt_prefix ), 'wt' )
266- txt_file .write (str (glob_it ))
283+ txt_file .write (str (self . glob_it ))
267284 txt_file .close ()
268-
285+ if (stop_now ):
286+ logging .info ("The training is early stopped" )
287+ break
269288 # save the best performing checkpoint
270289 save_dict = {'iteration' : self .max_val_it ,
271290 'valid_pred' : self .max_val_score ,
0 commit comments