diff --git a/finetrainers/trainer.py b/finetrainers/trainer.py index 81c82647..8f2ebedf 100644 --- a/finetrainers/trainer.py +++ b/finetrainers/trainer.py @@ -673,6 +673,8 @@ def train(self) -> None: self.transformer.train() models_to_accumulate = [self.transformer] + epoch_loss = 0.0 + num_loss_updates = 0 for step, batch in enumerate(self.dataloader): logger.debug(f"Starting step {step + 1}") @@ -843,7 +845,10 @@ def train(self) -> None: if should_run_validation: self.validate(global_step) - logs["loss"] = loss.detach().item() + loss_item = loss.detach().item() + epoch_loss += loss_item + num_loss_updates += 1 + logs["step_loss"] = loss_item logs["lr"] = self.lr_scheduler.get_last_lr()[0] progress_bar.set_postfix(logs) accelerator.log(logs, step=global_step) @@ -851,6 +856,9 @@ def train(self) -> None: if global_step >= self.state.train_steps: break + if num_loss_updates > 0: + epoch_loss /= num_loss_updates + accelerator.log({"epoch_loss": epoch_loss}, step=global_step) memory_statistics = get_memory_statistics() logger.info(f"Memory after epoch {epoch + 1}: {json.dumps(memory_statistics, indent=4)}")