File tree Expand file tree Collapse file tree 2 files changed +10
-2
lines changed
Expand file tree Collapse file tree 2 files changed +10
-2
lines changed Original file line number Diff line number Diff line change @@ -89,8 +89,8 @@ def _pack_sequences_for_megatron(
8989
9090 # Round up the pad_packed_seq_to to the nearest multiple of pad_packed_seq_to_multiple_of
9191 if pad_packed_seq_to is not None :
92- pad_packed_seq_to = _round_up_to_multiple (
93- pad_packed_seq_to , pad_packed_seq_to_multiple_of
92+ assert pad_packed_seq_to % pad_packed_seq_to_multiple_of == 0 , (
93+ f" pad_packed_seq_to ( { pad_packed_seq_to } ) is not a multiple of pad_packed_seq_to_multiple_of ( { pad_packed_seq_to_multiple_of } )."
9494 )
9595
9696 pad_factor = pad_individual_seqs_to_multiple_of
@@ -280,6 +280,12 @@ def _get_pack_sequence_parameters_for_megatron(
280280 else :
281281 pad_packed_seq_to = None
282282
283+ # make sure the pad_packed_seq_to is a multiple of the pad_packed_seq_to_multiple_of
284+ if pad_packed_seq_to is not None :
285+ pad_packed_seq_to = _round_up_to_multiple (
286+ pad_packed_seq_to , pad_packed_seq_to_multiple_of
287+ )
288+
283289 return (
284290 pad_individual_seqs_to_multiple_of ,
285291 pad_packed_seq_to_multiple_of ,
Original file line number Diff line number Diff line change @@ -1147,6 +1147,8 @@ def train(
11471147 self .cfg ["megatron_cfg" ],
11481148 seq_dim_size ,
11491149 )
1150+ # if pad_full_seq_to is not None, we need to use it as the sequence length
1151+ seq_dim_size = pad_full_seq_to or seq_dim_size
11501152 else :
11511153 data_iterator = batch .make_microbatch_iterator (mbs )
11521154 data_iterator_len = local_gbs // mbs
You can’t perform that action at this time.
0 commit comments