Skip to content

Commit 9844613

Browse files
committed
Remove negative_* from SDXL callback
1 parent c002724 commit 9844613

File tree

3 files changed

+0
-24
lines changed

3 files changed

+0
-24
lines changed

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -237,11 +237,8 @@ class StableDiffusionXLPipeline(
237237
_callback_tensor_inputs = [
238238
"latents",
239239
"prompt_embeds",
240-
"negative_prompt_embeds",
241240
"add_text_embeds",
242241
"add_time_ids",
243-
"negative_pooled_prompt_embeds",
244-
"negative_add_time_ids",
245242
]
246243

247244
def __init__(
@@ -1243,13 +1240,8 @@ def __call__(
12431240

12441241
latents = callback_outputs.pop("latents", latents)
12451242
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1246-
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
12471243
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
1248-
negative_pooled_prompt_embeds = callback_outputs.pop(
1249-
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1250-
)
12511244
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
1252-
negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
12531245

12541246
# call the callback, if provided
12551247
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -257,11 +257,8 @@ class StableDiffusionXLImg2ImgPipeline(
257257
_callback_tensor_inputs = [
258258
"latents",
259259
"prompt_embeds",
260-
"negative_prompt_embeds",
261260
"add_text_embeds",
262261
"add_time_ids",
263-
"negative_pooled_prompt_embeds",
264-
"add_neg_time_ids",
265262
]
266263

267264
def __init__(
@@ -1438,13 +1435,8 @@ def denoising_value_valid(dnv):
14381435

14391436
latents = callback_outputs.pop("latents", latents)
14401437
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1441-
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
14421438
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
1443-
negative_pooled_prompt_embeds = callback_outputs.pop(
1444-
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1445-
)
14461439
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
1447-
add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
14481440

14491441
# call the callback, if provided
14501442
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -285,11 +285,8 @@ class StableDiffusionXLInpaintPipeline(
285285
_callback_tensor_inputs = [
286286
"latents",
287287
"prompt_embeds",
288-
"negative_prompt_embeds",
289288
"add_text_embeds",
290289
"add_time_ids",
291-
"negative_pooled_prompt_embeds",
292-
"add_neg_time_ids",
293290
"mask",
294291
"masked_image_latents",
295292
]
@@ -1671,13 +1668,8 @@ def denoising_value_valid(dnv):
16711668

16721669
latents = callback_outputs.pop("latents", latents)
16731670
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1674-
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
16751671
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
1676-
negative_pooled_prompt_embeds = callback_outputs.pop(
1677-
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1678-
)
16791672
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
1680-
add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
16811673
mask = callback_outputs.pop("mask", mask)
16821674
masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents)
16831675

0 commit comments

Comments
 (0)