@@ -145,7 +145,8 @@ def training(self):
145145 self .optimizer .zero_grad ()
146146 # forward + backward + optimize
147147 outputs = self .net (inputs )
148- loss = self .get_loss_value (data , inputs , outputs , labels )
148+
149+ loss = self .get_loss_value (data , outputs , labels )
149150 loss .backward ()
150151 self .optimizer .step ()
151152 self .scheduler .step ()
@@ -175,7 +176,7 @@ def validation(self):
175176 self .optimizer .zero_grad ()
176177 # forward + backward + optimize
177178 outputs = self .net (inputs )
178- loss = self .get_loss_value (data , inputs , outputs , labels )
179+ loss = self .get_loss_value (data , outputs , labels )
179180
180181 # statistics
181182 sample_num += labels .size (0 )
@@ -243,10 +244,11 @@ def train_valid(self):
243244 logging .info ("{0:} training start" .format (str (datetime .now ())[:- 7 ]))
244245 self .summ_writer = SummaryWriter (self .config ['training' ]['ckpt_save_dir' ])
245246 for it in range (iter_start , iter_max , iter_valid ):
247+ lr_value = self .optimizer .param_groups [0 ]['lr' ]
246248 train_scalars = self .training ()
247249 valid_scalars = self .validation ()
248250 glob_it = it + iter_valid
249- self .write_scalars (train_scalars , valid_scalars , glob_it )
251+ self .write_scalars (train_scalars , valid_scalars , lr_value , glob_it )
250252
251253 if (valid_scalars [metrics ] > self .max_val_score ):
252254 self .max_val_score = valid_scalars [metrics ]
0 commit comments