Skip to content

Commit ab3a895

Browse files
committed
up
1 parent 905c215 commit ab3a895

File tree

1 file changed

+26
-19
lines changed

1 file changed

+26
-19
lines changed

src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -351,10 +351,10 @@ def encode_prompt(
351351
if do_classifier_free_guidance and negative_pooled_prompt_embeds is None:
352352
negative_prompt = negative_prompt or ""
353353
negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
354-
354+
355355
if len(negative_prompt) > 1 and len(negative_prompt) != batch_size:
356356
raise ValueError(f"negative_prompt must be of length 1 or {batch_size}")
357-
357+
358358
negative_pooled_prompt_embeds_1 = self._get_clip_prompt_embeds(
359359
self.tokenizer, self.text_encoder, negative_prompt, max_sequence_length, device, dtype
360360
)
@@ -406,7 +406,7 @@ def encode_prompt(
406406
raise ValueError(f"prompt_3 must be of length 1 or {batch_size}")
407407

408408
t5_prompt_embeds = self._get_t5_prompt_embeds(prompt_3, max_sequence_length, device, dtype)
409-
409+
410410
if t5_prompt_embeds.shape[0] == 1 and batch_size > 1:
411411
t5_prompt_embeds = t5_prompt_embeds.repeat(batch_size, 1, 1)
412412

@@ -420,7 +420,7 @@ def encode_prompt(
420420
negative_t5_prompt_embeds = self._get_t5_prompt_embeds(
421421
negative_prompt_3, max_sequence_length, device, dtype
422422
)
423-
423+
424424
if negative_t5_prompt_embeds.shape[0] == 1 and batch_size > 1:
425425
negative_t5_prompt_embeds = negative_t5_prompt_embeds.repeat(batch_size, 1, 1)
426426

@@ -432,7 +432,7 @@ def encode_prompt(
432432
raise ValueError(f"prompt_4 must be of length 1 or {batch_size}")
433433

434434
llama3_prompt_embeds = self._get_llama3_prompt_embeds(prompt_4, max_sequence_length, device, dtype)
435-
435+
436436
if llama3_prompt_embeds.shape[0] == 1 and batch_size > 1:
437437
llama3_prompt_embeds = llama3_prompt_embeds.repeat(1, batch_size, 1, 1)
438438

@@ -446,10 +446,10 @@ def encode_prompt(
446446
negative_llama3_prompt_embeds = self._get_llama3_prompt_embeds(
447447
negative_prompt_4, max_sequence_length, device, dtype
448448
)
449-
449+
450450
if negative_llama3_prompt_embeds.shape[0] == 1 and batch_size > 1:
451451
negative_llama3_prompt_embeds = negative_llama3_prompt_embeds.repeat(1, batch_size, 1, 1)
452-
452+
453453
# duplicate pooled_prompt_embeds for each generation per prompt
454454
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
455455
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
@@ -472,7 +472,7 @@ def encode_prompt(
472472
llama3_prompt_embeds = llama3_prompt_embeds.repeat(1, 1, num_images_per_prompt, 1)
473473
llama3_prompt_embeds = llama3_prompt_embeds.view(-1, batch_size * num_images_per_prompt, seq_len, dim)
474474

475-
if do_classifier_free_guidance:
475+
if do_classifier_free_guidance:
476476
# duplicate negative_pooled_prompt_embeds for batch_size and num_images_per_prompt
477477
bs_embed, seq_len = negative_pooled_prompt_embeds.shape
478478
if bs_embed == 1 and batch_size > 1:
@@ -502,7 +502,14 @@ def encode_prompt(
502502
-1, batch_size * num_images_per_prompt, seq_len, dim
503503
)
504504

505-
return t5_prompt_embeds, llama3_prompt_embeds, negative_t5_prompt_embeds, negative_llama3_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
505+
return (
506+
t5_prompt_embeds,
507+
llama3_prompt_embeds,
508+
negative_t5_prompt_embeds,
509+
negative_llama3_prompt_embeds,
510+
pooled_prompt_embeds,
511+
negative_pooled_prompt_embeds,
512+
)
506513

507514
def enable_vae_slicing(self):
508515
r"""
@@ -583,9 +590,13 @@ def check_inputs(
583590
"Provide either `prompt` or `pooled_prompt_embeds`. Cannot leave both `prompt` and `pooled_prompt_embeds` undefined."
584591
)
585592
elif prompt is None and t5_prompt_embeds is None:
586-
raise ValueError("Provide either `prompt` or `t5_prompt_embeds`. Cannot leave both `prompt` and `t5_prompt_embeds` undefined.")
593+
raise ValueError(
594+
"Provide either `prompt` or `t5_prompt_embeds`. Cannot leave both `prompt` and `t5_prompt_embeds` undefined."
595+
)
587596
elif prompt is None and llama3_prompt_embeds is None:
588-
raise ValueError("Provide either `prompt` or `llama3_prompt_embeds`. Cannot leave both `prompt` and `llama3_prompt_embeds` undefined.")
597+
raise ValueError(
598+
"Provide either `prompt` or `llama3_prompt_embeds`. Cannot leave both `prompt` and `llama3_prompt_embeds` undefined."
599+
)
589600
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
590601
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
591602
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
@@ -602,8 +613,8 @@ def check_inputs(
602613
)
603614
elif negative_prompt_2 is not None and negative_pooled_prompt_embeds is not None:
604615
raise ValueError(
605-
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
606-
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
616+
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_pooled_prompt_embeds`:"
617+
f" {negative_pooled_prompt_embeds}. Please make sure to only forward one of the two."
607618
)
608619
elif negative_prompt_3 is not None and negative_t5_prompt_embeds is not None:
609620
raise ValueError(
@@ -638,8 +649,6 @@ def check_inputs(
638649
f" {negative_llama3_prompt_embeds.shape}."
639650
)
640651

641-
642-
643652
def prepare_latents(
644653
self,
645654
batch_size,
@@ -755,10 +764,8 @@ def __call__(
755764
batch_size = 1
756765
elif prompt is not None and isinstance(prompt, list):
757766
batch_size = len(prompt)
758-
elif prompt_embeds is not None:
759-
batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, list) else prompt_embeds.shape[0]
760-
else:
761-
batch_size = 1
767+
elif pooled_prompt_embeds is not None:
768+
batch_size = pooled_prompt_embeds.shape[0]
762769

763770
device = self._execution_device
764771

0 commit comments

Comments
 (0)