File tree Expand file tree Collapse file tree 2 files changed +7
-3
lines changed
nemo_rl/models/policy/workers Expand file tree Collapse file tree 2 files changed +7
-3
lines changed Original file line number Diff line number Diff line change @@ -858,6 +858,10 @@ def train(
858858 ## NOTE: invalid samples should be multiplied
859859 ## by zero in the loss function to prevent them
860860 ## from affecting the gradient calculation
861+
862+ # when FSDP reduces the gradients over the DP dim, they're automatically averaged
863+ # but we want to sum them so we cancel out the average here
864+ loss *= self .dp_size * self .cp_size
861865 loss .backward ()
862866
863867 if num_valid_samples > 0 :
@@ -880,8 +884,6 @@ def train(
880884 pp_axis_name = None ,
881885 foreach = True ,
882886 num_label_tokens = 1 ,
883- # when FSDP reduces the gradients over the DP dim, they're automatically averaged
884- # but we want to sum them so we rescale the gradients by self.dp_size * self.cp_size
885887 dp_group_size = self .dp_size * self .cp_size ,
886888 )
887889 grad_norm = torch .tensor (
Original file line number Diff line number Diff line change @@ -35,5 +35,7 @@ uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS
3535if [[ $( jq ' to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS ) -ge $MAX_STEPS ]]; then
3636 uv run tests/check_metrics.py $JSON_METRICS \
3737 ' mean(data["train/token_mult_prob_error"]) < 1.1' \
38- ' data["train/token_mult_prob_error"]["30"] < 1.1'
38+ ' data["train/token_mult_prob_error"]["30"] < 1.1' \
39+ ' data["train/grad_norm"]["30"] < 0.5' \
40+ ' data["train/grad_norm"]["30"] > 0.1'
3941fi
You can’t perform that action at this time.
0 commit comments