Fix GRPO tool mask alignment after tool-call retokenization#5145
Fix GRPO tool mask alignment after tool-call retokenization#5145MichalMraz wants to merge 1 commit intohuggingface:mainfrom
Conversation
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
shouldn't happen, what model do you use? |
|
I originally hit it with Qwen3-32B on transformers==5.0.0. |
|
ok, can you share what are the generated completion ids / tools calls when if occurs? |
|
For example this kind of model response causes it The completion_ids are then Here
So old logic computes negative delta (41 - 47 = -6). Yes, it is a somewhat degenerate case but it sometimes happens during unstable training, causing a crash |
|
Friendly bump! I recognise it's not a high priority issue but I'd like to see it resolved. Let me know if there's anything I should update or improve to help move this forward. |
|
Yep, thanks for the PR. No worry it still a priority PR. Actually I think the "right" solution is not to retokenize the tool call. We might need some deeper re-factoring of the generation loop here :/ |
What does this PR do?
Fixes #5144
This PR fixes a shape-mismatch bug in
GRPOTrainertool-call flow.Root cause:
_tool_call_loop, tool-round retokenization can make the completion part shorter than the previous completion.tool_mask/logprobswere only extended, not truncated, so they could become longer thancompletion_ids._compute_losswhen multiplyingcompletion_mask * tool_mask.Fix:
trl/trainer/grpo_trainer.py, aligntool_maskandlogprobsto computed completion lengths by truncating or padding as needed before appending post-tool tokens.Tests:
tests/test_grpo_trainer.py::test_training_with_tools_keeps_masks_aligned_when_retokenization_shortens_completion