Skip to content

Commit 185ed9a

Browse files
committed
add context for src pipelines
1 parent 17c0e79 commit 185ed9a

File tree

80 files changed

+983
-868
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

80 files changed

+983
-868
lines changed

src/diffusers/pipelines/allegro/pipeline_allegro.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -929,14 +929,15 @@ def __call__(
929929
timestep = t.expand(latent_model_input.shape[0])
930930

931931
# predict noise model_output
932-
noise_pred = self.transformer(
933-
hidden_states=latent_model_input,
934-
encoder_hidden_states=prompt_embeds,
935-
encoder_attention_mask=prompt_attention_mask,
936-
timestep=timestep,
937-
image_rotary_emb=image_rotary_emb,
938-
return_dict=False,
939-
)[0]
932+
with self.transformer.cache_context("cond"):
933+
noise_pred = self.transformer(
934+
hidden_states=latent_model_input,
935+
encoder_hidden_states=prompt_embeds,
936+
encoder_attention_mask=prompt_attention_mask,
937+
timestep=timestep,
938+
image_rotary_emb=image_rotary_emb,
939+
return_dict=False,
940+
)[0]
940941

941942
# perform guidance
942943
if do_classifier_free_guidance:

src/diffusers/pipelines/amused/pipeline_amused.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -281,13 +281,14 @@ def __call__(
281281
else:
282282
model_input = latents
283283

284-
model_output = self.transformer(
285-
model_input,
286-
micro_conds=micro_conds,
287-
pooled_text_emb=prompt_embeds,
288-
encoder_hidden_states=encoder_hidden_states,
289-
cross_attention_kwargs=cross_attention_kwargs,
290-
)
284+
with self.transformer.cache_context("cond"):
285+
model_output = self.transformer(
286+
model_input,
287+
micro_conds=micro_conds,
288+
pooled_text_emb=prompt_embeds,
289+
encoder_hidden_states=encoder_hidden_states,
290+
cross_attention_kwargs=cross_attention_kwargs,
291+
)
291292

292293
if guidance_scale > 1.0:
293294
uncond_logits, cond_logits = model_output.chunk(2)

src/diffusers/pipelines/amused/pipeline_amused_img2img.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -309,13 +309,14 @@ def __call__(
309309
else:
310310
model_input = latents
311311

312-
model_output = self.transformer(
313-
model_input,
314-
micro_conds=micro_conds,
315-
pooled_text_emb=prompt_embeds,
316-
encoder_hidden_states=encoder_hidden_states,
317-
cross_attention_kwargs=cross_attention_kwargs,
318-
)
312+
with self.transformer.cache_context("cond"):
313+
model_output = self.transformer(
314+
model_input,
315+
micro_conds=micro_conds,
316+
pooled_text_emb=prompt_embeds,
317+
encoder_hidden_states=encoder_hidden_states,
318+
cross_attention_kwargs=cross_attention_kwargs,
319+
)
319320

320321
if guidance_scale > 1.0:
321322
uncond_logits, cond_logits = model_output.chunk(2)

src/diffusers/pipelines/amused/pipeline_amused_inpaint.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -339,13 +339,14 @@ def __call__(
339339
else:
340340
model_input = latents
341341

342-
model_output = self.transformer(
343-
model_input,
344-
micro_conds=micro_conds,
345-
pooled_text_emb=prompt_embeds,
346-
encoder_hidden_states=encoder_hidden_states,
347-
cross_attention_kwargs=cross_attention_kwargs,
348-
)
342+
with self.transformer.cache_context("cond"):
343+
model_output = self.transformer(
344+
model_input,
345+
micro_conds=micro_conds,
346+
pooled_text_emb=prompt_embeds,
347+
encoder_hidden_states=encoder_hidden_states,
348+
cross_attention_kwargs=cross_attention_kwargs,
349+
)
349350

350351
if guidance_scale > 1.0:
351352
uncond_logits, cond_logits = model_output.chunk(2)

src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -615,13 +615,14 @@ def __call__(
615615
timestep = timestep.to(latents.device, dtype=latents.dtype)
616616

617617
# predict noise model_output
618-
noise_pred = self.transformer(
619-
latent_model_input,
620-
encoder_hidden_states=prompt_embeds,
621-
timestep=timestep,
622-
return_dict=False,
623-
attention_kwargs=self.attention_kwargs,
624-
)[0]
618+
with self.transformer.cache_context("cond"):
619+
noise_pred = self.transformer(
620+
latent_model_input,
621+
encoder_hidden_states=prompt_embeds,
622+
timestep=timestep,
623+
return_dict=False,
624+
attention_kwargs=self.attention_kwargs,
625+
)[0]
625626

626627
# perform guidance
627628
if do_classifier_free_guidance:

src/diffusers/pipelines/bria/pipeline_bria.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -662,15 +662,16 @@ def __call__(
662662
timestep = t.expand(latent_model_input.shape[0])
663663

664664
# This is predicts "v" from flow-matching or eps from diffusion
665-
noise_pred = self.transformer(
666-
hidden_states=latent_model_input,
667-
timestep=timestep,
668-
encoder_hidden_states=prompt_embeds,
669-
attention_kwargs=self.attention_kwargs,
670-
return_dict=False,
671-
txt_ids=text_ids,
672-
img_ids=latent_image_ids,
673-
)[0]
665+
with self.transformer.cache_context("cond"):
666+
noise_pred = self.transformer(
667+
hidden_states=latent_model_input,
668+
timestep=timestep,
669+
encoder_hidden_states=prompt_embeds,
670+
attention_kwargs=self.attention_kwargs,
671+
return_dict=False,
672+
txt_ids=text_ids,
673+
img_ids=latent_image_ids,
674+
)[0]
674675

675676
# perform guidance
676677
if self.do_classifier_free_guidance:

src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -705,16 +705,17 @@ def __call__(
705705
)
706706

707707
# This is predicts "v" from flow-matching or eps from diffusion
708-
noise_pred = self.transformer(
709-
hidden_states=latent_model_input,
710-
timestep=timestep,
711-
encoder_hidden_states=prompt_embeds,
712-
text_encoder_layers=prompt_layers,
713-
joint_attention_kwargs=self.joint_attention_kwargs,
714-
return_dict=False,
715-
txt_ids=text_ids,
716-
img_ids=latent_image_ids,
717-
)[0]
708+
with self.transformer.cache_context("cond"):
709+
noise_pred = self.transformer(
710+
hidden_states=latent_model_input,
711+
timestep=timestep,
712+
encoder_hidden_states=prompt_embeds,
713+
text_encoder_layers=prompt_layers,
714+
joint_attention_kwargs=self.joint_attention_kwargs,
715+
return_dict=False,
716+
txt_ids=text_ids,
717+
img_ids=latent_image_ids,
718+
)[0]
718719

719720
# perform guidance
720721
if guidance_scale > 1:

src/diffusers/pipelines/chroma/pipeline_chroma.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -906,30 +906,33 @@ def __call__(
906906
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
907907
timestep = t.expand(latents.shape[0]).to(latents.dtype)
908908

909-
noise_pred = self.transformer(
910-
hidden_states=latents,
911-
timestep=timestep / 1000,
912-
encoder_hidden_states=prompt_embeds,
913-
txt_ids=text_ids,
914-
img_ids=latent_image_ids,
915-
attention_mask=attention_mask,
916-
joint_attention_kwargs=self.joint_attention_kwargs,
917-
return_dict=False,
918-
)[0]
919-
920-
if self.do_classifier_free_guidance:
921-
if negative_image_embeds is not None:
922-
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
923-
neg_noise_pred = self.transformer(
909+
with self.transformer.cache_context("cond"):
910+
noise_pred = self.transformer(
924911
hidden_states=latents,
925912
timestep=timestep / 1000,
926-
encoder_hidden_states=negative_prompt_embeds,
927-
txt_ids=negative_text_ids,
913+
encoder_hidden_states=prompt_embeds,
914+
txt_ids=text_ids,
928915
img_ids=latent_image_ids,
929-
attention_mask=negative_attention_mask,
916+
attention_mask=attention_mask,
930917
joint_attention_kwargs=self.joint_attention_kwargs,
931918
return_dict=False,
932919
)[0]
920+
921+
if self.do_classifier_free_guidance:
922+
if negative_image_embeds is not None:
923+
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
924+
925+
with self.transformer.cache_context("uncond"):
926+
neg_noise_pred = self.transformer(
927+
hidden_states=latents,
928+
timestep=timestep / 1000,
929+
encoder_hidden_states=negative_prompt_embeds,
930+
txt_ids=negative_text_ids,
931+
img_ids=latent_image_ids,
932+
attention_mask=negative_attention_mask,
933+
joint_attention_kwargs=self.joint_attention_kwargs,
934+
return_dict=False,
935+
)[0]
933936
noise_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred)
934937

935938
# compute the previous noisy sample x_t -> x_t-1

src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -989,31 +989,33 @@ def __call__(
989989
if image_embeds is not None:
990990
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
991991

992-
noise_pred = self.transformer(
993-
hidden_states=latents,
994-
timestep=timestep / 1000,
995-
encoder_hidden_states=prompt_embeds,
996-
txt_ids=text_ids,
997-
img_ids=latent_image_ids,
998-
attention_mask=attention_mask,
999-
joint_attention_kwargs=self.joint_attention_kwargs,
1000-
return_dict=False,
1001-
)[0]
1002-
1003-
if self.do_classifier_free_guidance:
1004-
if negative_image_embeds is not None:
1005-
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
1006-
1007-
noise_pred_uncond = self.transformer(
992+
with self.transformer.cache_context("cond"):
993+
noise_pred = self.transformer(
1008994
hidden_states=latents,
1009995
timestep=timestep / 1000,
1010-
encoder_hidden_states=negative_prompt_embeds,
1011-
txt_ids=negative_text_ids,
996+
encoder_hidden_states=prompt_embeds,
997+
txt_ids=text_ids,
1012998
img_ids=latent_image_ids,
1013-
attention_mask=negative_attention_mask,
999+
attention_mask=attention_mask,
10141000
joint_attention_kwargs=self.joint_attention_kwargs,
10151001
return_dict=False,
10161002
)[0]
1003+
1004+
if self.do_classifier_free_guidance:
1005+
if negative_image_embeds is not None:
1006+
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
1007+
1008+
with self.transformer.cache_context("uncond"):
1009+
noise_pred_uncond = self.transformer(
1010+
hidden_states=latents,
1011+
timestep=timestep / 1000,
1012+
encoder_hidden_states=negative_prompt_embeds,
1013+
txt_ids=negative_text_ids,
1014+
img_ids=latent_image_ids,
1015+
attention_mask=negative_attention_mask,
1016+
joint_attention_kwargs=self.joint_attention_kwargs,
1017+
return_dict=False,
1018+
)[0]
10171019
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
10181020

10191021
# compute the previous noisy sample x_t -> x_t-1

src/diffusers/pipelines/chronoedit/pipeline_chronoedit.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -680,24 +680,26 @@ def __call__(
680680
latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
681681
timestep = t.expand(latents.shape[0])
682682

683-
noise_pred = self.transformer(
684-
hidden_states=latent_model_input,
685-
timestep=timestep,
686-
encoder_hidden_states=prompt_embeds,
687-
encoder_hidden_states_image=image_embeds,
688-
attention_kwargs=attention_kwargs,
689-
return_dict=False,
690-
)[0]
691-
692-
if self.do_classifier_free_guidance:
693-
noise_uncond = self.transformer(
683+
with self.transformer.cache_context("cond"):
684+
noise_pred = self.transformer(
694685
hidden_states=latent_model_input,
695686
timestep=timestep,
696-
encoder_hidden_states=negative_prompt_embeds,
687+
encoder_hidden_states=prompt_embeds,
697688
encoder_hidden_states_image=image_embeds,
698689
attention_kwargs=attention_kwargs,
699690
return_dict=False,
700691
)[0]
692+
693+
if self.do_classifier_free_guidance:
694+
with self.transformer.cache_context("uncond"):
695+
noise_uncond = self.transformer(
696+
hidden_states=latent_model_input,
697+
timestep=timestep,
698+
encoder_hidden_states=negative_prompt_embeds,
699+
encoder_hidden_states_image=image_embeds,
700+
attention_kwargs=attention_kwargs,
701+
return_dict=False,
702+
)[0]
701703
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
702704

703705
# compute the previous noisy sample x_t -> x_t-1

0 commit comments

Comments
 (0)