Skip to content

Commit f530ded

Browse files
yfwSahilJain314
andauthored
feat: SFT convergence run changes (#21)
Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com> Co-authored-by: Sahil Jain <48468750+SahilJain314@users.noreply.github.com>
1 parent 8fb070f commit f530ded

File tree

9 files changed

+78
-31
lines changed

9 files changed

+78
-31
lines changed

examples/configs/grpo_math_1B.yaml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,18 @@ policy:
2727
train_global_batch_size: 32
2828
train_micro_batch_size: 4
2929
generation_batch_size: 32
30-
learning_rate: 5.0e-6
3130
logprob_batch_size: 4
3231
max_total_sequence_length: 512
3332
precision: "bfloat16"
3433

34+
optimizer:
35+
name: "torch.optim.AdamW"
36+
kwargs:
37+
lr: 5.0e-6
38+
weight_decay: 0.01
39+
betas: [0.9, 0.999]
40+
eps: 1e-8
41+
3542
scheduler:
3643
- name: "torch.optim.lr_scheduler.LinearLR"
3744
kwargs:

examples/configs/grpo_math_8B.yaml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,18 @@ policy:
66
train_global_batch_size: 32
77
train_micro_batch_size: 1
88
generation_batch_size: 32
9-
learning_rate: 5.0e-6
109
logprob_batch_size: 2
1110
max_total_sequence_length: 4096
1211
precision: "bfloat16"
1312

13+
optimizer:
14+
name: "torch.optim.AdamW"
15+
kwargs:
16+
lr: 5.0e-6
17+
weight_decay: 0.01
18+
betas: [0.9, 0.999]
19+
eps: 1e-8
20+
1421
scheduler:
1522
- name: "torch.optim.lr_scheduler.LinearLR"
1623
kwargs:

examples/configs/sft.yaml

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
# SFT Algorithm Configuration
22
sft:
3-
max_num_steps: 20
3+
max_num_steps: 1000
44
val_period: 10
55
val_batches: 8
6-
val_global_batch_size: 32
7-
val_micro_batch_size: 2
6+
val_global_batch_size: 128
7+
val_micro_batch_size: 1
88
val_at_start: true
9+
seed: 42
910

1011
checkpointing:
1112
enabled: true
@@ -16,27 +17,30 @@ checkpointing:
1617
save_period: 10
1718

1819
policy:
19-
model_name: "meta-llama/Llama-3.2-1B-Instruct"
20-
train_global_batch_size: 32
21-
train_micro_batch_size: 2
22-
learning_rate: 5.0e-6
23-
max_total_sequence_length: 1024
20+
model_name: "meta-llama/Meta-Llama-3-8B"
21+
train_global_batch_size: 128
22+
train_micro_batch_size: 1
23+
max_total_sequence_length: 2048
2424
precision: "float32"
2525

26+
optimizer:
27+
name: "torch.optim.AdamW"
28+
kwargs:
29+
lr: 5.0e-6
30+
weight_decay: 0.1
31+
betas: [0.9, 0.98]
32+
eps: 1e-5
33+
2634
scheduler:
27-
- name: "torch.optim.lr_scheduler.LinearLR"
28-
kwargs:
29-
start_factor: 0.1
30-
end_factor: 1.0
31-
total_iters: 100
32-
- name: "torch.optim.lr_scheduler.CosineAnnealingLR"
33-
kwargs:
34-
T_max: 100
35-
- milestones: [50]
35+
name: "torch.optim.lr_scheduler.LinearLR"
36+
kwargs:
37+
start_factor: 0.0196078
38+
end_factor: 1.0
39+
total_iters: 50
3640

3741
data:
3842
max_input_seq_length: ${policy.max_total_sequence_length}
39-
dataset_name: "open_assistant"
43+
dataset_name: "squad"
4044

4145
logger:
4246
log_dir: "logs" # Base directory for all logs
@@ -49,5 +53,5 @@ logger:
4953
log_dir: "tb_logs"
5054

5155
cluster:
52-
gpus_per_node: 1
56+
gpus_per_node: 8
5357
num_nodes: 1

nemo_reinforcer/algorithms/loss_functions.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,10 @@ def __call__(
160160

161161
# Only compute loss on generated tokens (not input tokens)
162162
# by applying the token_loss_mask (shifted by 1 since we're predicting next tokens)
163-
loss = -torch.sum(token_logprobs * mask)
163+
num_unmasked_tokens = torch.sum(mask)
164+
if num_unmasked_tokens == 0:
165+
# prevent division by zero
166+
num_unmasked_tokens = torch.tensor(1)
167+
loss = -torch.sum(token_logprobs * mask) / num_unmasked_tokens
164168

165-
return loss, {"loss": loss.item()}
169+
return loss, {"loss": loss.item(), "num_unmasked_tokens": num_unmasked_tokens.item(), "total_tokens": mask.numel()}

nemo_reinforcer/algorithms/sft.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@
1515
from pathlib import Path
1616
from typing import Optional, Tuple, TypedDict
1717

18+
import numpy as np
1819
import torch
1920
from torchdata.stateful_dataloader import StatefulDataLoader
2021
from nemo_reinforcer.algorithms.loss_functions import (
2122
NLLLoss,
2223
)
24+
from nemo_reinforcer.algorithms.utils import set_seed
2325
from nemo_reinforcer.data import DataConfig
2426
from nemo_reinforcer.data.datasets import AllTaskProcessedDataset, rl_collate_fn
2527
from nemo_reinforcer.data.interfaces import TaskDataSpec
@@ -57,7 +59,7 @@ class SFTConfig(TypedDict):
5759
val_global_batch_size: int
5860
val_micro_batch_size: int
5961
val_at_start: bool
60-
62+
seed: int
6163

6264
class MasterConfig(TypedDict):
6365
policy: PolicyConfig
@@ -91,6 +93,8 @@ def setup(
9193
Returns:
9294
Tuple of policy, cluster, dataloader, tokenizer, loss_fn, math_env, master_config, logger
9395
"""
96+
set_seed(master_config["sft"]["seed"])
97+
9498
# Extract individual configs for easier access
9599
policy_config = master_config["policy"]
96100
data_config = master_config["data"]
@@ -176,6 +180,7 @@ def setup(
176180
print(f" ✓ Model initialized")
177181

178182
logger = Logger(logger_config)
183+
logger.log_hyperparams(master_config)
179184

180185
print("\n" + "=" * 60)
181186
print(" " * 18 + "SETUP COMPLETE")
@@ -410,11 +415,12 @@ def sft_train(
410415
checkpointer.finalize_checkpoint(checkpoint_path)
411416

412417
losses = train_results["loss"]
413-
timing_metrics = timer.get_timing_metrics(reduction_op="sum")
414-
415418
metrics = {
416-
"loss": losses.numpy(),
419+
"loss": train_results["loss"].numpy(),
417420
}
421+
metrics.update(train_results["all_mb_metrics"])
422+
metrics = {k: np.mean(v).item() for k, v in metrics.items()}
423+
timing_metrics = timer.get_timing_metrics(reduction_op="sum")
418424

419425
print("\n📊 Training Results:")
420426
print(f" • Loss: {float(metrics['loss']):.4f}")

nemo_reinforcer/algorithms/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import random
1415
import warnings
1516
from functools import wraps
1617

18+
import numpy as np
1719
import torch
1820
from torch.masked import as_masked_tensor
1921

@@ -120,3 +122,10 @@ def masked_mean(values, mask, dim=None):
120122
if dim is None:
121123
return values[mask.bool()].mean()
122124
return as_masked_tensor(values, mask.bool()).mean(dim=dim).to_tensor(torch.nan)
125+
126+
def set_seed(seed: int):
127+
"""Sets the seed for python, numpy, and pytorch."""
128+
random.seed(seed)
129+
np.random.seed(seed)
130+
torch.manual_seed(seed)
131+
torch.cuda.manual_seed_all(seed)

nemo_reinforcer/data/hf_datasets/squad.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __init__(self):
4141
original_ds = load_dataset("rajpurkar/squad")
4242
self.formatted_ds = original_ds.map(format_squad)
4343

44-
custom_template = "{% for message in messages %}{%- if message['role'] == 'system' %}{{'Context: ' + message['content'].strip()}}{%- elif message['role'] == 'user' %}{{' Question: ' + message['content'].strip() + ' Answer: '}}{%- elif message['role'] == 'assistant' %}{{message['content'].strip()}}{%- endif %}{% endfor %}"
44+
custom_template = "{% for message in messages %}{%- if message['role'] == 'system' %}{{'Context: ' + message['content'].strip()}}{%- elif message['role'] == 'user' %}{{' Question: ' + message['content'].strip() + ' Answer:'}}{%- elif message['role'] == 'assistant' %}{{' ' + message['content'].strip()}}{%- endif %}{% endfor %}"
4545

4646
super().__init__(
4747
dataset_name="squad",

nemo_reinforcer/models/policy/hf_policy.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,10 @@ def do_fsdp(model):
121121
self._held_reference_model_params = None
122122
# register_fsdp_forward_method(self.model, "generate")
123123
if init_optimizer:
124-
self.optimizer = torch.optim.AdamW(
125-
self.model.parameters(), lr=self.cfg["learning_rate"]
124+
optimizer_cls = import_class_from_path(self.cfg["optimizer"]["name"])
125+
self.optimizer = optimizer_cls(
126+
self.model.parameters(),
127+
**self.cfg["optimizer"]["kwargs"]
126128
)
127129
else:
128130
self.optimizer = None
@@ -285,6 +287,7 @@ def train(
285287
logits = outputs.logits
286288

287289
loss, loss_metrics = loss_fn(logits, mb)
290+
loss_metrics["lr"] = self.optimizer.param_groups[0]["lr"]
288291

289292
# Backward pass
290293
if not eval_mode:

tests/unit/algorithms/test_loss_functions.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ def test_nll_loss():
4747
)
4848
loss, metrics_dict = loss_fn(next_token_logits, data)
4949
torch.testing.assert_allclose(loss.cpu(), torch.tensor(0.0))
50+
# Check the metrics dictionary contains the expected values
51+
assert metrics_dict["num_unmasked_tokens"] == 2
52+
assert metrics_dict["total_tokens"] == 3
5053

5154
## now assume we predict the incorrect token with high probability
5255
next_token_logits = (
@@ -63,4 +66,8 @@ def test_nll_loss():
6366
)
6467
loss, metrics_dict = loss_fn(next_token_logits, data)
6568
## loss per token is 999, and we have two unmasked tokens
66-
torch.testing.assert_allclose(loss.cpu(), torch.tensor(1998.0))
69+
## with the updated loss function, we now average the loss over unmasked tokens
70+
torch.testing.assert_allclose(loss.cpu(), torch.tensor(999.0))
71+
# Check the metrics dictionary contains the expected values
72+
assert metrics_dict["num_unmasked_tokens"] == 2
73+
assert metrics_dict["total_tokens"] == 3

0 commit comments

Comments
 (0)