Skip to content

Commit e983608

Browse files
authored
Fix fp16/bf16 for pipeline_wan_vace.py
1 parent 51e8787 commit e983608

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

src/diffusers/pipelines/wan/pipeline_wan_vace.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -527,10 +527,16 @@ def prepare_video_latents(
527527
else:
528528
mask = mask.to(dtype=vae_dtype)
529529
mask = torch.where(mask > 0.5, 1.0, 0.0)
530-
inactive = video * (1 - mask)
531-
reactive = video * mask
530+
531+
inactive: torch.Tensor = video * (1 - mask)
532+
reactive: torch.Tensor = video * mask
533+
534+
inactive = inactive.to(dtype=vae_dtype)
535+
reactive = reactive.to(dtype=vae_dtype)
536+
532537
inactive = retrieve_latents(self.vae.encode(inactive), generator, sample_mode="argmax")
533538
reactive = retrieve_latents(self.vae.encode(reactive), generator, sample_mode="argmax")
539+
534540
inactive = ((inactive.float() - latents_mean) * latents_std).to(vae_dtype)
535541
reactive = ((reactive.float() - latents_mean) * latents_std).to(vae_dtype)
536542
latents = torch.cat([inactive, reactive], dim=1)

0 commit comments

Comments
 (0)