File tree Expand file tree Collapse file tree 1 file changed +3
-6
lines changed Expand file tree Collapse file tree 1 file changed +3
-6
lines changed Original file line number Diff line number Diff line change @@ -142,12 +142,9 @@ def get_batch(data_iterator):
142142 batch = get_batch_on_this_tp_rank (data_iterator )
143143 args = get_args ()
144144 num_samples = batch .pop ('num_samples' )
145- position_ids = batch ['position_ids' ]
146- if position_ids .ndim == 3 :
147- text_position_ids = position_ids [0 ]
148- batch ['position_ids' ] = position_ids [1 :]
149- else :
150- text_position_ids = position_ids
145+ text_position_ids = batch .pop ('text_position_ids' , None )
146+ if text_position_ids is None :
147+ text_position_ids = batch .get ('position_ids' )
151148 if args .padding_free and text_position_ids is not None :
152149 batch ['packed_seq_params' ] = get_packed_seq_params (text_position_ids )
153150 batch ['packed_seq_params' ].num_samples = num_samples
You can’t perform that action at this time.
0 commit comments