Skip to content

Commit ecea03e

Browse files
authored
fix grpo completion length equal zero(#3857)
Co-authored-by: hjh <[email protected]>
1 parent f51ae37 commit ecea03e

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

swift/trainers/rlhf_trainer/grpo_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1062,7 +1062,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
10621062
completions_length = completion_mask.sum()
10631063
if completions_length == 0:
10641064
# Prevent division by zero issues after all completions are filtered by the overlong filter
1065-
completions_length += 1e-4
1065+
completions_length = completions_length.float() + 1e-4
10661066
loss = (per_token_loss * completion_mask).sum() / completions_length
10671067

10681068
# Log the metrics

0 commit comments

Comments
 (0)