Skip to content

Incorrect Tensor Parallel sharding and semantic ordering for GDN in_proj and conv1d weights in Qwen3Next #2013

@liyandong001

Description

@liyandong001

Describe the bug

When exporting Qwen3Next models with GDN attention using Megatron-Bridge, we observed an incorrect semantic ordering of two GDN-related parameters under Tensor Parallelism (TP > 1): decoder.layers..self_attention.in_proj.weight and decoder.layers..self_attention.conv1d.weight. This issue does not affect training or forward correctness inside Megatron, but becomes visible during checkpoint export or conversion (e.g. exporting to HuggingFace format). Specifically, loading the same Megatron torch distributed checkpoint with TP=1 produces a semantically correct full tensor, while loading it with TP=2 and reconstructing the full tensor by directly concatenating TP shards along dim=0 produces a tensor that differs significantly from the TP=1 result, with large-scale row misalignment rather than numerical noise.

Steps / Code to reproduce bug
1. Train or load a Megatron model that includes Qwen3Next GDN attention
2. Save a torch distributed checkpoint with TP=2
3. Load the same checkpoint twice: once with TP=1 and once with TP=2
4. Compare the following parameters: decoder.layers.0.self_attention.in_proj.weight and decoder.layers.0.self_attention.conv1d.weight
5. Observe that under TP=2, concatenating rank0/rank1 shards using torch.cat(dim=0) yields a tensor that does not match the TP=1-loaded tensor, with many mismatched rows

Example: TP=1 produces an in_proj.weight of shape (12352, 2048), while TP=2 produces two local shards of shape (6176, 2048), but torch.cat([rank0, rank1], dim=0) does not equal the TP=1 tensor.

Expected behavior

For the same Megatron checkpoint, loading with TP=1 or TP>1 should yield identical full tensors after TP aggregation. In particular, GDN parameters should be reconstructed in a semantically consistent order regardless of TP size.

Root cause analysis

The issue appears to be caused by how GDN parameters are split in sharded_state_dict. For in_proj.weight, the tensor is split along dim=0 into multiple semantic sections (Q, K, V, Z, Beta, Alpha), and for conv1d.weight into (Q, K, V). Each TP rank therefore owns local slices of each semantic section, rather than a contiguous global semantic range. As a result, directly concatenating TP shards along dim=0 does not reconstruct the correct semantic order, even though the local shard shapes are correct.

Correct reconstruction logic

For TP > 1, the correct way to reconstruct the full tensor is to first split each local TP shard into semantic chunks (Q/K/V/Z/…), then perform TP all-gather independently for each semantic chunk, concatenate each gathered chunk along dim=0, and finally concatenate all semantic chunks in semantic order to form the full tensor. Using this reconstruction logic, the TP=2 result exactly matches the TP=1-loaded tensor.

Additional context

This issue does not affect Megatron’s internal computation correctness, but it breaks downstream checkpoint export and cross-framework conversion (e.g. exporting to HuggingFace). We implemented a local workaround in our export path by explicitly restoring semantic order, but this appears to be a general issue in Megatron-Bridge when handling semantically split GDN parameters for Qwen3Next, and may require a framework-level fix in TP aggregation or mapping logic.

`
def _gather_semantic_full_from_local(self, local: torch.Tensor, config, tp_group) -> torch.Tensor:
tp_size = torch.distributed.get_world_size(group=tp_group)
if tp_size == 1:
return local

qk_dim = config.linear_key_head_dim * config.linear_num_key_heads
v_dim = config.linear_value_head_dim * config.linear_num_value_heads
num_v_heads = config.linear_num_value_heads

split_local = [
    qk_dim // tp_size,
    qk_dim // tp_size,
    v_dim // tp_size,
    v_dim // tp_size,
    num_v_heads // tp_size,
    num_v_heads // tp_size,
]

chunks = list(torch.split(local, split_local, dim=0))

full_chunks = []
for c in chunks:
    gathered = [torch.empty_like(c) for _ in range(tp_size)]
    torch.distributed.all_gather(gathered, c, group=tp_group)
    full_chunks.append(torch.cat(gathered, dim=0))

return torch.cat(full_chunks, dim=0).contiguous()

def megatron_to_hf(
self,
megatron_weights: Optional[torch.Tensor],
megatron_module: Optional[nn.Module],
) -> Dict[str, torch.Tensor]:

assert self.megatron_param is not None
megatron_weights = self.maybe_dequantize(megatron_weights)

if megatron_module is None:
    config = self.broadcast_obj_from_pp_rank(None)
else:
    config = self._get_config(megatron_module)
    config = remove_non_pickleables(config, max_depth=3)
    config = self.broadcast_obj_from_pp_rank(config)

if megatron_weights is None:
    return {}

tp_size = torch.distributed.get_world_size(group=self.tp_group)

if tp_size == 1:
    return {str(self.hf_param): megatron_weights.contiguous()}

qk_dim = config.linear_key_head_dim * config.linear_num_key_heads
v_dim = config.linear_value_head_dim * config.linear_num_value_heads

split_local = [
    qk_dim // tp_size,
    qk_dim // tp_size,
    v_dim // tp_size,
]

chunks = list(torch.split(megatron_weights, split_local, dim=0))

full_chunks = []
for c in chunks:
    gathered = [torch.empty_like(c) for _ in range(tp_size)]
    torch.distributed.all_gather(gathered, c, group=self.tp_group)
    full_chunks.append(torch.cat(gathered, dim=0))

full = torch.cat(full_chunks, dim=0).contiguous()
return {str(self.hf_param): full}

`
I plan to submit a PR to fix this properly when I have time, instead of relying on downstream ad-hoc logic.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions