Skip to content

Commit 0e0fc6a

Browse files
authored
[bugfix] fix megatron cp (real_position_ids) (#5683)
1 parent a8dd61f commit 0e0fc6a

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

swift/llm/template/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1682,6 +1682,8 @@ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[in
16821682
cp_size = self.sequence_parallel_size
16831683
if cp_size > 1:
16841684
for key in ['position_ids', 'real_position_ids']:
1685+
if key not in res:
1686+
continue
16851687
padding_len = padding_to - seq_lens[0]
16861688
position_ids = res[key][0]
16871689
extended_position_ids = torch.arange(cp_size * 2).repeat(padding_len // (cp_size * 2))

0 commit comments

Comments
 (0)