-
Notifications
You must be signed in to change notification settings - Fork 6.4k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
When doing torch.compile with the decoder part of the VAE, a graph break gets hit. This makes it impossible to compile the model as a whole.
And these are the exact lines from diffusers library that cause the graph break:
diffusers/src/diffusers/models/autoencoders/vae.py:287-289
sample = self.conv_in(sample)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
Particularly, the next(iter()) call causes the evaluation of the lazy tensors that preceed this part of the graph, forcing conv_in module to compile independently.
It would be ideal if this upscale_dtype can be inferred in a different way
Thanks!
Reproduction
This is the repro code:
import torch
from diffusers import AutoencoderKL
vae = AutoencoderKL.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
subfolder="vae",
torch_dtype=torch.float16
)
model = vae.decoder
model = model.eval()
model = model.to(torch.bfloat16)
device = torch.device("cpu")
sample_img = torch.randn(1, 4, 64, 64, dtype=torch.bfloat16)
model.compile()
sample_img = sample_img.to(device)
model = model.to(device)
with torch.no_grad():
output = model(sample_img)
Logs
System Info
🤗 Diffusers version: 0.35.1
- Platform: Linux-5.4.0-212-generic-x86_64-with-glibc2.35
- Running on Google Colab?: No
- Python version: 3.11.13
- PyTorch version (GPU?): 2.7.0+cpu (False)
- Flax version (CPU?/GPU?/TPU?): 0.10.4 (cpu)
- Jax version: 0.7.1
- JaxLib version: 0.7.1
- Huggingface_hub version: 0.35.3
- Transformers version: 4.52.4
- Accelerate version: 1.10.1
- PEFT version: 0.17.1
- Bitsandbytes version: not installed
- Safetensors version: 0.6.2
- xFormers version: not installed
- Accelerator: NA
- Using GPU in script?: no, just cpu
- Using distributed or parallel set-up in script?: no, single core cpu
Who can help?
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working