@@ -1135,23 +1135,25 @@ def __call__(
11351135 else :
11361136 guidance = None
11371137
1138- noise_pred = self .transformer (
1139- hidden_states = latents ,
1140- timestep = timestep / 1000 ,
1141- guidance = guidance ,
1142- pooled_projections = pooled_prompt_embeds ,
1143- encoder_hidden_states = prompt_embeds ,
1144- txt_ids = text_ids ,
1145- img_ids = latent_image_ids ,
1146- joint_attention_kwargs = self .joint_attention_kwargs ,
1147- return_dict = False ,
1148- )[0 ]
1138+ with self .transformer .cache_context ("cond" ):
1139+ noise_pred = self .transformer (
1140+ hidden_states = latents ,
1141+ timestep = timestep / 1000 ,
1142+ guidance = guidance ,
1143+ pooled_projections = pooled_prompt_embeds ,
1144+ encoder_hidden_states = prompt_embeds ,
1145+ txt_ids = text_ids ,
1146+ img_ids = latent_image_ids ,
1147+ joint_attention_kwargs = self .joint_attention_kwargs ,
1148+ return_dict = False ,
1149+ )[0 ]
11491150
11501151 if enable_edit_guidance and max_edit_cooldown_steps >= i >= min_edit_warmup_steps :
11511152 noise_pred_edit_concepts = []
11521153 for e_embed , pooled_e_embed , e_text_id in zip (
11531154 editing_prompts_embeds , pooled_editing_prompt_embeds , edit_text_ids
11541155 ):
1156+ # TODO-context
11551157 noise_pred_edit = self .transformer (
11561158 hidden_states = latents ,
11571159 timestep = timestep / 1000 ,
@@ -1168,17 +1170,18 @@ def __call__(
11681170 if do_true_cfg :
11691171 if negative_image_embeds is not None :
11701172 self ._joint_attention_kwargs ["ip_adapter_image_embeds" ] = negative_image_embeds
1171- noise_pred_uncond = self .transformer (
1172- hidden_states = latents ,
1173- timestep = timestep / 1000 ,
1174- guidance = guidance ,
1175- pooled_projections = negative_pooled_prompt_embeds ,
1176- encoder_hidden_states = negative_prompt_embeds ,
1177- txt_ids = text_ids ,
1178- img_ids = latent_image_ids ,
1179- joint_attention_kwargs = self .joint_attention_kwargs ,
1180- return_dict = False ,
1181- )[0 ]
1173+ with self .transformer .cache_context ("uncond" ):
1174+ noise_pred_uncond = self .transformer (
1175+ hidden_states = latents ,
1176+ timestep = timestep / 1000 ,
1177+ guidance = guidance ,
1178+ pooled_projections = negative_pooled_prompt_embeds ,
1179+ encoder_hidden_states = negative_prompt_embeds ,
1180+ txt_ids = text_ids ,
1181+ img_ids = latent_image_ids ,
1182+ joint_attention_kwargs = self .joint_attention_kwargs ,
1183+ return_dict = False ,
1184+ )[0 ]
11821185 noise_guidance = true_cfg_scale * (noise_pred - noise_pred_uncond )
11831186 else :
11841187 noise_pred_uncond = noise_pred
0 commit comments