Skip to content

Commit e6d60f0

Browse files
authored
[megatron] fix pp4 (#4516)
1 parent 3feb0bc commit e6d60f0

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

swift/megatron/train/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def get_batch(data_iterator):
205205

206206
# TODO: this is pretty hacky, find a better way
207207
if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()):
208-
return None, None, None, None, None
208+
return {key: None for key in ['input_ids', 'attention_mask', 'position_ids']}
209209

210210
# get batches based on the TP rank you are on
211211
batch = get_batch_on_this_tp_rank(data_iterator)

0 commit comments

Comments
 (0)