Skip to content

Commit 9902db0

Browse files
cp: fix: grad norm calculation for dtensor v2 (1693) into r0.5.0 (#1696)
Signed-off-by: Hemil Desai <[email protected]> Signed-off-by: NeMo Bot <[email protected]> Co-authored-by: Hemil Desai <[email protected]>
1 parent cd3b423 commit 9902db0

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -858,6 +858,10 @@ def train(
858858
## NOTE: invalid samples should be multiplied
859859
## by zero in the loss function to prevent them
860860
## from affecting the gradient calculation
861+
862+
# when FSDP reduces the gradients over the DP dim, they're automatically averaged
863+
# but we want to sum them so we cancel out the average here
864+
loss *= self.dp_size * self.cp_size
861865
loss.backward()
862866

863867
if num_valid_samples > 0:
@@ -880,8 +884,6 @@ def train(
880884
pp_axis_name=None,
881885
foreach=True,
882886
num_label_tokens=1,
883-
# when FSDP reduces the gradients over the DP dim, they're automatically averaged
884-
# but we want to sum them so we rescale the gradients by self.dp_size * self.cp_size
885887
dp_group_size=self.dp_size * self.cp_size,
886888
)
887889
grad_norm = torch.tensor(

tests/test_suites/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4.v3.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,5 +35,7 @@ uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS
3535
if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then
3636
uv run tests/check_metrics.py $JSON_METRICS \
3737
'mean(data["train/token_mult_prob_error"]) < 1.1' \
38-
'data["train/token_mult_prob_error"]["30"] < 1.1'
38+
'data["train/token_mult_prob_error"]["30"] < 1.1' \
39+
'data["train/grad_norm"]["30"] < 0.5' \
40+
'data["train/grad_norm"]["30"] > 0.1'
3941
fi

0 commit comments

Comments
 (0)