Skip to content

Commit 6ee6b53

Browse files
committed
fix: vae sampling mode
1 parent 6508da6 commit 6ee6b53

File tree

1 file changed

+1
-6
lines changed

1 file changed

+1
-6
lines changed

src/diffusers/pipelines/wan/pipeline_wan_video2video.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -419,12 +419,7 @@ def prepare_latents(
419419
)
420420

421421
if latents is None:
422-
if isinstance(generator, list):
423-
init_latents = [
424-
retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
425-
]
426-
else:
427-
init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video]
422+
init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), sample_mode="argmax") for vid in video]
428423

429424
init_latents = torch.cat(init_latents, dim=0).to(dtype)
430425

0 commit comments

Comments
 (0)