Skip to content

Commit 2558444

Browse files
authored
fix: Fix fsdp1 grad clipping and log grad norm (#251)
Signed-off-by: ashors1 <[email protected]>
1 parent c8f0a01 commit 2558444

File tree

7 files changed

+27
-6
lines changed

7 files changed

+27
-6
lines changed

nemo_reinforcer/algorithms/dpo.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,7 @@ def dpo_train(
482482
losses = train_results["loss"]
483483
metrics = {
484484
"loss": train_results["loss"].numpy(),
485+
"grad_norm": train_results["grad_norm"].numpy(),
485486
}
486487
metrics.update(train_results["all_mb_metrics"])
487488
for k, v in metrics.items():

nemo_reinforcer/algorithms/grpo.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,7 @@ def grpo_train(
576576
metrics = {
577577
"loss": train_results["loss"].numpy(),
578578
"reward": rewards.numpy(),
579+
"grad_norm": train_results["grad_norm"].numpy(),
579580
}
580581
metrics.update(train_results["all_mb_metrics"])
581582
for k, v in metrics.items():

nemo_reinforcer/algorithms/sft.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,7 @@ def sft_train(
486486
losses = train_results["loss"]
487487
metrics = {
488488
"loss": train_results["loss"].numpy(),
489+
"grad_norm": train_results["grad_norm"].numpy(),
489490
}
490491
metrics.update(train_results["all_mb_metrics"])
491492
for k, v in metrics.items():

nemo_reinforcer/models/policy/dtensor_policy_worker.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,7 @@ def train(
369369
total_norm=grad_norm,
370370
dtype=torch.float32,
371371
)
372+
grad_norm = torch.tensor([grad_norm])
372373

373374
# Update parameters
374375
self.optimizer.step()

nemo_reinforcer/models/policy/fsdp1_policy_worker.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -310,10 +310,20 @@ def train(
310310
all_mb_metrics.append(loss_metrics)
311311

312312
# Clip gradients
313+
grad_norm = None
313314
if not eval_mode:
314-
torch.nn.utils.clip_grad_norm_(
315-
self.model.parameters(), max_norm=self.cfg["max_grad_norm"]
316-
)
315+
if isinstance(self.model, FullyShardedDataParallel):
316+
# when using FSDP1, use FSDP's clip_grad_norm_
317+
# to ensure grad norm is being computed over all parameters
318+
# see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_
319+
grad_norm = self.model.clip_grad_norm_(
320+
max_norm=self.cfg["max_grad_norm"]
321+
)
322+
else:
323+
grad_norm = torch.nn.utils.clip_grad_norm_(
324+
self.model.parameters(), max_norm=self.cfg["max_grad_norm"]
325+
)
326+
grad_norm = grad_norm.cpu()
317327

318328
# Update parameters
319329
self.optimizer.step()
@@ -336,6 +346,7 @@ def train(
336346
metrics = {
337347
"global_loss": global_loss.cpu(),
338348
"local_loss": local_loss.cpu(),
349+
"grad_norm": grad_norm,
339350
"rank": torch.distributed.get_rank(),
340351
"all_mb_metrics": dict(mb_metrics),
341352
}

nemo_reinforcer/models/policy/hf_policy.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,10 @@ def train(
177177
results = self.worker_group.get_all_worker_results(futures)
178178

179179
# Aggregate the results
180-
aggregated_results = {}
181-
aggregated_results["loss"] = results[0]["global_loss"]
180+
aggregated_results = {
181+
"loss": results[0]["global_loss"],
182+
"grad_norm": results[0]["grad_norm"],
183+
}
182184

183185
# Aggregate metrics across all workers
184186
all_mb_metrics = defaultdict(list)

tests/unit/algorithms/test_sft.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@
2424
def mock_components():
2525
# Create mock components
2626
policy = MagicMock()
27-
policy.train.return_value = {"loss": torch.tensor(0.5), "all_mb_metrics": {}}
27+
policy.train.return_value = {
28+
"loss": torch.tensor(0.5),
29+
"grad_norm": torch.tensor(1.0),
30+
"all_mb_metrics": {},
31+
}
2832

2933
# Create a proper message log structure with token_ids
3034
mock_batch = {

0 commit comments

Comments
 (0)