Skip to content

Commit 96656c3

Browse files
fix: Fix the logger error in non-colocated sync-grpo code path (#1355)
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
1 parent 8f6e00e commit 96656c3

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

nemo_rl/algorithms/grpo.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1798,9 +1798,6 @@ def async_grpo_train(
17981798
train_results, metrics, timing_metrics, master_config
17991799
)
18001800

1801-
if "per_worker_token_counts" in metrics:
1802-
del metrics["per_worker_token_counts"]
1803-
18041801
logger.log_metrics(performance_metrics, step + 1, prefix="performance")
18051802
logger.log_metrics(metrics, step + 1, prefix="train")
18061803
logger.log_metrics(timing_metrics, step + 1, prefix="timing/train")

nemo_rl/algorithms/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,15 @@ def visualize_per_worker_load(per_worker_token_counts: dict[int, int]) -> float:
553553
total_tflops / theoretical_tflops
554554
)
555555

556+
# =====================================================
557+
# Clean up metrics
558+
# =====================================================
559+
560+
# Clean up metrics to avoid wandb logging errors
561+
# Dict structures cannot be logged to wandb
562+
if "per_worker_token_counts" in metrics:
563+
del metrics["per_worker_token_counts"]
564+
556565
# =====================================================
557566
# Logging
558567
# =====================================================

0 commit comments

Comments
 (0)