-
Couldn't load subscription status.
- Fork 6.5k
Description
Describe the bug
tl;dr I think that the Sana autoencoder should use RMSNorm2d instead of RMSNorm.
Bug description
Sana's Autoencoder uses the EfficientVit block, which uses a linear attention layer called SanaMultiScaleLinearAttention. This layer uses RMSNorm in 2d : that is, each (b,c,w,h) is normalized along the channel dimension.
HuggingFace's implementation instead uses the classical RMSNorm.
At this line, one can see that norm_type="rms_norm" leads to the norm defined here, which does:
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
With tensors having normalization (b,c,w,h), the .mean(-1) averages along the h dimension whereas it should be along the c dimension.
Solution
The authors of the original AutoencoderDC paper have an implementation of RMSNorm2d here in the SANA repo.
Reproduction
from diffusers import AutoencoderDC
vae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers", torch_dtype=torch.float32)
norm = vae.encoder.down_blocks[3][0].attn.norm_out # extract one of these RMSNorm
x = norm(torch.ones(1,512, 100, 100))
This computes the normalized version of the all-ones tensor. The result should be constant along the channels, e.g. x[0, 0, :, :] should be constant. It is not :
tensor([[[0.9461, 0.7190, 0.9163, ..., 0.8793, 0.8766, 0.8826],
[0.9461, 0.7190, 0.9163, ..., 0.8793, 0.8766, 0.8826],
[0.9461, 0.7190, 0.9163, ..., 0.8793, 0.8766, 0.8826],
...,
[0.9461, 0.7190, 0.9163, ..., 0.8793, 0.8766, 0.8826],
[0.9461, 0.7190, 0.9163, ..., 0.8793, 0.8766, 0.8826],
[0.9461, 0.7190, 0.9163, ..., 0.8793, 0.8766, 0.8826]]],
grad_fn=<SliceBackward0>)
In addition, if the RMSNorm was working as intended, one could normalize any batch of size, say, (1, 512, *, *). But that's not the case : norm(torch.ones(1,512, 10, 10)) throws an error.
Logs
System Info
- 🤗 Diffusers version: 0.33.1
- Platform: Linux-6.11.0-21-generic-x86_64-with-glibc2.39
- Running on Google Colab?: No
- Python version: 3.12.3
- PyTorch version (GPU?): 2.6.0+cu124 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.30.2
- Transformers version: 4.51.3
- Accelerate version: 1.6.0
- PEFT version: 0.15.2
- Bitsandbytes version: not installed
- Safetensors version: 0.5.3
- xFormers version: not installed
- Accelerator: NVIDIA RTX A2000 8GB Laptop GPU, 8192 MiB
- Using GPU in script?: nope
- Using distributed or parallel set-up in script?: nope
Who can help?
No response