Skip to content

Commit 9f3c0fd

Browse files
authored
Avoiding graph break by changing the way we infer dtype in vae.decoder (#12512)
* Changing the way we infer dtype to avoid force evaluation of lazy tensors * changing way to infer dtype to ensure type consistency * more robust infering of dtype * removing the upscale dtype entirely
1 parent 84e1657 commit 9f3c0fd

File tree

1 file changed

+0
-3
lines changed
  • src/diffusers/models/autoencoders

1 file changed

+0
-3
lines changed

src/diffusers/models/autoencoders/vae.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -286,19 +286,16 @@ def forward(
286286

287287
sample = self.conv_in(sample)
288288

289-
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
290289
if torch.is_grad_enabled() and self.gradient_checkpointing:
291290
# middle
292291
sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds)
293-
sample = sample.to(upscale_dtype)
294292

295293
# up
296294
for up_block in self.up_blocks:
297295
sample = self._gradient_checkpointing_func(up_block, sample, latent_embeds)
298296
else:
299297
# middle
300298
sample = self.mid_block(sample, latent_embeds)
301-
sample = sample.to(upscale_dtype)
302299

303300
# up
304301
for up_block in self.up_blocks:

0 commit comments

Comments
 (0)