diff --git a/src/diffusers/models/autoencoders/vae.py b/src/diffusers/models/autoencoders/vae.py index 7b171961254a..9c6031a988f9 100644 --- a/src/diffusers/models/autoencoders/vae.py +++ b/src/diffusers/models/autoencoders/vae.py @@ -286,11 +286,9 @@ def forward( sample = self.conv_in(sample) - upscale_dtype = next(iter(self.up_blocks.parameters())).dtype if torch.is_grad_enabled() and self.gradient_checkpointing: # middle sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds) - sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: @@ -298,7 +296,6 @@ def forward( else: # middle sample = self.mid_block(sample, latent_embeds) - sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: