Skip to content

Commit f417c26

Browse files
fix(vae): Fix dtype mismatch in FP32 VAE decode mode
The previous mixed-precision optimization for FP32 mode only converted some VAE decoder layers (post_quant_conv, conv_in, mid_block) to the latents dtype while leaving others (up_blocks, conv_norm_out) in float32. This caused "expected scalar type Half but found Float" errors after recent diffusers updates. Simplify FP32 mode to consistently use float32 for both VAE and latents, removing the incomplete mixed-precision logic. This trades some VRAM usage for stability and correctness. Also removes now-unused attention processor imports.
1 parent 4ce0ef5 commit f417c26

File tree

1 file changed

+2
-25
lines changed

1 file changed

+2
-25
lines changed

invokeai/app/invocations/latents_to_image.py

Lines changed: 2 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,6 @@
22

33
import torch
44
from diffusers.image_processor import VaeImageProcessor
5-
from diffusers.models.attention_processor import (
6-
AttnProcessor2_0,
7-
LoRAAttnProcessor2_0,
8-
LoRAXFormersAttnProcessor,
9-
XFormersAttnProcessor,
10-
)
115
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
126
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
137

@@ -77,26 +71,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
7771
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
7872
latents = latents.to(TorchDevice.choose_torch_device())
7973
if self.fp32:
74+
# FP32 mode: convert everything to float32 for maximum precision
8075
vae.to(dtype=torch.float32)
81-
82-
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
83-
vae.decoder.mid_block.attentions[0].processor,
84-
(
85-
AttnProcessor2_0,
86-
XFormersAttnProcessor,
87-
LoRAXFormersAttnProcessor,
88-
LoRAAttnProcessor2_0,
89-
),
90-
)
91-
# if xformers or torch_2_0 is used attention block does not need
92-
# to be in float32 which can save lots of memory
93-
if use_torch_2_0_or_xformers:
94-
vae.post_quant_conv.to(latents.dtype)
95-
vae.decoder.conv_in.to(latents.dtype)
96-
vae.decoder.mid_block.to(latents.dtype)
97-
else:
98-
latents = latents.float()
99-
76+
latents = latents.float()
10077
else:
10178
vae.to(dtype=torch.float16)
10279
latents = latents.half()

0 commit comments

Comments
 (0)