Skip to content

RMSNorm in AutoencoderDC should be RMSNorm2d #11387

@SimonCoste

Description

@SimonCoste

Describe the bug

tl;dr I think that the Sana autoencoder should use RMSNorm2d instead of RMSNorm.

Bug description

Sana's Autoencoder uses the EfficientVit block, which uses a linear attention layer called SanaMultiScaleLinearAttention. This layer uses RMSNorm in 2d : that is, each (b,c,w,h) is normalized along the channel dimension.

HuggingFace's implementation instead uses the classical RMSNorm.
At this line, one can see that norm_type="rms_norm" leads to the norm defined here, which does:

            variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
            hidden_states = hidden_states * torch.rsqrt(variance + self.eps)

With tensors having normalization (b,c,w,h), the .mean(-1) averages along the h dimension whereas it should be along the c dimension.

Solution

The authors of the original AutoencoderDC paper have an implementation of RMSNorm2d here in the SANA repo.

Reproduction

from diffusers import AutoencoderDC
vae =  AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers", torch_dtype=torch.float32)
norm = vae.encoder.down_blocks[3][0].attn.norm_out  # extract one of these RMSNorm
x = norm(torch.ones(1,512, 100, 100))

This computes the normalized version of the all-ones tensor. The result should be constant along the channels, e.g. x[0, 0, :, :] should be constant. It is not :

tensor([[[0.9461, 0.7190, 0.9163,  ..., 0.8793, 0.8766, 0.8826],
         [0.9461, 0.7190, 0.9163,  ..., 0.8793, 0.8766, 0.8826],
         [0.9461, 0.7190, 0.9163,  ..., 0.8793, 0.8766, 0.8826],
         ...,
         [0.9461, 0.7190, 0.9163,  ..., 0.8793, 0.8766, 0.8826],
         [0.9461, 0.7190, 0.9163,  ..., 0.8793, 0.8766, 0.8826],
         [0.9461, 0.7190, 0.9163,  ..., 0.8793, 0.8766, 0.8826]]],
       grad_fn=<SliceBackward0>)

In addition, if the RMSNorm was working as intended, one could normalize any batch of size, say, (1, 512, *, *). But that's not the case : norm(torch.ones(1,512, 10, 10)) throws an error.

Logs

System Info

  • 🤗 Diffusers version: 0.33.1
  • Platform: Linux-6.11.0-21-generic-x86_64-with-glibc2.39
  • Running on Google Colab?: No
  • Python version: 3.12.3
  • PyTorch version (GPU?): 2.6.0+cu124 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.30.2
  • Transformers version: 4.51.3
  • Accelerate version: 1.6.0
  • PEFT version: 0.15.2
  • Bitsandbytes version: not installed
  • Safetensors version: 0.5.3
  • xFormers version: not installed
  • Accelerator: NVIDIA RTX A2000 8GB Laptop GPU, 8192 MiB
  • Using GPU in script?: nope
  • Using distributed or parallel set-up in script?: nope

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions