Skip to content

Commit a38dd79

Browse files
wuyushuwyssayakpaulDN6
authored
[Pipeline] Fix error of SVD pipeline when num_videos_per_prompt > 1 (#7786)
swap the order for do_classifier_free_guidance concat with repeat Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Dhruv Nair <[email protected]>
1 parent b1c5817 commit a38dd79

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,9 @@ def _encode_vae_image(
199199
image = image.to(device=device)
200200
image_latents = self.vae.encode(image).latent_dist.mode()
201201

202+
# duplicate image_latents for each generation per prompt, using mps friendly method
203+
image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1)
204+
202205
if do_classifier_free_guidance:
203206
negative_image_latents = torch.zeros_like(image_latents)
204207

@@ -207,9 +210,6 @@ def _encode_vae_image(
207210
# to avoid doing two forward passes
208211
image_latents = torch.cat([negative_image_latents, image_latents])
209212

210-
# duplicate image_latents for each generation per prompt, using mps friendly method
211-
image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1)
212-
213213
return image_latents
214214

215215
def _get_add_time_ids(

0 commit comments

Comments
 (0)