Skip to content

Commit 843e3f9

Browse files
authored
wan2.2 i2v FirstBlockCache fix (#12013)
* enable caching for WanImageToVideoPipeline * ruff format
1 parent d8854b8 commit 843e3f9

File tree

1 file changed

+15
-13
lines changed

1 file changed

+15
-13
lines changed

src/diffusers/pipelines/wan/pipeline_wan_i2v.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -750,25 +750,27 @@ def __call__(
750750
latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
751751
timestep = t.expand(latents.shape[0])
752752

753-
noise_pred = current_model(
754-
hidden_states=latent_model_input,
755-
timestep=timestep,
756-
encoder_hidden_states=prompt_embeds,
757-
encoder_hidden_states_image=image_embeds,
758-
attention_kwargs=attention_kwargs,
759-
return_dict=False,
760-
)[0]
761-
762-
if self.do_classifier_free_guidance:
763-
noise_uncond = current_model(
753+
with current_model.cache_context("cond"):
754+
noise_pred = current_model(
764755
hidden_states=latent_model_input,
765756
timestep=timestep,
766-
encoder_hidden_states=negative_prompt_embeds,
757+
encoder_hidden_states=prompt_embeds,
767758
encoder_hidden_states_image=image_embeds,
768759
attention_kwargs=attention_kwargs,
769760
return_dict=False,
770761
)[0]
771-
noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
762+
763+
if self.do_classifier_free_guidance:
764+
with current_model.cache_context("uncond"):
765+
noise_uncond = current_model(
766+
hidden_states=latent_model_input,
767+
timestep=timestep,
768+
encoder_hidden_states=negative_prompt_embeds,
769+
encoder_hidden_states_image=image_embeds,
770+
attention_kwargs=attention_kwargs,
771+
return_dict=False,
772+
)[0]
773+
noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
772774

773775
# compute the previous noisy sample x_t -> x_t-1
774776
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]

0 commit comments

Comments
 (0)