|
28 | 28 | from megatron.core.num_microbatches_calculator import get_num_microbatches |
29 | 29 | from megatron.core.optimizer import MegatronOptimizer |
30 | 30 | from megatron.core.packed_seq_params import PackedSeqParams |
31 | | -from megatron.core.parallel_state import get_tensor_model_parallel_src_rank |
| 31 | +from megatron.core.parallel_state import get_tensor_model_parallel_src_rank, get_tensor_model_parallel_world_size |
32 | 32 | from megatron.core.rerun_state_machine import RerunDataIterator |
33 | 33 | from megatron.core.transformer.cuda_graphs import _CudagraphGlobalRecord |
34 | 34 | from megatron.core.transformer.utils import toggle_cuda_graphs |
@@ -793,7 +793,30 @@ def get_logprobs(model, tokens, position_ids, attention_mask, no_grad=False, pac |
793 | 793 | # No real tokens, skip packed path |
794 | 794 | packed_seq_params = None |
795 | 795 | else: |
796 | | - # Slice inputs to remove padding |
| 796 | + # When sequence parallelism is enabled, the sequence length must be |
| 797 | + # divisible by the tensor parallel world size for reduce-scatter ops |
| 798 | + if model.config.sequence_parallel: |
| 799 | + tp_world_size = get_tensor_model_parallel_world_size() |
| 800 | + if actual_len % tp_world_size != 0: |
| 801 | + actual_len = ((actual_len + tp_world_size - 1) // tp_world_size) * tp_world_size |
| 802 | + # Update cu_seqlens to match the padded length. |
| 803 | + # The last entry of cu_seqlens must equal the tensor's sequence dimension. |
| 804 | + # Without this, TE attention/rotary ops see mismatched dimensions. |
| 805 | + if packed_seq_params.cu_seqlens_q[-1].item() != actual_len: |
| 806 | + # Clone to avoid modifying cached params |
| 807 | + new_cu_seqlens = packed_seq_params.cu_seqlens_q.clone() |
| 808 | + new_cu_seqlens[-1] = actual_len |
| 809 | + packed_seq_params = PackedSeqParams( |
| 810 | + qkv_format=packed_seq_params.qkv_format, |
| 811 | + cu_seqlens_q=new_cu_seqlens, |
| 812 | + cu_seqlens_kv=new_cu_seqlens, |
| 813 | + cu_seqlens_q_padded=packed_seq_params.cu_seqlens_q_padded, |
| 814 | + cu_seqlens_kv_padded=packed_seq_params.cu_seqlens_kv_padded, |
| 815 | + max_seqlen_q=packed_seq_params.max_seqlen_q, |
| 816 | + max_seqlen_kv=packed_seq_params.max_seqlen_kv, |
| 817 | + ) |
| 818 | + |
| 819 | + # Slice inputs to remove padding (or pad if needed for SP alignment) |
797 | 820 | # dimension 0 is batch, with seq packing BS=1 |
798 | 821 | tokens = tokens[:, :actual_len] |
799 | 822 | position_ids = position_ids[:, :actual_len] |
|
0 commit comments