Skip to content

Commit b67ba2c

Browse files
committed
add context for examples
1 parent 185ed9a commit b67ba2c

21 files changed

+246
-216
lines changed

examples/community/cogvideox_ddim_inversion.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -522,14 +522,15 @@ def sample(
522522
timestep = t.expand(latent_model_input.shape[0])
523523

524524
# predict noise model_output
525-
noise_pred = self.transformer(
526-
hidden_states=latent_model_input,
527-
encoder_hidden_states=prompt_embeds,
528-
timestep=timestep,
529-
image_rotary_emb=image_rotary_emb,
530-
attention_kwargs=attention_kwargs,
531-
return_dict=False,
532-
)[0]
525+
with self.transformer.cache_context("cond"):
526+
noise_pred = self.transformer(
527+
hidden_states=latent_model_input,
528+
encoder_hidden_states=prompt_embeds,
529+
timestep=timestep,
530+
image_rotary_emb=image_rotary_emb,
531+
attention_kwargs=attention_kwargs,
532+
return_dict=False,
533+
)[0]
533534
noise_pred = noise_pred.float()
534535

535536
if reference_latents is not None: # Recover the original batch size

examples/community/pipeline_flux_differential_img2img.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -954,17 +954,18 @@ def __call__(
954954

955955
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
956956
timestep = t.expand(latents.shape[0]).to(latents.dtype)
957-
noise_pred = self.transformer(
958-
hidden_states=latents,
959-
timestep=timestep / 1000,
960-
guidance=guidance,
961-
pooled_projections=pooled_prompt_embeds,
962-
encoder_hidden_states=prompt_embeds,
963-
txt_ids=text_ids,
964-
img_ids=latent_image_ids,
965-
joint_attention_kwargs=self.joint_attention_kwargs,
966-
return_dict=False,
967-
)[0]
957+
with self.transformer.cache_context("cond"):
958+
noise_pred = self.transformer(
959+
hidden_states=latents,
960+
timestep=timestep / 1000,
961+
guidance=guidance,
962+
pooled_projections=pooled_prompt_embeds,
963+
encoder_hidden_states=prompt_embeds,
964+
txt_ids=text_ids,
965+
img_ids=latent_image_ids,
966+
joint_attention_kwargs=self.joint_attention_kwargs,
967+
return_dict=False,
968+
)[0]
968969

969970
# compute the previous noisy sample x_t -> x_t-1
970971
latents_dtype = latents.dtype

examples/community/pipeline_flux_kontext_multiple_images.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1150,33 +1150,35 @@ def __call__(
11501150
latent_model_input = torch.cat([latents, image_latents], dim=1)
11511151
timestep = t.expand(latents.shape[0]).to(latents.dtype)
11521152

1153-
noise_pred = self.transformer(
1154-
hidden_states=latent_model_input,
1155-
timestep=timestep / 1000,
1156-
guidance=guidance,
1157-
pooled_projections=pooled_prompt_embeds,
1158-
encoder_hidden_states=prompt_embeds,
1159-
txt_ids=text_ids,
1160-
img_ids=latent_ids,
1161-
joint_attention_kwargs=self.joint_attention_kwargs,
1162-
return_dict=False,
1163-
)[0]
1164-
noise_pred = noise_pred[:, : latents.size(1)]
1165-
1166-
if do_true_cfg:
1167-
if negative_image_embeds is not None:
1168-
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
1169-
neg_noise_pred = self.transformer(
1153+
with self.transformer.cache_context("cond"):
1154+
noise_pred = self.transformer(
11701155
hidden_states=latent_model_input,
11711156
timestep=timestep / 1000,
11721157
guidance=guidance,
1173-
pooled_projections=negative_pooled_prompt_embeds,
1174-
encoder_hidden_states=negative_prompt_embeds,
1175-
txt_ids=negative_text_ids,
1158+
pooled_projections=pooled_prompt_embeds,
1159+
encoder_hidden_states=prompt_embeds,
1160+
txt_ids=text_ids,
11761161
img_ids=latent_ids,
11771162
joint_attention_kwargs=self.joint_attention_kwargs,
11781163
return_dict=False,
11791164
)[0]
1165+
noise_pred = noise_pred[:, : latents.size(1)]
1166+
1167+
if do_true_cfg:
1168+
if negative_image_embeds is not None:
1169+
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
1170+
with self.transformer.cache_context("uncond"):
1171+
neg_noise_pred = self.transformer(
1172+
hidden_states=latent_model_input,
1173+
timestep=timestep / 1000,
1174+
guidance=guidance,
1175+
pooled_projections=negative_pooled_prompt_embeds,
1176+
encoder_hidden_states=negative_prompt_embeds,
1177+
txt_ids=negative_text_ids,
1178+
img_ids=latent_ids,
1179+
joint_attention_kwargs=self.joint_attention_kwargs,
1180+
return_dict=False,
1181+
)[0]
11801182
neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
11811183
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
11821184

examples/community/pipeline_flux_rf_inversion.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -888,17 +888,18 @@ def __call__(
888888
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
889889
timestep = t.expand(latents.shape[0]).to(latents.dtype)
890890

891-
noise_pred = self.transformer(
892-
hidden_states=latents,
893-
timestep=timestep / 1000,
894-
guidance=guidance,
895-
pooled_projections=pooled_prompt_embeds,
896-
encoder_hidden_states=prompt_embeds,
897-
txt_ids=text_ids,
898-
img_ids=latent_image_ids,
899-
joint_attention_kwargs=self.joint_attention_kwargs,
900-
return_dict=False,
901-
)[0]
891+
with self.transformer.cache_context("cond"):
892+
noise_pred = self.transformer(
893+
hidden_states=latents,
894+
timestep=timestep / 1000,
895+
guidance=guidance,
896+
pooled_projections=pooled_prompt_embeds,
897+
encoder_hidden_states=prompt_embeds,
898+
txt_ids=text_ids,
899+
img_ids=latent_image_ids,
900+
joint_attention_kwargs=self.joint_attention_kwargs,
901+
return_dict=False,
902+
)[0]
902903

903904
latents_dtype = latents.dtype
904905
if do_rf_inversion:
@@ -1058,17 +1059,18 @@ def invert(
10581059
timestep = torch.tensor(t_i, dtype=Y_t.dtype, device=device).repeat(batch_size)
10591060

10601061
# get the unconditional vector field
1061-
u_t_i = self.transformer(
1062-
hidden_states=Y_t,
1063-
timestep=timestep,
1064-
guidance=guidance,
1065-
pooled_projections=pooled_prompt_embeds,
1066-
encoder_hidden_states=prompt_embeds,
1067-
txt_ids=text_ids,
1068-
img_ids=latent_image_ids,
1069-
joint_attention_kwargs=self.joint_attention_kwargs,
1070-
return_dict=False,
1071-
)[0]
1062+
with self.transformer.cache_context("uncond"):
1063+
u_t_i = self.transformer(
1064+
hidden_states=Y_t,
1065+
timestep=timestep,
1066+
guidance=guidance,
1067+
pooled_projections=pooled_prompt_embeds,
1068+
encoder_hidden_states=prompt_embeds,
1069+
txt_ids=text_ids,
1070+
img_ids=latent_image_ids,
1071+
joint_attention_kwargs=self.joint_attention_kwargs,
1072+
return_dict=False,
1073+
)[0]
10721074

10731075
# get the conditional vector field
10741076
u_t_i_cond = (y_1 - Y_t) / (1 - t_i)

examples/community/pipeline_flux_semantic_guidance.py

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1135,23 +1135,25 @@ def __call__(
11351135
else:
11361136
guidance = None
11371137

1138-
noise_pred = self.transformer(
1139-
hidden_states=latents,
1140-
timestep=timestep / 1000,
1141-
guidance=guidance,
1142-
pooled_projections=pooled_prompt_embeds,
1143-
encoder_hidden_states=prompt_embeds,
1144-
txt_ids=text_ids,
1145-
img_ids=latent_image_ids,
1146-
joint_attention_kwargs=self.joint_attention_kwargs,
1147-
return_dict=False,
1148-
)[0]
1138+
with self.transformer.cache_context("cond"):
1139+
noise_pred = self.transformer(
1140+
hidden_states=latents,
1141+
timestep=timestep / 1000,
1142+
guidance=guidance,
1143+
pooled_projections=pooled_prompt_embeds,
1144+
encoder_hidden_states=prompt_embeds,
1145+
txt_ids=text_ids,
1146+
img_ids=latent_image_ids,
1147+
joint_attention_kwargs=self.joint_attention_kwargs,
1148+
return_dict=False,
1149+
)[0]
11491150

11501151
if enable_edit_guidance and max_edit_cooldown_steps >= i >= min_edit_warmup_steps:
11511152
noise_pred_edit_concepts = []
11521153
for e_embed, pooled_e_embed, e_text_id in zip(
11531154
editing_prompts_embeds, pooled_editing_prompt_embeds, edit_text_ids
11541155
):
1156+
# TODO-context
11551157
noise_pred_edit = self.transformer(
11561158
hidden_states=latents,
11571159
timestep=timestep / 1000,
@@ -1168,17 +1170,18 @@ def __call__(
11681170
if do_true_cfg:
11691171
if negative_image_embeds is not None:
11701172
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
1171-
noise_pred_uncond = self.transformer(
1172-
hidden_states=latents,
1173-
timestep=timestep / 1000,
1174-
guidance=guidance,
1175-
pooled_projections=negative_pooled_prompt_embeds,
1176-
encoder_hidden_states=negative_prompt_embeds,
1177-
txt_ids=text_ids,
1178-
img_ids=latent_image_ids,
1179-
joint_attention_kwargs=self.joint_attention_kwargs,
1180-
return_dict=False,
1181-
)[0]
1173+
with self.transformer.cache_context("uncond"):
1174+
noise_pred_uncond = self.transformer(
1175+
hidden_states=latents,
1176+
timestep=timestep / 1000,
1177+
guidance=guidance,
1178+
pooled_projections=negative_pooled_prompt_embeds,
1179+
encoder_hidden_states=negative_prompt_embeds,
1180+
txt_ids=text_ids,
1181+
img_ids=latent_image_ids,
1182+
joint_attention_kwargs=self.joint_attention_kwargs,
1183+
return_dict=False,
1184+
)[0]
11821185
noise_guidance = true_cfg_scale * (noise_pred - noise_pred_uncond)
11831186
else:
11841187
noise_pred_uncond = noise_pred

examples/community/pipeline_flux_with_cfg.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -815,17 +815,18 @@ def __call__(
815815
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
816816
timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
817817

818-
noise_pred = self.transformer(
819-
hidden_states=latent_model_input,
820-
timestep=timestep / 1000,
821-
guidance=guidance,
822-
pooled_projections=pooled_prompt_embeds,
823-
encoder_hidden_states=prompt_embeds,
824-
txt_ids=text_ids,
825-
img_ids=latent_image_ids,
826-
joint_attention_kwargs=self.joint_attention_kwargs,
827-
return_dict=False,
828-
)[0]
818+
with self.transformer.cache_context("cond"):
819+
noise_pred = self.transformer(
820+
hidden_states=latent_model_input,
821+
timestep=timestep / 1000,
822+
guidance=guidance,
823+
pooled_projections=pooled_prompt_embeds,
824+
encoder_hidden_states=prompt_embeds,
825+
txt_ids=text_ids,
826+
img_ids=latent_image_ids,
827+
joint_attention_kwargs=self.joint_attention_kwargs,
828+
return_dict=False,
829+
)[0]
829830

830831
if do_true_cfg:
831832
neg_noise_pred, noise_pred = noise_pred.chunk(2)

examples/community/pipeline_hunyuandit_differential_img2img.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1074,18 +1074,19 @@ def __call__(
10741074
)
10751075

10761076
# predict the noise residual
1077-
noise_pred = self.transformer(
1078-
latent_model_input,
1079-
t_expand,
1080-
encoder_hidden_states=prompt_embeds,
1081-
text_embedding_mask=prompt_attention_mask,
1082-
encoder_hidden_states_t5=prompt_embeds_2,
1083-
text_embedding_mask_t5=prompt_attention_mask_2,
1084-
image_meta_size=add_time_ids,
1085-
style=style,
1086-
image_rotary_emb=image_rotary_emb,
1087-
return_dict=False,
1088-
)[0]
1077+
with self.transformer.cache_context("cond"):
1078+
noise_pred = self.transformer(
1079+
latent_model_input,
1080+
t_expand,
1081+
encoder_hidden_states=prompt_embeds,
1082+
text_embedding_mask=prompt_attention_mask,
1083+
encoder_hidden_states_t5=prompt_embeds_2,
1084+
text_embedding_mask_t5=prompt_attention_mask_2,
1085+
image_meta_size=add_time_ids,
1086+
style=style,
1087+
image_rotary_emb=image_rotary_emb,
1088+
return_dict=False,
1089+
)[0]
10891090

10901091
noise_pred, _ = noise_pred.chunk(2, dim=1)
10911092

examples/community/pipeline_stable_diffusion_3_differential_img2img.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -918,13 +918,14 @@ def __call__(
918918
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
919919
timestep = t.expand(latent_model_input.shape[0])
920920

921-
noise_pred = self.transformer(
922-
hidden_states=latent_model_input,
923-
timestep=timestep,
924-
encoder_hidden_states=prompt_embeds,
925-
pooled_projections=pooled_prompt_embeds,
926-
return_dict=False,
927-
)[0]
921+
with self.transformer.cache_context("cond"):
922+
noise_pred = self.transformer(
923+
hidden_states=latent_model_input,
924+
timestep=timestep,
925+
encoder_hidden_states=prompt_embeds,
926+
pooled_projections=pooled_prompt_embeds,
927+
return_dict=False,
928+
)[0]
928929

929930
# perform guidance
930931
if self.do_classifier_free_guidance:

examples/community/pipeline_stable_diffusion_3_instruct_pix2pix.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1178,14 +1178,15 @@ def __call__(
11781178
timestep = t.expand(latent_model_input.shape[0])
11791179
scaled_latent_model_input = torch.cat([latent_model_input, image_latents], dim=1)
11801180

1181-
noise_pred = self.transformer(
1182-
hidden_states=scaled_latent_model_input,
1183-
timestep=timestep,
1184-
encoder_hidden_states=prompt_embeds,
1185-
pooled_projections=pooled_prompt_embeds,
1186-
joint_attention_kwargs=self.joint_attention_kwargs,
1187-
return_dict=False,
1188-
)[0]
1181+
with self.transformer.cache_context("cond"):
1182+
noise_pred = self.transformer(
1183+
hidden_states=scaled_latent_model_input,
1184+
timestep=timestep,
1185+
encoder_hidden_states=prompt_embeds,
1186+
pooled_projections=pooled_prompt_embeds,
1187+
joint_attention_kwargs=self.joint_attention_kwargs,
1188+
return_dict=False,
1189+
)[0]
11891190

11901191
# perform guidance
11911192
if self.do_classifier_free_guidance:
@@ -1204,6 +1205,7 @@ def __call__(
12041205
if skip_guidance_layers is not None and should_skip_layers:
12051206
timestep = t.expand(latents.shape[0])
12061207
latent_model_input = latents
1208+
# TODO-context
12071209
noise_pred_skip_layers = self.transformer(
12081210
hidden_states=latent_model_input,
12091211
timestep=timestep,

examples/community/pipeline_stg_cogvideox.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -793,14 +793,15 @@ def __call__(
793793
timestep = t.expand(latent_model_input.shape[0])
794794

795795
# predict noise model_output
796-
noise_pred = self.transformer(
797-
hidden_states=latent_model_input,
798-
encoder_hidden_states=prompt_embeds,
799-
timestep=timestep,
800-
image_rotary_emb=image_rotary_emb,
801-
attention_kwargs=attention_kwargs,
802-
return_dict=False,
803-
)[0]
796+
with self.transformer.cache_context("cond"):
797+
noise_pred = self.transformer(
798+
hidden_states=latent_model_input,
799+
encoder_hidden_states=prompt_embeds,
800+
timestep=timestep,
801+
image_rotary_emb=image_rotary_emb,
802+
attention_kwargs=attention_kwargs,
803+
return_dict=False,
804+
)[0]
804805
noise_pred = noise_pred.float()
805806

806807
# perform guidance

0 commit comments

Comments
 (0)