Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/diffusers/pipelines/wan/pipeline_wan_i2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,11 @@ def retrieve_latents(
latents_mean: torch.Tensor,
latents_std: torch.Tensor,
generator: Optional[torch.Generator] = None,
sample_mode: str = "sample",
sample_mode: str = "none",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is actually should be same as sample_mode == "argmax" , no?

def mode(self) -> torch.Tensor:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think it is just without any extra computation cost.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it, let's :

  1. still use sample_mode="argmax" so that this is argument is consistent acorss all pipelines
  2. remove the code that's not needed for wan (I think all of the other mode lol

):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
if hasattr(encoder_output, "latent_dist") and sample_mode == "none":
return (encoder_output.latent_dist.mean - latents_mean) * latents_std
elif hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
encoder_output.latent_dist.mean = (encoder_output.latent_dist.mean - latents_mean) * latents_std
encoder_output.latent_dist.logvar = torch.clamp(
(encoder_output.latent_dist.logvar - latents_mean) * latents_std, -30.0, 20.0
Expand Down