Skip to content

Commit 566a915

Browse files
committed
fix_wan_i2v_quality
1 parent 1001425 commit 566a915

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

src/diffusers/pipelines/wan/pipeline_wan_i2v.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,11 @@ def retrieve_latents(
113113
latents_mean: torch.Tensor,
114114
latents_std: torch.Tensor,
115115
generator: Optional[torch.Generator] = None,
116-
sample_mode: str = "sample",
116+
sample_mode: str = "none",
117117
):
118-
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
118+
if hasattr(encoder_output, "latent_dist") and sample_mode == "none":
119+
return (encoder_output.latent_dist.mean - latents_mean) * latents_std
120+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
119121
encoder_output.latent_dist.mean = (encoder_output.latent_dist.mean - latents_mean) * latents_std
120122
encoder_output.latent_dist.logvar = torch.clamp(
121123
(encoder_output.latent_dist.logvar - latents_mean) * latents_std, -30.0, 20.0

0 commit comments

Comments
 (0)