@@ -226,13 +226,13 @@ def on_train_epoch_start(self, epoch):
226
226
self .trainer .call_hook ("on_epoch_start" )
227
227
self .trainer .call_hook ("on_train_epoch_start" )
228
228
229
- def on_train_batch_end (self , epoch_output , epoch_end_outputs , batch , batch_idx , dataloader_idx ):
229
+ def on_train_batch_end (self , epoch_output , batch_end_outputs , batch , batch_idx , dataloader_idx ):
230
230
# hook
231
231
self .trainer .call_hook ('on_batch_end' )
232
- self .trainer .call_hook ('on_train_batch_end' , epoch_end_outputs , batch , batch_idx , dataloader_idx )
232
+ self .trainer .call_hook ('on_train_batch_end' , batch_end_outputs , batch , batch_idx , dataloader_idx )
233
233
234
234
# figure out what to track for epoch end
235
- self .track_epoch_end_reduce_metrics (epoch_output , epoch_end_outputs )
235
+ self .track_epoch_end_reduce_metrics (epoch_output , batch_end_outputs )
236
236
237
237
# reset batch logger internals
238
238
self .trainer .logger_connector .on_train_batch_end ()
@@ -244,12 +244,27 @@ def reset_train_val_dataloaders(self, model):
244
244
if self .trainer .val_dataloaders is None and not self .trainer .reload_dataloaders_every_epoch :
245
245
self .trainer .reset_val_dataloader (model )
246
246
247
- def track_epoch_end_reduce_metrics (self , epoch_output , epoch_end_outputs ):
247
+ def track_epoch_end_reduce_metrics (self , epoch_output , batch_end_outputs ):
248
+
248
249
# track the outputs to reduce at the end of the epoch
249
- for opt_idx , opt_outputs in enumerate (epoch_end_outputs ):
250
+ for opt_idx , opt_outputs in enumerate (batch_end_outputs ):
251
+ sample_output = opt_outputs [- 1 ]
252
+
253
+ # decide if we need to reduce at the end of the epoch automatically
254
+ auto_reduce_tng_result = isinstance (sample_output , Result ) and sample_output .should_reduce_on_epoch_end
255
+ hook_overridden = (
256
+ is_overridden ("training_epoch_end" , model = self .trainer .get_model ()) or
257
+ is_overridden ("on_train_epoch_end" , model = self .trainer .get_model ())
258
+ )
259
+
260
+ # only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end
261
+ if not (hook_overridden or auto_reduce_tng_result ):
262
+ continue
263
+
250
264
# with 1 step (no tbptt) don't use a sequence at epoch end
251
265
if isinstance (opt_outputs , list ) and len (opt_outputs ) == 1 and not isinstance (opt_outputs [0 ], Result ):
252
266
opt_outputs = opt_outputs [0 ]
267
+
253
268
epoch_output [opt_idx ].append (opt_outputs )
254
269
255
270
def get_optimizers_iterable (self ):
@@ -537,17 +552,14 @@ def run_training_epoch(self):
537
552
if batch_output .signal == - 1 :
538
553
break
539
554
540
- # only track outputs when user implements training_epoch_end
541
- # otherwise we will build up unnecessary memory
542
- epoch_end_outputs = self .process_train_step_outputs (
555
+ batch_end_outputs = self .process_train_step_outputs (
543
556
batch_output .training_step_output_for_epoch_end ,
544
557
self .early_stopping_accumulator ,
545
558
self .checkpoint_accumulator ,
546
559
)
547
-
548
560
# hook
549
561
# TODO: add outputs to batches
550
- self .on_train_batch_end (epoch_output , epoch_end_outputs , batch , batch_idx , dataloader_idx )
562
+ self .on_train_batch_end (epoch_output , batch_end_outputs , batch , batch_idx , dataloader_idx )
551
563
552
564
# -----------------------------------------
553
565
# SAVE METRICS TO LOGGERS
@@ -901,7 +913,7 @@ def process_train_step_outputs(self, all_train_step_outputs, early_stopping_accu
901
913
# the training step outputs a list per optimizer. The list contains the outputs at each time step
902
914
# when no TBPTT is used, then the list has 1 item per batch
903
915
# when TBPTT IS used, then the list has n items (1 per time step)
904
- epoch_end_outputs = []
916
+ batch_end_outputs = []
905
917
for optimizer_idx_outputs in all_train_step_outputs :
906
918
# extract one representative sample from each time step (1 if no tbptt) and 0th optimizer
907
919
if len (optimizer_idx_outputs ) == 0 :
@@ -916,14 +928,9 @@ def process_train_step_outputs(self, all_train_step_outputs, early_stopping_accu
916
928
if isinstance (sample_output , dict ) and "checkpoint_on" in sample_output :
917
929
checkpoint_accumulator .accumulate (sample_output ["checkpoint_on" ])
918
930
919
- # decide if we need to reduce at the end of the epoch automatically
920
- auto_reduce_tng_result = isinstance (sample_output , Result ) and sample_output .should_reduce_on_epoch_end
921
-
922
- # only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end
923
- if is_overridden ("training_epoch_end" , model = self .trainer .get_model ()) or auto_reduce_tng_result :
924
- epoch_end_outputs .append (optimizer_idx_outputs )
931
+ batch_end_outputs .append (optimizer_idx_outputs )
925
932
926
- return epoch_end_outputs
933
+ return batch_end_outputs
927
934
928
935
def prepare_optimizers (self ):
929
936
# in manual optimization we loop over all optimizers at once
0 commit comments