Skip to content

Commit 8eddfc5

Browse files
committed
fix implementation
1 parent 559660b commit 8eddfc5

19 files changed

+98
-38
lines changed

src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,8 @@ def encode_prompt(
438438
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
439439

440440
# We are only ALWAYS interested in the pooled output of the final text encoder
441-
pooled_prompt_embeds = prompt_embeds[0] if pooled_prompt_embeds is None else pooled_prompt_embeds
441+
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
442+
pooled_prompt_embeds = prompt_embeds[0]
442443

443444
if clip_skip is None:
444445
prompt_embeds = prompt_embeds.hidden_states[-2]
@@ -498,8 +499,10 @@ def encode_prompt(
498499
uncond_input.input_ids.to(device),
499500
output_hidden_states=True,
500501
)
502+
501503
# We are only ALWAYS interested in the pooled output of the final text encoder
502-
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
504+
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
505+
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
503506
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
504507

505508
negative_prompt_embeds_list.append(negative_prompt_embeds)

src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,8 @@ def encode_prompt(
406406
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
407407

408408
# We are only ALWAYS interested in the pooled output of the final text encoder
409-
pooled_prompt_embeds = prompt_embeds[0] if pooled_prompt_embeds is None else pooled_prompt_embeds
409+
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
410+
pooled_prompt_embeds = prompt_embeds[0]
410411

411412
if clip_skip is None:
412413
prompt_embeds = prompt_embeds.hidden_states[-2]
@@ -466,8 +467,10 @@ def encode_prompt(
466467
uncond_input.input_ids.to(device),
467468
output_hidden_states=True,
468469
)
470+
469471
# We are only ALWAYS interested in the pooled output of the final text encoder
470-
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
472+
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
473+
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
471474
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
472475

473476
negative_prompt_embeds_list.append(negative_prompt_embeds)

src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,8 @@ def encode_prompt(
415415
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
416416

417417
# We are only ALWAYS interested in the pooled output of the final text encoder
418-
pooled_prompt_embeds = prompt_embeds[0] if pooled_prompt_embeds is None else pooled_prompt_embeds
418+
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
419+
pooled_prompt_embeds = prompt_embeds[0]
419420

420421
if clip_skip is None:
421422
prompt_embeds = prompt_embeds.hidden_states[-2]
@@ -475,8 +476,10 @@ def encode_prompt(
475476
uncond_input.input_ids.to(device),
476477
output_hidden_states=True,
477478
)
479+
478480
# We are only ALWAYS interested in the pooled output of the final text encoder
479-
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
481+
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
482+
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
480483
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
481484

482485
negative_prompt_embeds_list.append(negative_prompt_embeds)

src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,8 @@ def encode_prompt(
408408
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
409409

410410
# We are only ALWAYS interested in the pooled output of the final text encoder
411-
pooled_prompt_embeds = prompt_embeds[0] if pooled_prompt_embeds is None else pooled_prompt_embeds
411+
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
412+
pooled_prompt_embeds = prompt_embeds[0]
412413

413414
if clip_skip is None:
414415
prompt_embeds = prompt_embeds.hidden_states[-2]
@@ -468,8 +469,10 @@ def encode_prompt(
468469
uncond_input.input_ids.to(device),
469470
output_hidden_states=True,
470471
)
472+
471473
# We are only ALWAYS interested in the pooled output of the final text encoder
472-
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
474+
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
475+
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
473476
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
474477

475478
negative_prompt_embeds_list.append(negative_prompt_embeds)

src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,9 @@ def encode_prompt(
388388
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
389389

390390
# We are only ALWAYS interested in the pooled output of the final text encoder
391-
pooled_prompt_embeds = prompt_embeds[0]
391+
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
392+
pooled_prompt_embeds = prompt_embeds[0]
393+
392394
if clip_skip is None:
393395
prompt_embeds = prompt_embeds.hidden_states[-2]
394396
else:
@@ -447,8 +449,10 @@ def encode_prompt(
447449
uncond_input.input_ids.to(device),
448450
output_hidden_states=True,
449451
)
452+
450453
# We are only ALWAYS interested in the pooled output of the final text encoder
451-
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
454+
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
455+
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
452456
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
453457

454458
negative_prompt_embeds_list.append(negative_prompt_embeds)

src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,9 @@ def encode_prompt(
397397
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
398398

399399
# We are only ALWAYS interested in the pooled output of the final text encoder
400-
pooled_prompt_embeds = prompt_embeds[0]
400+
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
401+
pooled_prompt_embeds = prompt_embeds[0]
402+
401403
if clip_skip is None:
402404
prompt_embeds = prompt_embeds.hidden_states[-2]
403405
else:
@@ -456,8 +458,10 @@ def encode_prompt(
456458
uncond_input.input_ids.to(device),
457459
output_hidden_states=True,
458460
)
461+
459462
# We are only ALWAYS interested in the pooled output of the final text encoder
460-
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
463+
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
464+
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
461465
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
462466

463467
negative_prompt_embeds_list.append(negative_prompt_embeds)

src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,9 @@ def encode_prompt(
422422
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
423423

424424
# We are only ALWAYS interested in the pooled output of the final text encoder
425-
pooled_prompt_embeds = prompt_embeds[0]
425+
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
426+
pooled_prompt_embeds = prompt_embeds[0]
427+
426428
if clip_skip is None:
427429
prompt_embeds = prompt_embeds.hidden_states[-2]
428430
else:
@@ -481,8 +483,10 @@ def encode_prompt(
481483
uncond_input.input_ids.to(device),
482484
output_hidden_states=True,
483485
)
486+
484487
# We are only ALWAYS interested in the pooled output of the final text encoder
485-
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
488+
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
489+
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
486490
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
487491

488492
negative_prompt_embeds_list.append(negative_prompt_embeds)

src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,8 @@ def encode_prompt(
336336
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
337337

338338
# We are only ALWAYS interested in the pooled output of the final text encoder
339-
pooled_prompt_embeds = prompt_embeds[0] if pooled_prompt_embeds is None else pooled_prompt_embeds
339+
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
340+
pooled_prompt_embeds = prompt_embeds[0]
340341

341342
if clip_skip is None:
342343
prompt_embeds = prompt_embeds.hidden_states[-2]
@@ -396,8 +397,10 @@ def encode_prompt(
396397
uncond_input.input_ids.to(device),
397398
output_hidden_states=True,
398399
)
400+
399401
# We are only ALWAYS interested in the pooled output of the final text encoder
400-
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
402+
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
403+
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
401404
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
402405

403406
negative_prompt_embeds_list.append(negative_prompt_embeds)

src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,8 @@ def encode_prompt(
421421
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
422422

423423
# We are only ALWAYS interested in the pooled output of the final text encoder
424-
pooled_prompt_embeds = prompt_embeds[0] if pooled_prompt_embeds is None else pooled_prompt_embeds
424+
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
425+
pooled_prompt_embeds = prompt_embeds[0]
425426

426427
if clip_skip is None:
427428
prompt_embeds = prompt_embeds.hidden_states[-2]
@@ -481,8 +482,10 @@ def encode_prompt(
481482
uncond_input.input_ids.to(device),
482483
output_hidden_states=True,
483484
)
485+
484486
# We are only ALWAYS interested in the pooled output of the final text encoder
485-
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
487+
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
488+
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
486489
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
487490

488491
negative_prompt_embeds_list.append(negative_prompt_embeds)

src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,8 @@ def encode_prompt(
413413
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
414414

415415
# We are only ALWAYS interested in the pooled output of the final text encoder
416-
pooled_prompt_embeds = prompt_embeds[0] if pooled_prompt_embeds is None else pooled_prompt_embeds
416+
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
417+
pooled_prompt_embeds = prompt_embeds[0]
417418

418419
if clip_skip is None:
419420
prompt_embeds = prompt_embeds.hidden_states[-2]
@@ -473,8 +474,10 @@ def encode_prompt(
473474
uncond_input.input_ids.to(device),
474475
output_hidden_states=True,
475476
)
477+
476478
# We are only ALWAYS interested in the pooled output of the final text encoder
477-
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
479+
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
480+
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
478481
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
479482

480483
negative_prompt_embeds_list.append(negative_prompt_embeds)

0 commit comments

Comments
 (0)