Skip to content

Context Parallel mismatch: CP set to 2, runtime slices with CP=4 causing reshape error in get_batch_on_this_cp_rank #1982

@TedLiangHK

Description

@TedLiangHK

Environment

  • Hardware: AWS EC2, multi-node (4 nodes × 8 GPUs = 32 ranks)
  • CUDA/NCCL: CUDA 13, NCCL 2.28.8+cuda13.0
  • PyTorch: 2.9.1 (container-provided)
  • Megatron-Core: pip package (from megatron-core), used via Megatron‑Bridge
  • Transformer Engine: installed and importable
  • Dataset: Energon WebDataset (no videos), Qwen‑VL finetune
  • Launcher: torch.distributed.run with nnodes=4, nproc_per_node=8
  • Networking: host network, NCCL_DEBUG=INFO, GLOO/NCCL IFNAME configured

Model/Config

  • Script: examples/recipes/qwen_vl/finetune_qwen_vl.py
  • Recipe: qwen3_vl_8b_finetune_config
  • Key overrides:
    • model.tensor_model_parallel_size=8
    • model.pipeline_model_parallel_size=1
    • model.context_parallel_size=2
    • model.seq_length=24576
    • ddp.use_distributed_optimizer=true
    • ddp.data_parallel_sharding_strategy=optim_grads_params
    • model.recompute_granularity=selective
    • model.attention_backend=flash
    • dataset-type=energon with dataset.micro_batch_size=… matching train.micro_batch_size

Symptoms

  • Config print on rank0 shows:
    • context_parallel_size: 2
    • hierarchical_context_parallel_sizes: null
  • Runtime error during training:
    • Stack points to Megatron Core CP slicing:
      • megatron.core.utils.get_batch_on_this_cp_rank()
    • Error example 1:
      • RuntimeError: shape '[1, 4, 1131]' is invalid for input of size 4526
    • Error example 2:
      • RuntimeError: shape '[1, 4, 1620]' is invalid for input of size 6483
    • This indicates the CP slicer is attempting to reshape with CP=4 (B=1, CP=4) even though the config sets CP=2.
  • The per‑microbatch effective sequence length S is not divisible by 4 in the error cases (e.g., 6483), causing the reshape failure.

Expected Behavior

  • With model.context_parallel_size=2 and PP=1, the batch tensors should be padded to model.seq_length (24576), which is divisible by 2. The CP slicer should reshape with CP=2 (not 4), i.e., [B, CP=2, S/2, …], and should not error.

Actual Behavior

  • CP appears to be treated as 4 by the slicer despite context_parallel_size=2 in config; reshape fails when S%4 != 0.
  • hierarchical_context_parallel_sizes is null; possibly a hierarchical CP split is inferred at runtime.

Repro Steps

  1. Launch (master on node 0) with:
    • --nnodes=4 --nproc_per_node=8 --node_rank set per node
    • model.tensor_model_parallel_size=8
    • model.pipeline_model_parallel_size=1
    • model.context_parallel_size=2
    • model.seq_length=24576
    • dataset-type=energon (dataset.global/micro batch sizes = train.*)
    • model.recompute_granularity=selective
    • model.attention_backend=flash
  2. Observe config printed on rank0:
    • context_parallel_size: 2
    • hierarchical_context_parallel_sizes: null
  3. Training crashes on some ranks with the reshape error above.

What we tried

  • Verified seq_length=24576 (multiple of 128 and divisible by CP=2).
  • Ensured PP=1 padding path (fixed length per microbatch).
  • Confirmed no packed/THD (remove-padding) inputs used with CP.

Hypothesis

  • Despite context_parallel_size=2, hierarchical CP defaults to a deeper split on this topology (CP=4) when hierarchical_context_parallel_sizes is null. The batch tensors are then sliced using CP=4, but the microbatch length S is not
    divisible by 4, causing the reshape error.
  • Alternately, CP is correctly 2 on config, but some internal path initializes CP groups differently, leading to a mismatch between CP world size used in slicing and the model config.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions