diff --git a/megatron/rl/rl_utils.py b/megatron/rl/rl_utils.py index 364a80db81e..5d159aa91de 100644 --- a/megatron/rl/rl_utils.py +++ b/megatron/rl/rl_utils.py @@ -795,9 +795,13 @@ def maybe_log_training_metrics( advantages = wandb_writer.Table( columns=['advantages'], data=[[x] for x in group_stats.advantages] ) + stats = torch.cuda.memory_stats() + # 1024*1024 = 1048576 + n_split_megabytes = stats.get("inactive_split_bytes.all.current", 0)/1048576 wandb_writer.log( { **{ + 'split_megabytes': n_split_megabytes, 'group_means_hist': wandb_writer.plot.histogram( group_table, 'group_means', 'Group Means' ),