|
15 | 15 | from pathlib import Path |
16 | 16 | from typing import Optional, Tuple, TypedDict |
17 | 17 |
|
| 18 | +import numpy as np |
18 | 19 | import torch |
19 | 20 | from torchdata.stateful_dataloader import StatefulDataLoader |
20 | 21 | from nemo_reinforcer.algorithms.loss_functions import ( |
21 | 22 | NLLLoss, |
22 | 23 | ) |
| 24 | +from nemo_reinforcer.algorithms.utils import set_seed |
23 | 25 | from nemo_reinforcer.data import DataConfig |
24 | 26 | from nemo_reinforcer.data.datasets import AllTaskProcessedDataset, rl_collate_fn |
25 | 27 | from nemo_reinforcer.data.interfaces import TaskDataSpec |
@@ -57,7 +59,7 @@ class SFTConfig(TypedDict): |
57 | 59 | val_global_batch_size: int |
58 | 60 | val_micro_batch_size: int |
59 | 61 | val_at_start: bool |
60 | | - |
| 62 | + seed: int |
61 | 63 |
|
62 | 64 | class MasterConfig(TypedDict): |
63 | 65 | policy: PolicyConfig |
@@ -91,6 +93,8 @@ def setup( |
91 | 93 | Returns: |
92 | 94 | Tuple of policy, cluster, dataloader, tokenizer, loss_fn, math_env, master_config, logger |
93 | 95 | """ |
| 96 | + set_seed(master_config["sft"]["seed"]) |
| 97 | + |
94 | 98 | # Extract individual configs for easier access |
95 | 99 | policy_config = master_config["policy"] |
96 | 100 | data_config = master_config["data"] |
@@ -176,6 +180,7 @@ def setup( |
176 | 180 | print(f" ✓ Model initialized") |
177 | 181 |
|
178 | 182 | logger = Logger(logger_config) |
| 183 | + logger.log_hyperparams(master_config) |
179 | 184 |
|
180 | 185 | print("\n" + "=" * 60) |
181 | 186 | print(" " * 18 + "SETUP COMPLETE") |
@@ -410,11 +415,12 @@ def sft_train( |
410 | 415 | checkpointer.finalize_checkpoint(checkpoint_path) |
411 | 416 |
|
412 | 417 | losses = train_results["loss"] |
413 | | - timing_metrics = timer.get_timing_metrics(reduction_op="sum") |
414 | | - |
415 | 418 | metrics = { |
416 | | - "loss": losses.numpy(), |
| 419 | + "loss": train_results["loss"].numpy(), |
417 | 420 | } |
| 421 | + metrics.update(train_results["all_mb_metrics"]) |
| 422 | + metrics = {k: np.mean(v).item() for k, v in metrics.items()} |
| 423 | + timing_metrics = timer.get_timing_metrics(reduction_op="sum") |
418 | 424 |
|
419 | 425 | print("\n📊 Training Results:") |
420 | 426 | print(f" • Loss: {float(metrics['loss']):.4f}") |
|
0 commit comments