Skip to content

Commit 8780093

Browse files
authored
feat: Add total logging of generations in training (#172)
Signed-off-by: Sahil Jain <[email protected]>
1 parent ce2d121 commit 8780093

File tree

2 files changed

+39
-1
lines changed

2 files changed

+39
-1
lines changed

nemo_reinforcer/algorithms/grpo.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def generate_responses(
290290
tokenizer,
291291
input_lengths: torch.Tensor,
292292
include_logprobs: bool = True,
293-
) -> Tuple[List[torch.Tensor], List[str], torch.Tensor]:
293+
) -> Tuple[BatchedDataDict[DatumSpec], List[List[int]], Dict[str, float | int]]:
294294
"""Generate responses from policy."""
295295
# Generate responses
296296
generation_outputs = policy_generation.generate(generation_input_data)
@@ -452,6 +452,7 @@ def grpo_train(
452452
logger.log_metrics(validation_timings, step, prefix="timing/validation")
453453

454454
# Run grpo training (single-turn)
455+
batch: BatchedDataDict[DatumSpec]
455456
for batch in dataloader:
456457
print(
457458
f"\n{'=' * 25} Step {step + 1}/{min(len(dataloader), master_config['grpo']['max_num_steps'])} {'=' * 25}"
@@ -645,6 +646,14 @@ def grpo_train(
645646
policy.offload_after_refit()
646647

647648
# Logging
649+
# Log training data
650+
log_data = {"content": flat_messages["content"]}
651+
log_data["rewards"] = rewards.tolist()
652+
log_data["generation_logprobs"] = train_data["generation_logprobs"].tolist()
653+
log_data["prev_logprobs"] = train_data["prev_logprobs"].tolist()
654+
log_data["input_lengths"] = input_lengths.tolist()
655+
logger.log_batched_dict_as_jsonl(log_data, f"train_data_step{step}.jsonl")
656+
648657
print("\n📊 Training Results:")
649658
metrics = {
650659
"loss": train_results["loss"].numpy(),

nemo_reinforcer/utils/logger.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import time
2020
import threading
2121
import requests
22+
import json
2223
from abc import ABC, abstractmethod
2324
import logging
2425
from typing import List, Any, Dict, Optional, TypedDict, Union
@@ -27,8 +28,10 @@
2728
from rich.panel import Panel
2829
from rich.box import ROUNDED
2930
from rich.logging import RichHandler
31+
import torch
3032

3133
from nemo_reinforcer.data.interfaces import LLMMessageLogType
34+
from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict
3235
from torch.utils.tensorboard import SummaryWriter
3336

3437
import ray
@@ -568,6 +571,32 @@ def log_hyperparams(self, params: Dict[str, Any]) -> None:
568571
for logger in self.loggers:
569572
logger.log_hyperparams(params)
570573

574+
def log_batched_dict_as_jsonl(
575+
self, to_log: BatchedDataDict | Dict[str, Any], filename: str
576+
) -> None:
577+
"""Log a list of dictionaries to a JSONL file.
578+
579+
Args:
580+
to_log: BatchedDataDict to log
581+
filename: Filename to log to (within the log directory)
582+
"""
583+
if not isinstance(to_log, BatchedDataDict):
584+
to_log = BatchedDataDict(to_log)
585+
586+
# Create full path within log directory
587+
filepath = os.path.join(self.base_log_dir, filename)
588+
os.makedirs(os.path.dirname(filepath), exist_ok=True)
589+
590+
# Write to JSONL file
591+
with open(filepath, "w") as f:
592+
for i, sample in enumerate(to_log.make_microbatch_iterator(1)):
593+
for key, value in sample.items():
594+
if isinstance(value, torch.Tensor):
595+
sample[key] = value.tolist()
596+
f.write(json.dumps({**sample, "idx": i}) + "\n")
597+
598+
print(f"Logged data to {filepath}")
599+
571600
def __del__(self):
572601
"""Clean up resources when the logger is destroyed."""
573602
if self.gpu_monitor:

0 commit comments

Comments
 (0)