Skip to content

Commit 60633f4

Browse files
feat: add throughput/prompt_length/total_num_tokens metrics (#781)
Signed-off-by: Zhiyu Li <zhiyul@nvidia.com> Signed-off-by: Zhiyu Li <zhiyul@NVIDIA.com>
1 parent bec9cde commit 60633f4

File tree

1 file changed

+23
-1
lines changed

1 file changed

+23
-1
lines changed

nemo_rl/algorithms/grpo.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -765,10 +765,19 @@ def grpo_train(
765765
"loss": train_results["loss"].numpy(),
766766
"reward": rewards.numpy(),
767767
"grad_norm": train_results["grad_norm"].numpy(),
768+
"mean_prompt_length": repeated_batch["length"].numpy(),
769+
"total_num_tokens": input_lengths.numpy(),
768770
}
769771
metrics.update(train_results["all_mb_metrics"])
770772
for k, v in metrics.items():
771-
if k in {"lr", "wd", "reward", "global_valid_seqs", "global_valid_toks"}:
773+
if k in {
774+
"lr",
775+
"wd",
776+
"reward",
777+
"global_valid_seqs",
778+
"global_valid_toks",
779+
"mean_prompt_length",
780+
}:
772781
metrics[k] = np.mean(v).item()
773782
else:
774783
metrics[k] = np.sum(v).item()
@@ -801,6 +810,19 @@ def grpo_train(
801810
print("\n⏱️ Timing:")
802811
# Display total time first, separately
803812
total_time = timing_metrics.get("total_step_time", 0)
813+
814+
total_num_gpus = (
815+
master_config["cluster"]["num_nodes"]
816+
* master_config["cluster"]["gpus_per_node"]
817+
)
818+
metrics.update(
819+
{
820+
"tokens_per_sec_per_gpu": metrics["total_num_tokens"]
821+
/ total_time
822+
/ total_num_gpus
823+
}
824+
)
825+
804826
print(f" • Total step time: {total_time:.2f}s")
805827

806828
# Display all other timing metrics

0 commit comments

Comments
 (0)