Skip to content

Commit 46c5e58

Browse files
committed
feat: Add transformer cache context for conditional and unconditional predictions for skyreels-v2 pipes.
1 parent 17c0e79 commit 46c5e58

File tree

5 files changed

+72
-59
lines changed

5 files changed

+72
-59
lines changed

src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -545,22 +545,24 @@ def __call__(
545545
latent_model_input = latents.to(transformer_dtype)
546546
timestep = t.expand(latents.shape[0])
547547

548-
noise_pred = self.transformer(
549-
hidden_states=latent_model_input,
550-
timestep=timestep,
551-
encoder_hidden_states=prompt_embeds,
552-
attention_kwargs=attention_kwargs,
553-
return_dict=False,
554-
)[0]
555-
556-
if self.do_classifier_free_guidance:
557-
noise_uncond = self.transformer(
548+
with self.transformer.cache_context("cond"):
549+
noise_pred = self.transformer(
558550
hidden_states=latent_model_input,
559551
timestep=timestep,
560-
encoder_hidden_states=negative_prompt_embeds,
552+
encoder_hidden_states=prompt_embeds,
561553
attention_kwargs=attention_kwargs,
562554
return_dict=False,
563555
)[0]
556+
557+
if self.do_classifier_free_guidance:
558+
with self.transformer.cache_context("uncond"):
559+
noise_uncond = self.transformer(
560+
hidden_states=latent_model_input,
561+
timestep=timestep,
562+
encoder_hidden_states=negative_prompt_embeds,
563+
attention_kwargs=attention_kwargs,
564+
return_dict=False,
565+
)[0]
564566
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
565567

566568
# compute the previous noisy sample x_t -> x_t-1

src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -887,25 +887,28 @@ def __call__(
887887
)
888888
timestep[:, valid_interval_start:prefix_video_latents_frames] = addnoise_condition
889889

890-
noise_pred = self.transformer(
891-
hidden_states=latent_model_input,
892-
timestep=timestep,
893-
encoder_hidden_states=prompt_embeds,
894-
enable_diffusion_forcing=True,
895-
fps=fps_embeds,
896-
attention_kwargs=attention_kwargs,
897-
return_dict=False,
898-
)[0]
899-
if self.do_classifier_free_guidance:
900-
noise_uncond = self.transformer(
890+
with self.transformer.cache_context("cond"):
891+
noise_pred = self.transformer(
901892
hidden_states=latent_model_input,
902893
timestep=timestep,
903-
encoder_hidden_states=negative_prompt_embeds,
894+
encoder_hidden_states=prompt_embeds,
904895
enable_diffusion_forcing=True,
905896
fps=fps_embeds,
906897
attention_kwargs=attention_kwargs,
907898
return_dict=False,
908899
)[0]
900+
901+
if self.do_classifier_free_guidance:
902+
with self.transformer.cache_context("uncond"):
903+
noise_uncond = self.transformer(
904+
hidden_states=latent_model_input,
905+
timestep=timestep,
906+
encoder_hidden_states=negative_prompt_embeds,
907+
enable_diffusion_forcing=True,
908+
fps=fps_embeds,
909+
attention_kwargs=attention_kwargs,
910+
return_dict=False,
911+
)[0]
909912
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
910913

911914
update_mask_i = step_update_mask[i]

src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -966,25 +966,28 @@ def __call__(
966966
)
967967
timestep[:, valid_interval_start:prefix_video_latents_frames] = addnoise_condition
968968

969-
noise_pred = self.transformer(
970-
hidden_states=latent_model_input,
971-
timestep=timestep,
972-
encoder_hidden_states=prompt_embeds,
973-
enable_diffusion_forcing=True,
974-
fps=fps_embeds,
975-
attention_kwargs=attention_kwargs,
976-
return_dict=False,
977-
)[0]
978-
if self.do_classifier_free_guidance:
979-
noise_uncond = self.transformer(
969+
with self.transformer.cache_context("cond"):
970+
noise_pred = self.transformer(
980971
hidden_states=latent_model_input,
981972
timestep=timestep,
982-
encoder_hidden_states=negative_prompt_embeds,
973+
encoder_hidden_states=prompt_embeds,
983974
enable_diffusion_forcing=True,
984975
fps=fps_embeds,
985976
attention_kwargs=attention_kwargs,
986977
return_dict=False,
987978
)[0]
979+
980+
if self.do_classifier_free_guidance:
981+
with self.transformer.cache_context("uncond"):
982+
noise_uncond = self.transformer(
983+
hidden_states=latent_model_input,
984+
timestep=timestep,
985+
encoder_hidden_states=negative_prompt_embeds,
986+
enable_diffusion_forcing=True,
987+
fps=fps_embeds,
988+
attention_kwargs=attention_kwargs,
989+
return_dict=False,
990+
)[0]
988991
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
989992

990993
update_mask_i = step_update_mask[i]

src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -974,25 +974,28 @@ def __call__(
974974
)
975975
timestep[:, valid_interval_start:prefix_video_latents_frames] = addnoise_condition
976976

977-
noise_pred = self.transformer(
978-
hidden_states=latent_model_input,
979-
timestep=timestep,
980-
encoder_hidden_states=prompt_embeds,
981-
enable_diffusion_forcing=True,
982-
fps=fps_embeds,
983-
attention_kwargs=attention_kwargs,
984-
return_dict=False,
985-
)[0]
986-
if self.do_classifier_free_guidance:
987-
noise_uncond = self.transformer(
977+
with self.transformer.cache_context("cond"):
978+
noise_pred = self.transformer(
988979
hidden_states=latent_model_input,
989980
timestep=timestep,
990-
encoder_hidden_states=negative_prompt_embeds,
981+
encoder_hidden_states=prompt_embeds,
991982
enable_diffusion_forcing=True,
992983
fps=fps_embeds,
993984
attention_kwargs=attention_kwargs,
994985
return_dict=False,
995986
)[0]
987+
988+
if self.do_classifier_free_guidance:
989+
with self.transformer.cache_context("uncond"):
990+
noise_uncond = self.transformer(
991+
hidden_states=latent_model_input,
992+
timestep=timestep,
993+
encoder_hidden_states=negative_prompt_embeds,
994+
enable_diffusion_forcing=True,
995+
fps=fps_embeds,
996+
attention_kwargs=attention_kwargs,
997+
return_dict=False,
998+
)[0]
996999
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
9971000

9981001
update_mask_i = step_update_mask[i]

src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -678,24 +678,26 @@ def __call__(
678678
latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
679679
timestep = t.expand(latents.shape[0])
680680

681-
noise_pred = self.transformer(
682-
hidden_states=latent_model_input,
683-
timestep=timestep,
684-
encoder_hidden_states=prompt_embeds,
685-
encoder_hidden_states_image=image_embeds,
686-
attention_kwargs=attention_kwargs,
687-
return_dict=False,
688-
)[0]
689-
690-
if self.do_classifier_free_guidance:
691-
noise_uncond = self.transformer(
681+
with self.transformer.cache_context("cond"):
682+
noise_pred = self.transformer(
692683
hidden_states=latent_model_input,
693684
timestep=timestep,
694-
encoder_hidden_states=negative_prompt_embeds,
685+
encoder_hidden_states=prompt_embeds,
695686
encoder_hidden_states_image=image_embeds,
696687
attention_kwargs=attention_kwargs,
697688
return_dict=False,
698689
)[0]
690+
691+
if self.do_classifier_free_guidance:
692+
with self.transformer.cache_context("uncond"):
693+
noise_uncond = self.transformer(
694+
hidden_states=latent_model_input,
695+
timestep=timestep,
696+
encoder_hidden_states=negative_prompt_embeds,
697+
encoder_hidden_states_image=image_embeds,
698+
attention_kwargs=attention_kwargs,
699+
return_dict=False,
700+
)[0]
699701
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
700702

701703
# compute the previous noisy sample x_t -> x_t-1

0 commit comments

Comments
 (0)