Skip to content

Commit b9a6492

Browse files
authored
Epoch loss (#201)
* update * update
1 parent f311f16 commit b9a6492

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

finetrainers/trainer.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,8 @@ def train(self) -> None:
673673

674674
self.transformer.train()
675675
models_to_accumulate = [self.transformer]
676+
epoch_loss = 0.0
677+
num_loss_updates = 0
676678

677679
for step, batch in enumerate(self.dataloader):
678680
logger.debug(f"Starting step {step + 1}")
@@ -843,14 +845,20 @@ def train(self) -> None:
843845
if should_run_validation:
844846
self.validate(global_step)
845847

846-
logs["loss"] = loss.detach().item()
848+
loss_item = loss.detach().item()
849+
epoch_loss += loss_item
850+
num_loss_updates += 1
851+
logs["step_loss"] = loss_item
847852
logs["lr"] = self.lr_scheduler.get_last_lr()[0]
848853
progress_bar.set_postfix(logs)
849854
accelerator.log(logs, step=global_step)
850855

851856
if global_step >= self.state.train_steps:
852857
break
853858

859+
if num_loss_updates > 0:
860+
epoch_loss /= num_loss_updates
861+
accelerator.log({"epoch_loss": epoch_loss}, step=global_step)
854862
memory_statistics = get_memory_statistics()
855863
logger.info(f"Memory after epoch {epoch + 1}: {json.dumps(memory_statistics, indent=4)}")
856864

0 commit comments

Comments
 (0)