Skip to content

Commit 0dcd6c1

Browse files
hjh0119Jintao-Huang
authored andcommitted
[bugfix] fix grpo-padding-free get_logps(#6275)
1 parent 05e7137 commit 0dcd6c1

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

swift/trainers/rlhf_trainer/grpo_trainer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1838,6 +1838,9 @@ def _get_per_token_logps_and_entropies_single(self,
18381838
use_local_entropy = not hasattr(super(), '_get_per_token_logps_and_entropies') and compute_entropy
18391839

18401840
can_use_super = (not self.is_multimodal and 'logits_to_keep' in parameters and not use_local_entropy)
1841+
if 'attention_mask' not in inputs:
1842+
# when set padding_free true, the attention_mask is not in inputs
1843+
can_use_super = False
18411844

18421845
if can_use_super:
18431846
# save memory

0 commit comments

Comments
 (0)