Skip to content

Commit 416687f

Browse files
fix for sequence packing plus sequence parallel: padding the sequence to a multiple of TP (#2574)
Co-authored-by: Jon Barker <19699370+jon-barker@users.noreply.github.com>
1 parent bcf07a2 commit 416687f

File tree

1 file changed

+25
-2
lines changed

1 file changed

+25
-2
lines changed

megatron/rl/rl_utils.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from megatron.core.num_microbatches_calculator import get_num_microbatches
2929
from megatron.core.optimizer import MegatronOptimizer
3030
from megatron.core.packed_seq_params import PackedSeqParams
31-
from megatron.core.parallel_state import get_tensor_model_parallel_src_rank
31+
from megatron.core.parallel_state import get_tensor_model_parallel_src_rank, get_tensor_model_parallel_world_size
3232
from megatron.core.rerun_state_machine import RerunDataIterator
3333
from megatron.core.transformer.cuda_graphs import _CudagraphGlobalRecord
3434
from megatron.core.transformer.utils import toggle_cuda_graphs
@@ -793,7 +793,30 @@ def get_logprobs(model, tokens, position_ids, attention_mask, no_grad=False, pac
793793
# No real tokens, skip packed path
794794
packed_seq_params = None
795795
else:
796-
# Slice inputs to remove padding
796+
# When sequence parallelism is enabled, the sequence length must be
797+
# divisible by the tensor parallel world size for reduce-scatter ops
798+
if model.config.sequence_parallel:
799+
tp_world_size = get_tensor_model_parallel_world_size()
800+
if actual_len % tp_world_size != 0:
801+
actual_len = ((actual_len + tp_world_size - 1) // tp_world_size) * tp_world_size
802+
# Update cu_seqlens to match the padded length.
803+
# The last entry of cu_seqlens must equal the tensor's sequence dimension.
804+
# Without this, TE attention/rotary ops see mismatched dimensions.
805+
if packed_seq_params.cu_seqlens_q[-1].item() != actual_len:
806+
# Clone to avoid modifying cached params
807+
new_cu_seqlens = packed_seq_params.cu_seqlens_q.clone()
808+
new_cu_seqlens[-1] = actual_len
809+
packed_seq_params = PackedSeqParams(
810+
qkv_format=packed_seq_params.qkv_format,
811+
cu_seqlens_q=new_cu_seqlens,
812+
cu_seqlens_kv=new_cu_seqlens,
813+
cu_seqlens_q_padded=packed_seq_params.cu_seqlens_q_padded,
814+
cu_seqlens_kv_padded=packed_seq_params.cu_seqlens_kv_padded,
815+
max_seqlen_q=packed_seq_params.max_seqlen_q,
816+
max_seqlen_kv=packed_seq_params.max_seqlen_kv,
817+
)
818+
819+
# Slice inputs to remove padding (or pad if needed for SP alignment)
797820
# dimension 0 is batch, with seq packing BS=1
798821
tokens = tokens[:, :actual_len]
799822
position_ids = position_ids[:, :actual_len]

0 commit comments

Comments
 (0)