Skip to content

Commit 140cd97

Browse files
feat: log generation ISL/OSL histogram to wandb (#1594)
Signed-off-by: Youngeun Kwon <[email protected]>
1 parent ed9cab7 commit 140cd97

File tree

4 files changed

+90
-7
lines changed

4 files changed

+90
-7
lines changed

nemo_rl/algorithms/grpo.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1591,6 +1591,20 @@ def grpo_train(
15911591
logger,
15921592
)
15931593

1594+
# Plot ISL/OSL/ISL+OSL histograms to wandb
1595+
if (
1596+
master_config["policy"]["generation"]
1597+
.get("vllm_cfg", {})
1598+
.get("async_engine", False)
1599+
):
1600+
for metric_name in metrics.keys():
1601+
if metric_name.startswith("histogram/"):
1602+
logger.log_histogram(
1603+
metrics[metric_name],
1604+
total_steps + 1,
1605+
f"generation_metrics/{metric_name}",
1606+
)
1607+
15941608
print("\n📊 Training Results:")
15951609

15961610
print(f" • Loss: {metrics['loss']:.4f}")
@@ -2528,6 +2542,20 @@ def async_grpo_train(
25282542
logger,
25292543
)
25302544

2545+
# Plot ISL/OSL/ISL+OSL histograms to wandb
2546+
if (
2547+
master_config["policy"]["generation"]
2548+
.get("vllm_cfg", {})
2549+
.get("async_engine", False)
2550+
):
2551+
for metric_name in metrics.keys():
2552+
if metric_name.startswith("histogram/"):
2553+
logger.log_histogram(
2554+
metrics[metric_name],
2555+
step + 1,
2556+
f"generation_metrics/{metric_name}",
2557+
)
2558+
25312559
print("\n📊 Training Results:")
25322560
print(f" • Loss: {metrics['loss']:.4f}")
25332561
print(f" • Generation KL Error: {metrics['gen_kl_error']:.4f}")

nemo_rl/algorithms/utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -748,24 +748,24 @@ def visualize_per_worker_timeline(
748748

749749

750750
def log_generation_metrics_to_wandb(
751-
vllm_logger_metrics: dict[str, dict[int, list[Any]]],
751+
generation_logger_metrics: dict[str, dict[int, list[Any]]],
752752
step: int,
753753
timeline_interval: float,
754754
logger: Logger,
755755
) -> None:
756-
"""Log vLLM metrics to wandb.
756+
"""Log generation metrics to wandb.
757757
758758
Args:
759-
vllm_logger_metrics: Dictionary of vLLM logger metrics
759+
generation_logger_metrics: Dictionary of generation logger metrics
760760
step: Global step value
761761
timeline_interval: Interval between timeline points (in seconds)
762762
logger: Logger instance
763763
"""
764-
for vllm_metric in vllm_logger_metrics.keys():
764+
for generation_metric in generation_logger_metrics.keys():
765765
logger.log_plot_per_worker_timeline_metrics(
766-
vllm_logger_metrics[vllm_metric],
766+
generation_logger_metrics[generation_metric],
767767
step=step,
768-
prefix="vllm_metrics",
769-
name=vllm_metric,
768+
prefix="generation_metrics",
769+
name=generation_metric,
770770
timeline_interval=timeline_interval,
771771
)

nemo_rl/experience/rollouts.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,8 @@ async def run_sample_multi_turn_rollout(
654654

655655
# Track per-turn metrics
656656
turn_gen_tokens = []
657+
turn_input_tokens = []
658+
turn_total_tokens = []
657659
# Track per-turn per-worker token accounting if available
658660
per_worker_token_counts = {} # worker_idx -> token_count
659661

@@ -685,6 +687,8 @@ async def run_sample_multi_turn_rollout(
685687
assistant_token_count += gen_token_count
686688
token_count += gen_token_count
687689
turn_gen_tokens.append(gen_token_count)
690+
turn_input_tokens.append(int(input_lengths))
691+
turn_total_tokens.append(int(input_lengths) + gen_token_count)
688692
# Per-worker load accounting
689693
if "gen_leader_worker_idx" in gen_metrics:
690694
worker_idx = int(gen_metrics["gen_leader_worker_idx"])
@@ -770,6 +774,8 @@ async def run_sample_multi_turn_rollout(
770774
"max_turns_reached": max_turns_reached,
771775
"total_reward": total_reward,
772776
"turn_gen_tokens": turn_gen_tokens,
777+
"turn_input_tokens": turn_input_tokens,
778+
"turn_total_tokens": turn_total_tokens,
773779
# Pass-through per-worker per-turn accounting for aggregation at batch level
774780
"per_worker_token_counts": per_worker_token_counts,
775781
}
@@ -930,6 +936,17 @@ async def run_single_sample_with_error_handling(i, sample_state):
930936
per_worker_token_counts[k] = per_worker_token_counts.get(k, 0) + v
931937
rollout_metrics["per_worker_token_counts"] = per_worker_token_counts
932938

939+
# Collect ISL, OSL, and ISL+OSL metrics for all samples
940+
rollout_metrics["histogram/gen_tokens_length"] = [
941+
t for m in all_sample_metrics for t in m["turn_gen_tokens"]
942+
]
943+
rollout_metrics["histogram/input_tokens_length"] = [
944+
t for m in all_sample_metrics for t in m["turn_input_tokens"]
945+
]
946+
rollout_metrics["histogram/total_tokens_length"] = [
947+
t for m in all_sample_metrics for t in m["turn_total_tokens"]
948+
]
949+
933950
return final_batch, rollout_metrics
934951

935952
return asyncio.run(_async_rollout_implementation())

nemo_rl/utils/logger.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,11 @@ def log_hyperparams(self, params: Mapping[str, Any]) -> None:
108108
"""Log dictionary of hyperparameters."""
109109
pass
110110

111+
@abstractmethod
112+
def log_histogram(self, histogram: list[Any], step: int, name: str) -> None:
113+
"""Log histogram metrics."""
114+
pass
115+
111116

112117
class TensorboardLogger(LoggerInterface):
113118
"""Tensorboard logger backend."""
@@ -153,6 +158,10 @@ def log_metrics(
153158
print(f"Warning: Failed to log metric '{name}' to TensorBoard: {e}")
154159
continue
155160

161+
def log_histogram(self, histogram: list[Any], step: int, name: str) -> None:
162+
"""Log histogram metrics to Tensorboard."""
163+
return
164+
156165
def log_hyperparams(self, params: Mapping[str, Any]) -> None:
157166
"""Log hyperparameters to Tensorboard.
158167
@@ -350,6 +359,16 @@ def log_plot(self, figure: plt.Figure, step: int, name: str) -> None:
350359
"""
351360
self.run.log({name: figure}, step=step)
352361

362+
def log_histogram(self, histogram: list[Any], step: int, name: str) -> None:
363+
"""Log histogram metrics to wandb.
364+
365+
Args:
366+
histogram: List of histogram values
367+
step: Global step value
368+
name: Name of the metric
369+
"""
370+
self.run.log({name: wandb.Histogram(histogram)}, step=step)
371+
353372

354373
class SwanlabLogger(LoggerInterface):
355374
"""SwanLab logger backend."""
@@ -419,6 +438,10 @@ def log_plot(self, figure: plt.Figure, step: int, name: str) -> None:
419438
"""
420439
self.run.log({name: figure}, step=step)
421440

441+
def log_histogram(self, histogram: list[Any], step: int, name: str) -> None:
442+
"""Log histogram metrics to swanlab."""
443+
return
444+
422445

423446
class GpuMetricSnapshot(TypedDict):
424447
step: int
@@ -793,6 +816,10 @@ def log_plot(self, figure: plt.Figure, step: int, name: str) -> None:
793816
figure.savefig(tmp_file.name, format="png", bbox_inches="tight")
794817
mlflow.log_artifact(tmp_file.name, f"plots/{name}")
795818

819+
def log_histogram(self, histogram: list[Any], step: int, name: str) -> None:
820+
"""Log histogram metrics to MLflow."""
821+
return
822+
796823
def __del__(self) -> None:
797824
"""Clean up resources when the logger is destroyed."""
798825
try:
@@ -1017,6 +1044,17 @@ def log_plot_per_worker_timeline_metrics(
10171044
logger.log_plot(fig, step, f"{prefix}/average_{name}")
10181045
plt.close(fig)
10191046

1047+
def log_histogram(self, histogram: list[Any], step: int, name: str) -> None:
1048+
"""Log histogram metrics to all backends if available.
1049+
1050+
Args:
1051+
histogram: List of histogram values
1052+
step: Global step value
1053+
name: Name of the metric
1054+
"""
1055+
for logger in self.loggers:
1056+
logger.log_histogram(histogram, step, name)
1057+
10201058
def log_plot_token_mult_prob_error(
10211059
self, data: dict[str, Any], step: int, name: str
10221060
) -> None:

0 commit comments

Comments
 (0)