|
19 | 19 | import time |
20 | 20 | import threading |
21 | 21 | import requests |
| 22 | +import json |
22 | 23 | from abc import ABC, abstractmethod |
23 | 24 | import logging |
24 | 25 | from typing import List, Any, Dict, Optional, TypedDict, Union |
|
27 | 28 | from rich.panel import Panel |
28 | 29 | from rich.box import ROUNDED |
29 | 30 | from rich.logging import RichHandler |
| 31 | +import torch |
30 | 32 |
|
31 | 33 | from nemo_reinforcer.data.interfaces import LLMMessageLogType |
| 34 | +from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict |
32 | 35 | from torch.utils.tensorboard import SummaryWriter |
33 | 36 |
|
34 | 37 | import ray |
@@ -568,6 +571,32 @@ def log_hyperparams(self, params: Dict[str, Any]) -> None: |
568 | 571 | for logger in self.loggers: |
569 | 572 | logger.log_hyperparams(params) |
570 | 573 |
|
| 574 | + def log_batched_dict_as_jsonl( |
| 575 | + self, to_log: BatchedDataDict | Dict[str, Any], filename: str |
| 576 | + ) -> None: |
| 577 | + """Log a list of dictionaries to a JSONL file. |
| 578 | +
|
| 579 | + Args: |
| 580 | + to_log: BatchedDataDict to log |
| 581 | + filename: Filename to log to (within the log directory) |
| 582 | + """ |
| 583 | + if not isinstance(to_log, BatchedDataDict): |
| 584 | + to_log = BatchedDataDict(to_log) |
| 585 | + |
| 586 | + # Create full path within log directory |
| 587 | + filepath = os.path.join(self.base_log_dir, filename) |
| 588 | + os.makedirs(os.path.dirname(filepath), exist_ok=True) |
| 589 | + |
| 590 | + # Write to JSONL file |
| 591 | + with open(filepath, "w") as f: |
| 592 | + for i, sample in enumerate(to_log.make_microbatch_iterator(1)): |
| 593 | + for key, value in sample.items(): |
| 594 | + if isinstance(value, torch.Tensor): |
| 595 | + sample[key] = value.tolist() |
| 596 | + f.write(json.dumps({**sample, "idx": i}) + "\n") |
| 597 | + |
| 598 | + print(f"Logged data to {filepath}") |
| 599 | + |
571 | 600 | def __del__(self): |
572 | 601 | """Clean up resources when the logger is destroyed.""" |
573 | 602 | if self.gpu_monitor: |
|
0 commit comments