Skip to content

Internal type conversions breaks DCAE on bfloat16Β #10891

@SwayStar123

Description

@SwayStar123

Describe the bug

Trying to use dcae

But i get the error
RuntimeError: expected scalar type Float but found BFloat16
Despite both the model and the tensor being the same type.

Looking at the apply_quadratic_attention function on which it errors

    def apply_quadratic_attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
        scores = torch.matmul(key.transpose(-1, -2), query)
        scores = scores.to(dtype=torch.float32)
        scores = scores / (torch.sum(scores, dim=2, keepdim=True) + self.eps)
        hidden_states = torch.matmul(value, scores)
        return hidden_states

theres a type conversion to torch.float32, but no conversion back to the type it is supposed to be. I believe this causes the error

Reproduction

DTYPE = torch.bfloat16
ae = AutoencoderDC.from_pretrained(f"mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers", torch_dtype=DTYPE, cache_dir=f"dc_ae", revision="main").to(device).eval()
example_latents = torch.randn(4, 32, 3, 3).to(device, dtype=DTYPE)with torch.no_grad():
example_ground_truth = ae.decode(example_latents).sample

Logs

File "D:\MyStuff\Programming\Python\AI\projects\reimei\train.py", line 141, in <module>
    example_ground_truth = ae.decode(example_latents).sample
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\Users\ASUS\AppData\Local\Programs\Python\Python312\Lib\site-packages\diffusers\utils\accelerate_utils.py", line 46, in wrapper
    return method(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\Users\ASUS\AppData\Local\Programs\Python\Python312\Lib\site-packages\diffusers\models\autoencoders\autoencoder_dc.py", line 603, in decode
    decoded = self._decode(z)
              ^^^^^^^^^^^^^^^
  File "D:\Users\ASUS\AppData\Local\Programs\Python\Python312\Lib\site-packages\diffusers\models\autoencoders\autoencoder_dc.py", line 580, in _decode
    decoded = self.decoder(z)
              ^^^^^^^^^^^^^^^
  File "D:\Users\ASUS\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\modules\module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\Users\ASUS\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\modules\module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\Users\ASUS\AppData\Local\Programs\Python\Python312\Lib\site-packages\diffusers\models\autoencoders\autoencoder_dc.py", line 370, in forward
    hidden_states = up_block(hidden_states)
                    ^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\Users\ASUS\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\modules\module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\Users\ASUS\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\modules\module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\Users\ASUS\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\modules\container.py", line 250, in forward
    input = module(input)
            ^^^^^^^^^^^^^
  File "D:\Users\ASUS\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\modules\module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\Users\ASUS\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\modules\module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\Users\ASUS\AppData\Local\Programs\Python\Python312\Lib\site-packages\diffusers\models\autoencoders\autoencoder_dc.py", line 93, in forward
    x = self.attn(x)
        ^^^^^^^^^^^^
  File "D:\Users\ASUS\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\modules\module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\Users\ASUS\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\modules\module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\Users\ASUS\AppData\Local\Programs\Python\Python312\Lib\site-packages\diffusers\models\attention_processor.py", line 906, in forward
    return self.processor(self, hidden_states)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\Users\ASUS\AppData\Local\Programs\Python\Python312\Lib\site-packages\diffusers\models\attention_processor.py", line 5806, in __call__
    hidden_states = attn.apply_quadratic_attention(query, key, value)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\Users\ASUS\AppData\Local\Programs\Python\Python312\Lib\site-packages\diffusers\models\attention_processor.py", line 902, in apply_quadratic_attention
    hidden_states = torch.matmul(value, scores)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: expected scalar type Float but found BFloat16

System Info

  • πŸ€— Diffusers version: 0.32.2
  • Platform: Windows-11-10.0.22631-SP0
  • Running on Google Colab?: No
  • Python version: 3.12.4
  • PyTorch version (GPU?): 2.6.0+cu126 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.27.1
  • Transformers version: 4.49.0
  • Accelerate version: 1.4.0
  • PEFT version: not installed
  • Bitsandbytes version: not installed
  • Safetensors version: 0.5.2
  • xFormers version: 0.0.29.post3
  • Accelerator: NVIDIA GeForce RTX 3090, 24576 MiB
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: Yes but only 1 gpu on my machine

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions