Skip to content

Commit fd06b6c

Browse files
authored
Fix bf16fp16 for pipeline_wan_vace.py
1 parent f5aa5fd commit fd06b6c

File tree

1 file changed

+3
-9
lines changed

1 file changed

+3
-9
lines changed

src/diffusers/pipelines/wan/pipeline_wan_vace.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -526,17 +526,11 @@ def prepare_video_latents(
526526
latents = ((latents.float() - latents_mean) * latents_std).to(vae_dtype)
527527
else:
528528
mask = mask.to(dtype=vae_dtype)
529-
mask = torch.where(mask > 0.5, 1.0, 0.0)
530-
531-
inactive: torch.Tensor = video * (1 - mask)
532-
reactive: torch.Tensor = video * mask
533-
534-
inactive = inactive.to(dtype=vae_dtype)
535-
reactive = reactive.to(dtype=vae_dtype)
536-
529+
mask = torch.where(mask > 0.5, 1.0, 0.0).to(dtype=vae_dtype)
530+
inactive = video * (1 - mask)
531+
reactive = video * mask
537532
inactive = retrieve_latents(self.vae.encode(inactive), generator, sample_mode="argmax")
538533
reactive = retrieve_latents(self.vae.encode(reactive), generator, sample_mode="argmax")
539-
540534
inactive = ((inactive.float() - latents_mean) * latents_std).to(vae_dtype)
541535
reactive = ((reactive.float() - latents_mean) * latents_std).to(vae_dtype)
542536
latents = torch.cat([inactive, reactive], dim=1)

0 commit comments

Comments
 (0)