@@ -226,13 +226,13 @@ def on_train_epoch_start(self, epoch):
226226 self .trainer .call_hook ("on_epoch_start" )
227227 self .trainer .call_hook ("on_train_epoch_start" )
228228
229- def on_train_batch_end (self , epoch_output , epoch_end_outputs , batch , batch_idx , dataloader_idx ):
229+ def on_train_batch_end (self , epoch_output , batch_end_outputs , batch , batch_idx , dataloader_idx ):
230230 # hook
231231 self .trainer .call_hook ('on_batch_end' )
232- self .trainer .call_hook ('on_train_batch_end' , epoch_end_outputs , batch , batch_idx , dataloader_idx )
232+ self .trainer .call_hook ('on_train_batch_end' , batch_end_outputs , batch , batch_idx , dataloader_idx )
233233
234234 # figure out what to track for epoch end
235- self .track_epoch_end_reduce_metrics (epoch_output , epoch_end_outputs )
235+ self .track_epoch_end_reduce_metrics (epoch_output , batch_end_outputs )
236236
237237 # reset batch logger internals
238238 self .trainer .logger_connector .on_train_batch_end ()
@@ -244,12 +244,27 @@ def reset_train_val_dataloaders(self, model):
244244 if self .trainer .val_dataloaders is None and not self .trainer .reload_dataloaders_every_epoch :
245245 self .trainer .reset_val_dataloader (model )
246246
247- def track_epoch_end_reduce_metrics (self , epoch_output , epoch_end_outputs ):
247+ def track_epoch_end_reduce_metrics (self , epoch_output , batch_end_outputs ):
248+
248249 # track the outputs to reduce at the end of the epoch
249- for opt_idx , opt_outputs in enumerate (epoch_end_outputs ):
250+ for opt_idx , opt_outputs in enumerate (batch_end_outputs ):
251+ sample_output = opt_outputs [- 1 ]
252+
253+ # decide if we need to reduce at the end of the epoch automatically
254+ auto_reduce_tng_result = isinstance (sample_output , Result ) and sample_output .should_reduce_on_epoch_end
255+ hook_overridden = (
256+ is_overridden ("training_epoch_end" , model = self .trainer .get_model ()) or
257+ is_overridden ("on_train_epoch_end" , model = self .trainer .get_model ())
258+ )
259+
260+ # only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end
261+ if not (hook_overridden or auto_reduce_tng_result ):
262+ continue
263+
250264 # with 1 step (no tbptt) don't use a sequence at epoch end
251265 if isinstance (opt_outputs , list ) and len (opt_outputs ) == 1 and not isinstance (opt_outputs [0 ], Result ):
252266 opt_outputs = opt_outputs [0 ]
267+
253268 epoch_output [opt_idx ].append (opt_outputs )
254269
255270 def get_optimizers_iterable (self ):
@@ -537,17 +552,14 @@ def run_training_epoch(self):
537552 if batch_output .signal == - 1 :
538553 break
539554
540- # only track outputs when user implements training_epoch_end
541- # otherwise we will build up unnecessary memory
542- epoch_end_outputs = self .process_train_step_outputs (
555+ batch_end_outputs = self .process_train_step_outputs (
543556 batch_output .training_step_output_for_epoch_end ,
544557 self .early_stopping_accumulator ,
545558 self .checkpoint_accumulator ,
546559 )
547-
548560 # hook
549561 # TODO: add outputs to batches
550- self .on_train_batch_end (epoch_output , epoch_end_outputs , batch , batch_idx , dataloader_idx )
562+ self .on_train_batch_end (epoch_output , batch_end_outputs , batch , batch_idx , dataloader_idx )
551563
552564 # -----------------------------------------
553565 # SAVE METRICS TO LOGGERS
@@ -901,7 +913,7 @@ def process_train_step_outputs(self, all_train_step_outputs, early_stopping_accu
901913 # the training step outputs a list per optimizer. The list contains the outputs at each time step
902914 # when no TBPTT is used, then the list has 1 item per batch
903915 # when TBPTT IS used, then the list has n items (1 per time step)
904- epoch_end_outputs = []
916+ batch_end_outputs = []
905917 for optimizer_idx_outputs in all_train_step_outputs :
906918 # extract one representative sample from each time step (1 if no tbptt) and 0th optimizer
907919 if len (optimizer_idx_outputs ) == 0 :
@@ -916,14 +928,9 @@ def process_train_step_outputs(self, all_train_step_outputs, early_stopping_accu
916928 if isinstance (sample_output , dict ) and "checkpoint_on" in sample_output :
917929 checkpoint_accumulator .accumulate (sample_output ["checkpoint_on" ])
918930
919- # decide if we need to reduce at the end of the epoch automatically
920- auto_reduce_tng_result = isinstance (sample_output , Result ) and sample_output .should_reduce_on_epoch_end
921-
922- # only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end
923- if is_overridden ("training_epoch_end" , model = self .trainer .get_model ()) or auto_reduce_tng_result :
924- epoch_end_outputs .append (optimizer_idx_outputs )
931+ batch_end_outputs .append (optimizer_idx_outputs )
925932
926- return epoch_end_outputs
933+ return batch_end_outputs
927934
928935 def prepare_optimizers (self ):
929936 # in manual optimization we loop over all optimizers at once
0 commit comments