-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
Trying to use dcae
But i get the error
RuntimeError: expected scalar type Float but found BFloat16
Despite both the model and the tensor being the same type.
Looking at the apply_quadratic_attention function on which it errors
def apply_quadratic_attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
scores = torch.matmul(key.transpose(-1, -2), query)
scores = scores.to(dtype=torch.float32)
scores = scores / (torch.sum(scores, dim=2, keepdim=True) + self.eps)
hidden_states = torch.matmul(value, scores)
return hidden_states
theres a type conversion to torch.float32, but no conversion back to the type it is supposed to be. I believe this causes the error
Reproduction
DTYPE = torch.bfloat16
ae = AutoencoderDC.from_pretrained(f"mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers", torch_dtype=DTYPE, cache_dir=f"dc_ae", revision="main").to(device).eval()
example_latents = torch.randn(4, 32, 3, 3).to(device, dtype=DTYPE)with torch.no_grad():
example_ground_truth = ae.decode(example_latents).sample
Logs
File "D:\MyStuff\Programming\Python\AI\projects\reimei\train.py", line 141, in <module>
example_ground_truth = ae.decode(example_latents).sample
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Users\ASUS\AppData\Local\Programs\Python\Python312\Lib\site-packages\diffusers\utils\accelerate_utils.py", line 46, in wrapper
return method(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Users\ASUS\AppData\Local\Programs\Python\Python312\Lib\site-packages\diffusers\models\autoencoders\autoencoder_dc.py", line 603, in decode
decoded = self._decode(z)
^^^^^^^^^^^^^^^
File "D:\Users\ASUS\AppData\Local\Programs\Python\Python312\Lib\site-packages\diffusers\models\autoencoders\autoencoder_dc.py", line 580, in _decode
decoded = self.decoder(z)
^^^^^^^^^^^^^^^
File "D:\Users\ASUS\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\modules\module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Users\ASUS\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\modules\module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Users\ASUS\AppData\Local\Programs\Python\Python312\Lib\site-packages\diffusers\models\autoencoders\autoencoder_dc.py", line 370, in forward
hidden_states = up_block(hidden_states)
^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Users\ASUS\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\modules\module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Users\ASUS\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\modules\module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Users\ASUS\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\modules\container.py", line 250, in forward
input = module(input)
^^^^^^^^^^^^^
File "D:\Users\ASUS\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\modules\module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Users\ASUS\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\modules\module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Users\ASUS\AppData\Local\Programs\Python\Python312\Lib\site-packages\diffusers\models\autoencoders\autoencoder_dc.py", line 93, in forward
x = self.attn(x)
^^^^^^^^^^^^
File "D:\Users\ASUS\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\modules\module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Users\ASUS\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\modules\module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Users\ASUS\AppData\Local\Programs\Python\Python312\Lib\site-packages\diffusers\models\attention_processor.py", line 906, in forward
return self.processor(self, hidden_states)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Users\ASUS\AppData\Local\Programs\Python\Python312\Lib\site-packages\diffusers\models\attention_processor.py", line 5806, in __call__
hidden_states = attn.apply_quadratic_attention(query, key, value)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Users\ASUS\AppData\Local\Programs\Python\Python312\Lib\site-packages\diffusers\models\attention_processor.py", line 902, in apply_quadratic_attention
hidden_states = torch.matmul(value, scores)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: expected scalar type Float but found BFloat16System Info
- π€ Diffusers version: 0.32.2
- Platform: Windows-11-10.0.22631-SP0
- Running on Google Colab?: No
- Python version: 3.12.4
- PyTorch version (GPU?): 2.6.0+cu126 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.27.1
- Transformers version: 4.49.0
- Accelerate version: 1.4.0
- PEFT version: not installed
- Bitsandbytes version: not installed
- Safetensors version: 0.5.2
- xFormers version: 0.0.29.post3
- Accelerator: NVIDIA GeForce RTX 3090, 24576 MiB
- Using GPU in script?: Yes
- Using distributed or parallel set-up in script?: Yes but only 1 gpu on my machine
Who can help?
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working