Skip to content

Commit fa5e448

Browse files
kashifyzhangcs
andauthored
add back grad_norm metric (#25)
--------- Co-authored-by: Yu Zhang <yzhang.cs@outlook.com>
1 parent 7f8789e commit fa5e448

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

3rdparty/torchtitan

Submodule torchtitan updated 106 files

flame/train.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -729,7 +729,27 @@ def main(job_config: JobConfig):
729729
train_state.global_max_losses.append(global_max_loss)
730730

731731
# Log using the metric processor
732-
metric_logger.log(train_state.step, global_avg_loss, global_max_loss)
732+
last_lr = lr_schedulers.schedulers[0].get_last_lr()[0]
733+
eta = (
734+
train_state.elapsed
735+
* (job_config.training.steps - train_state.step)
736+
/ train_state.step
737+
)
738+
metric_logger.log(
739+
train_state.step,
740+
global_avg_loss,
741+
global_max_loss,
742+
extra_metrics={
743+
"optimizer/lr": last_lr,
744+
"optimizer/grad_norm": grad_norm.item(),
745+
"optimizer/skipped_step": train_state.skipped_step,
746+
},
747+
)
748+
749+
logger.info(
750+
f"{color.blue}lr: {last_lr:.4e} gnorm: {grad_norm:5.2f} "
751+
f"{color.magenta}[{str(train_state.elapsed).split('.')[0]:>8}<{str(eta).split('.')[0]:>8}]{color.reset}"
752+
)
733753

734754
checkpoint.save(
735755
train_state.step, force=(train_state.step == job_config.training.steps)

0 commit comments

Comments
 (0)