-
Notifications
You must be signed in to change notification settings - Fork 371
Description
Problem Description
When using RMSNorm (or RoPE) as qk_norm inside multi-head attention, the pruning dependency graph infers an incorrect channel split for qkv.unbind(0).
Specifically, although qkv has shape [3, B, num_heads, N, head_dim] and unbind(0) should produce 3 outputs with head_dim * num_heads channels each, the pruning graph incorrectly treats the unbind as 5 chunks, resulting in channel groups of size ≈2304 / 5 = 461.
Minimal Reproduction Code
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return (self.weight * hidden_states).to(input_dtype)
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=True, qk_norm=True, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.head_dim = head_dim ## need for prune method
self.q_norm = RMSNorm(head_dim) if qk_norm else nn.Identity()
self.k_norm = RMSNorm(head_dim) if qk_norm else nn.Identity()
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, rope, attn_mask=None):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q = self.q_norm(q)
k = self.k_norm(k)
q = rope(q)
k = rope(k)
x = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_mask,
dropout_p=self.attn_drop.p if self.training else 0.,
)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
Observed Behavior
Pruning dependency output:
--------------------------------
Pruning Group
--------------------------------
[0] prune_out_channels on blocks.0.attn.qkv (Linear(in_features=768, out_features=2304, bias=True)) => prune_out_channels on blocks.0.attn.qkv (Linear(in_features=768, out_features=2304, bias=True)), len(idxs)=2304
[1] prune_out_channels on blocks.0.attn.qkv (Linear(in_features=768, out_features=2304, bias=True)) => prune_out_channels on _Reshape_751(), len(idxs)=2304
[2] prune_out_channels on _Reshape_751() => prune_out_channels on _ElementWiseOp_750(PermuteBackward0), len(idxs)=2304
[3] prune_out_channels on _ElementWiseOp_750(PermuteBackward0) => prune_out_channels on _UnbindOp_749([0, 460, 921, 1382, 1843, 2304]), len(idxs)=2304
[4] prune_out_channels on _UnbindOp_749([0, 460, 921, 1382, 1843, 2304]) => prune_out_channels on _ElementWiseOp_746(ScaledDotProductEfficientAttentionBackward0), len(idxs)=461
[5] prune_out_channels on _UnbindOp_749([0, 460, 921, 1382, 1843, 2304]) => prune_out_channels on _ElementWiseOp_775(MulBackward0), len(idxs)=461
[6] prune_out_channels on _UnbindOp_749([0, 460, 921, 1382, 1843, 2304]) => prune_out_channels on _ElementWiseOp_779(PowBackward0), len(idxs)=461
[7] prune_out_channels on _UnbindOp_749([0, 460, 921, 1382, 1843, 2304]) => prune_out_channels on _ElementWiseOp_788(MulBackward0), len(idxs)=461
[8] prune_out_channels on _UnbindOp_749([0, 460, 921, 1382, 1843, 2304]) => prune_out_channels on _ElementWiseOp_792(PowBackward0), len(idxs)=461
This indicates that num_chunks for the unbind op is inferred as 5 instead of 3.
Expected Behavior
For:
dim = 768
num_heads = 12
head_dim = 64
qkv.out_features = 2304
qkv.unbind(0) should produce 3 branches, each corresponding to 768 channels.
Therefore, pruning indices propagated from each branch should have len(idxs) = 768.
Root Cause Analysis
From inspection, this seems to be caused by how update_unbind_index_mapping infers(in dependency/index_mappint.py):
num_chunks = len(unbind_node.outputs)
offsets = [i * input_dims // num_chunks for i in range(num_chunks + 1)]
As a result, input_dims=2304 is evenly split into 5 segments instead of 3. Since q and k is input of pow in 'variance = hidden_states.pow(2).mean(-1, keepdim=True)' and 'multiply' in (self.weight * hidden_states).
Experimental Modification
As an experiment, I tried detaching the tensor used for variance computation in RMSNorm:
variance = hidden_states.detach().pow(2).mean(-1, keepdim=True)
With this change, the pruning dependency graph no longer produces incorrect channel offsets, and the inferred pruning indices become consistent with the expected Q/K/V split.
However, this modification intentionally breaks gradient flow through the variance term and is therefore not a valid solution.