Skip to content

Commit e1a0fc3

Browse files
committed
[megatron] fix pp4 position_ids (#5544)
1 parent 2973b3b commit e1a0fc3

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

swift/megatron/trainers/utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,9 @@ def _broadcast(item):
7474
_broadcast(batch['position_ids'])
7575
_broadcast(batch['loss_scale'])
7676
else:
77-
for key in ('input_ids', 'labels', 'attention_mask', 'position_ids', 'loss_scale'):
77+
_broadcast(batch['attention_mask'])
78+
_broadcast(batch['position_ids'])
79+
for key in ('input_ids', 'labels', 'loss_scale'):
7880
batch[key] = None
7981

8082
else:
@@ -121,7 +123,9 @@ def _broadcast(item):
121123
_broadcast(position_ids) # compat packing & cp
122124
_broadcast(loss_scale)
123125
else:
124-
input_ids, labels, attention_mask, position_ids, loss_scale = (None, ) * 5
126+
_broadcast(attention_mask)
127+
_broadcast(position_ids)
128+
input_ids, labels, loss_scale = (None, ) * 3
125129

126130
batch = {
127131
'input_ids': input_ids,

0 commit comments

Comments
 (0)