Skip to content

Commit 406a716

Browse files
committed
[megatron] fix text_position_ids (#5783)
1 parent b1dbace commit 406a716

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

swift/megatron/trainers/utils.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -142,12 +142,9 @@ def get_batch(data_iterator):
142142
batch = get_batch_on_this_tp_rank(data_iterator)
143143
args = get_args()
144144
num_samples = batch.pop('num_samples')
145-
position_ids = batch['position_ids']
146-
if position_ids.ndim == 3:
147-
text_position_ids = position_ids[0]
148-
batch['position_ids'] = position_ids[1:]
149-
else:
150-
text_position_ids = position_ids
145+
text_position_ids = batch.pop('text_position_ids', None)
146+
if text_position_ids is None:
147+
text_position_ids = batch.get('position_ids')
151148
if args.padding_free and text_position_ids is not None:
152149
batch['packed_seq_params'] = get_packed_seq_params(text_position_ids)
153150
batch['packed_seq_params'].num_samples = num_samples

0 commit comments

Comments
 (0)