Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 75 additions & 89 deletions src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,9 +315,9 @@ def _get_llama3_prompt_embeds(
def encode_prompt(
self,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
prompt_3: Union[str, List[str]],
prompt_4: Union[str, List[str]],
prompt_2: Optional[Union[str, List[str]]] = None,
prompt_3: Optional[Union[str, List[str]]] = None,
prompt_4: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
num_images_per_prompt: int = 1,
Expand All @@ -339,118 +339,104 @@ def encode_prompt(
else:
batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, list) else prompt_embeds.shape[0]

prompt_embeds, pooled_prompt_embeds = self._encode_prompt(
prompt=prompt,
prompt_2=prompt_2,
prompt_3=prompt_3,
prompt_4=prompt_4,
device=device,
dtype=dtype,
num_images_per_prompt=num_images_per_prompt,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
max_sequence_length=max_sequence_length,
)

if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt_2 = negative_prompt_2 or negative_prompt
negative_prompt_3 = negative_prompt_3 or negative_prompt
negative_prompt_4 = negative_prompt_4 or negative_prompt
device = device or self._execution_device

# normalize str to list
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
negative_prompt_2 = (
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
)
negative_prompt_3 = (
batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
)
negative_prompt_4 = (
batch_size * [negative_prompt_4] if isinstance(negative_prompt_4, str) else negative_prompt_4
if pooled_prompt_embeds is None:
pooled_prompt_embeds_1 = self._get_clip_prompt_embeds(
self.tokenizer, self.text_encoder, prompt, max_sequence_length, device, dtype
)

if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)

negative_prompt_embeds, negative_pooled_prompt_embeds = self._encode_prompt(
prompt=negative_prompt,
prompt_2=negative_prompt_2,
prompt_3=negative_prompt_3,
prompt_4=negative_prompt_4,
device=device,
dtype=dtype,
num_images_per_prompt=num_images_per_prompt,
prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=negative_pooled_prompt_embeds,
max_sequence_length=max_sequence_length,
if do_classifier_free_guidance and negative_pooled_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
negative_pooled_prompt_embeds_1 = self._get_clip_prompt_embeds(
self.tokenizer, self.text_encoder, negative_prompt, max_sequence_length, device, dtype
)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds

def _encode_prompt(
self,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
prompt_3: Union[str, List[str]],
prompt_4: Union[str, List[str]],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 128,
):
device = device or self._execution_device
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, list) else prompt_embeds.shape[0]

if pooled_prompt_embeds is None:
prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2

pooled_prompt_embeds_1 = self._get_clip_prompt_embeds(
self.tokenizer, self.text_encoder, prompt, max_sequence_length, device, dtype
)
pooled_prompt_embeds_2 = self._get_clip_prompt_embeds(
self.tokenizer_2, self.text_encoder_2, prompt_2, max_sequence_length, device, dtype
)

if do_classifier_free_guidance and negative_pooled_prompt_embeds is None:
negative_prompt_2 = negative_prompt_2 or ""
negative_prompt_2 = (
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
)
negative_pooled_prompt_embeds_2 = self._get_clip_prompt_embeds(
self.tokenizer_2, self.text_encoder_2, negative_prompt_2, max_sequence_length, device, dtype
)

if pooled_prompt_embeds is None:
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_1, pooled_prompt_embeds_2], dim=-1)

pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
if do_classifier_free_guidance and negative_pooled_prompt_embeds is None:
negative_pooled_prompt_embeds = torch.cat(
[negative_pooled_prompt_embeds_1, negative_pooled_prompt_embeds_2], dim=-1
)

pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt)
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)

if prompt_embeds is None:
prompt_3 = prompt_3 or prompt
prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
t5_prompt_embeds = self._get_t5_prompt_embeds(prompt_3, max_sequence_length, device, dtype)

if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt_3 = negative_prompt_3 or ""
negative_prompt_3 = (
batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
)
negative_t5_prompt_embeds = self._get_t5_prompt_embeds(
negative_prompt_3, max_sequence_length, device, dtype
)

if prompt_embeds is None:
prompt_4 = prompt_4 or prompt
prompt_4 = [prompt_4] if isinstance(prompt_4, str) else prompt_4

t5_prompt_embeds = self._get_t5_prompt_embeds(prompt_3, max_sequence_length, device, dtype)
llama3_prompt_embeds = self._get_llama3_prompt_embeds(prompt_4, max_sequence_length, device, dtype)

_, seq_len, _ = t5_prompt_embeds.shape
t5_prompt_embeds = t5_prompt_embeds.repeat(1, num_images_per_prompt, 1)
t5_prompt_embeds = t5_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

_, _, seq_len, dim = llama3_prompt_embeds.shape
llama3_prompt_embeds = llama3_prompt_embeds.repeat(1, 1, num_images_per_prompt, 1)
llama3_prompt_embeds = llama3_prompt_embeds.view(-1, batch_size * num_images_per_prompt, seq_len, dim)
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt_4 = negative_prompt_4 or ""
negative_prompt_4 = (
batch_size * [negative_prompt_4] if isinstance(negative_prompt_4, str) else negative_prompt_4
)
negative_llama3_prompt_embeds = self._get_llama3_prompt_embeds(
negative_prompt_4, max_sequence_length, device, dtype
)

if prompt_embeds is None:
prompt_embeds = [t5_prompt_embeds, llama3_prompt_embeds]

return prompt_embeds, pooled_prompt_embeds
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt_embeds = [negative_t5_prompt_embeds, negative_llama3_prompt_embeds]

_, seq_len, _ = prompt_embeds[0].shape
prompt_embeds[0] = prompt_embeds[0].repeat(1, num_images_per_prompt, 1)
prompt_embeds[0] = prompt_embeds[0].view(batch_size * num_images_per_prompt, seq_len, -1)

_, _, seq_len, dim = prompt_embeds[1].shape
prompt_embeds[1] = prompt_embeds[1].repeat(1, 1, num_images_per_prompt, 1)
prompt_embeds[1] = prompt_embeds[1].view(-1, batch_size * num_images_per_prompt, seq_len, dim)

if do_classifier_free_guidance:
_, seq_len, _ = negative_prompt_embeds[0].shape
negative_prompt_embeds[0] = negative_prompt_embeds[0].repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds[0] = negative_prompt_embeds[0].view(batch_size * num_images_per_prompt, seq_len, -1)

_, _, seq_len, dim = negative_prompt_embeds[1].shape
negative_prompt_embeds[1] = negative_prompt_embeds[1].repeat(1, 1, num_images_per_prompt, 1)
negative_prompt_embeds[1] = negative_prompt_embeds[1].view(
-1, batch_size * num_images_per_prompt, seq_len, dim
)

return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds

def enable_vae_slicing(self):
r"""
Expand Down
Loading