Skip to content

Commit d0651dd

Browse files
guyueh1root
andauthored
fix: Fix Fp8 sequence padding for PP>1 case (#1579)
Signed-off-by: root <[email protected]> Co-authored-by: root <[email protected]>
1 parent 91658c8 commit d0651dd

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

nemo_rl/models/megatron/common.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ def _pack_sequences_for_megatron(
8989

9090
# Round up the pad_packed_seq_to to the nearest multiple of pad_packed_seq_to_multiple_of
9191
if pad_packed_seq_to is not None:
92-
pad_packed_seq_to = _round_up_to_multiple(
93-
pad_packed_seq_to, pad_packed_seq_to_multiple_of
92+
assert pad_packed_seq_to % pad_packed_seq_to_multiple_of == 0, (
93+
f"pad_packed_seq_to ({pad_packed_seq_to}) is not a multiple of pad_packed_seq_to_multiple_of ({pad_packed_seq_to_multiple_of})."
9494
)
9595

9696
pad_factor = pad_individual_seqs_to_multiple_of
@@ -280,6 +280,12 @@ def _get_pack_sequence_parameters_for_megatron(
280280
else:
281281
pad_packed_seq_to = None
282282

283+
# make sure the pad_packed_seq_to is a multiple of the pad_packed_seq_to_multiple_of
284+
if pad_packed_seq_to is not None:
285+
pad_packed_seq_to = _round_up_to_multiple(
286+
pad_packed_seq_to, pad_packed_seq_to_multiple_of
287+
)
288+
283289
return (
284290
pad_individual_seqs_to_multiple_of,
285291
pad_packed_seq_to_multiple_of,

nemo_rl/models/policy/workers/megatron_policy_worker.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1147,6 +1147,8 @@ def train(
11471147
self.cfg["megatron_cfg"],
11481148
seq_dim_size,
11491149
)
1150+
# if pad_full_seq_to is not None, we need to use it as the sequence length
1151+
seq_dim_size = pad_full_seq_to or seq_dim_size
11501152
else:
11511153
data_iterator = batch.make_microbatch_iterator(mbs)
11521154
data_iterator_len = local_gbs // mbs

0 commit comments

Comments
 (0)