Skip to content

Wrong channel offsets inferred for unbind when outputs feeds multiple downstream ops #524

@Te4P0t

Description

@Te4P0t

Problem Description

When using RMSNorm (or RoPE) as qk_norm inside multi-head attention, the pruning dependency graph infers an incorrect channel split for qkv.unbind(0).

Specifically, although qkv has shape [3, B, num_heads, N, head_dim] and unbind(0) should produce 3 outputs with head_dim * num_heads channels each, the pruning graph incorrectly treats the unbind as 5 chunks, resulting in channel groups of size ≈2304 / 5 = 461.

Minimal Reproduction Code

class RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return (self.weight * hidden_states).to(input_dtype)

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=True, qk_norm=True, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.head_dim = head_dim        ## need for prune method

        self.q_norm = RMSNorm(head_dim) if qk_norm else nn.Identity()
        self.k_norm = RMSNorm(head_dim) if qk_norm else nn.Identity()

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x, rope, attn_mask=None):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)

        q = self.q_norm(q)
        k = self.k_norm(k)

        q = rope(q)
        k = rope(k)

        x = F.scaled_dot_product_attention(
                q, k, v,
                attn_mask=attn_mask,
                dropout_p=self.attn_drop.p if self.training else 0.,
            )

        x = x.transpose(1, 2).reshape(B, N, C)

        x = self.proj(x)
        x = self.proj_drop(x)
        return x

Observed Behavior

Pruning dependency output:

--------------------------------
          Pruning Group
--------------------------------
[0] prune_out_channels on blocks.0.attn.qkv (Linear(in_features=768, out_features=2304, bias=True)) => prune_out_channels on blocks.0.attn.qkv (Linear(in_features=768, out_features=2304, bias=True)), len(idxs)=2304
[1] prune_out_channels on blocks.0.attn.qkv (Linear(in_features=768, out_features=2304, bias=True)) => prune_out_channels on _Reshape_751(), len(idxs)=2304
[2] prune_out_channels on _Reshape_751() => prune_out_channels on _ElementWiseOp_750(PermuteBackward0), len(idxs)=2304
[3] prune_out_channels on _ElementWiseOp_750(PermuteBackward0) => prune_out_channels on _UnbindOp_749([0, 460, 921, 1382, 1843, 2304]), len(idxs)=2304
[4] prune_out_channels on _UnbindOp_749([0, 460, 921, 1382, 1843, 2304]) => prune_out_channels on _ElementWiseOp_746(ScaledDotProductEfficientAttentionBackward0), len(idxs)=461
[5] prune_out_channels on _UnbindOp_749([0, 460, 921, 1382, 1843, 2304]) => prune_out_channels on _ElementWiseOp_775(MulBackward0), len(idxs)=461
[6] prune_out_channels on _UnbindOp_749([0, 460, 921, 1382, 1843, 2304]) => prune_out_channels on _ElementWiseOp_779(PowBackward0), len(idxs)=461
[7] prune_out_channels on _UnbindOp_749([0, 460, 921, 1382, 1843, 2304]) => prune_out_channels on _ElementWiseOp_788(MulBackward0), len(idxs)=461
[8] prune_out_channels on _UnbindOp_749([0, 460, 921, 1382, 1843, 2304]) => prune_out_channels on _ElementWiseOp_792(PowBackward0), len(idxs)=461

This indicates that num_chunks for the unbind op is inferred as 5 instead of 3.

Expected Behavior

For:
dim = 768
num_heads = 12
head_dim = 64
qkv.out_features = 2304
qkv.unbind(0) should produce 3 branches, each corresponding to 768 channels.
Therefore, pruning indices propagated from each branch should have len(idxs) = 768.

Root Cause Analysis

From inspection, this seems to be caused by how update_unbind_index_mapping infers(in dependency/index_mappint.py):

num_chunks = len(unbind_node.outputs)
offsets = [i * input_dims // num_chunks for i in range(num_chunks + 1)]

As a result, input_dims=2304 is evenly split into 5 segments instead of 3. Since q and k is input of pow in 'variance = hidden_states.pow(2).mean(-1, keepdim=True)' and 'multiply' in (self.weight * hidden_states).

Experimental Modification

As an experiment, I tried detaching the tensor used for variance computation in RMSNorm:

variance = hidden_states.detach().pow(2).mean(-1, keepdim=True)

With this change, the pruning dependency graph no longer produces incorrect channel offsets, and the inferred pruning indices become consistent with the expected Q/K/V split.

However, this modification intentionally breaks gradient flow through the variance term and is therefore not a valid solution.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions