Skip to content

Commit 22af21c

Browse files
authored
feat: FSDP2 SFT (#206)
Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
1 parent e36f488 commit 22af21c

File tree

4 files changed

+28
-8
lines changed

4 files changed

+28
-8
lines changed

examples/configs/sft.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ policy:
4747
weight_decay: 0.1
4848
betas: [0.9, 0.98]
4949
eps: 1e-5
50+
# when using Dtensor, we need to set foreach
51+
# and fused to False
52+
foreach: False
53+
fused: False
5054

5155
data:
5256
max_input_seq_length: ${policy.max_total_sequence_length}

nemo_reinforcer/algorithms/loss_functions.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def __call__(
112112
next_token_logprobs = torch.nn.functional.log_softmax(
113113
next_token_logits, dim=-1
114114
)
115-
next_tokens = data["input_ids"][:, 1:] # Skip first token
115+
next_tokens = data.get("input_ids")[:, 1:].cuda() # Skip first token
116116
curr_logprobs = next_token_logprobs.gather(
117117
dim=-1, index=next_tokens.unsqueeze(-1)
118118
).squeeze(-1)
@@ -168,14 +168,22 @@ def __call__(
168168
sample_mask = data["sample_mask"]
169169
mask = token_mask * sample_mask.unsqueeze(-1)
170170

171-
next_tokens = data.get("input_ids")[:, 1:].cuda() # Skip first token
172-
next_token_logprobs = torch.nn.functional.log_softmax(next_token_logits, dim=-1)
173-
logprobs = next_token_logprobs[:, :-1] # Remove last position's logits
171+
next_token_logits = next_token_logits.to(torch.float32)
174172

175173
# Gather the logprobs for the actual next tokens
176-
token_logprobs = logprobs.gather(
177-
dim=-1, index=next_tokens.unsqueeze(-1)
178-
).squeeze(-1)
174+
if isinstance(next_token_logits, torch.distributed.tensor.DTensor):
175+
token_logprobs = get_logprobs_from_vocab_parallel_logits(
176+
next_token_logits, data["input_ids"]
177+
)
178+
else:
179+
next_tokens = data.get("input_ids")[:, 1:].cuda() # Skip first token
180+
next_token_logprobs = torch.nn.functional.log_softmax(
181+
next_token_logits, dim=-1
182+
)
183+
logprobs = next_token_logprobs[:, :-1] # Remove last position's logits
184+
token_logprobs = logprobs.gather(
185+
dim=-1, index=next_tokens.unsqueeze(-1)
186+
).squeeze(-1)
179187

180188
# Only compute loss on generated tokens (not input tokens)
181189
# by applying the token_loss_mask (shifted by 1 since we're predicting next tokens)

nemo_reinforcer/algorithms/sft.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ def validate(
237237

238238
val_metrics = {"val_loss": 0.0}
239239

240+
policy.prepare_for_training()
240241
for batch_idx, val_batch in enumerate(val_dataloader):
241242
## add loss mask based on role to every message
242243
add_loss_mask_to_message_log(
@@ -247,6 +248,9 @@ def validate(
247248
cat_and_padded, input_lengths = batched_message_log_to_flat_message(
248249
val_batch["message_log"],
249250
pad_value_dict={"token_ids": tokenizer.pad_token_id},
251+
make_sequence_length_divisible_by=master_config["policy"][
252+
"make_sequence_length_divisible_by"
253+
],
250254
)
251255

252256
val_data: BatchedDataDict = BatchedDataDict(
@@ -358,6 +362,9 @@ def sft_train(
358362
cat_and_padded, input_lengths = batched_message_log_to_flat_message(
359363
batch["message_log"],
360364
pad_value_dict={"token_ids": tokenizer.pad_token_id},
365+
make_sequence_length_divisible_by=master_config["policy"][
366+
"make_sequence_length_divisible_by"
367+
],
361368
)
362369

363370
train_data: BatchedDataDict = BatchedDataDict(

nemo_reinforcer/models/policy/dtensor_policy_worker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,7 @@ def train(
321321
mb_losses.append(loss.item())
322322
all_mb_metrics.append(loss_metrics)
323323

324+
grad_norm = None
324325
if not eval_mode:
325326
with torch.no_grad():
326327
grad_norm = get_grad_norm(
@@ -347,7 +348,7 @@ def train(
347348
with torch.no_grad():
348349
local_loss = torch.tensor(losses, device="cuda")
349350
global_loss = torch.zeros_like(local_loss)
350-
torch.distributed.all_reduce(local_loss)
351+
torch.distributed.all_reduce(local_loss, group=self.dp_mesh.get_group())
351352
global_loss = local_loss / self.dp_size
352353

353354
# Aggregate metrics across all microbatches

0 commit comments

Comments
 (0)