Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -371,32 +371,54 @@ def encode_prompt(
batch_size = prompt_embeds.shape[0]

# Define tokenizers and text encoders
tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
text_encoders = (
[self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
)
tokenizers = []
if self.tokenizer is not None:
tokenizers.append(self.tokenizer)
if self.tokenizer_2 is not None:
tokenizers.append(self.tokenizer_2)
if not tokenizers:
raise ValueError(
"Cannot encode prompt since no tokenizer is defined. Make sure that either `tokenizer` or `tokenizer_2` is defined."
)

text_encoders = []
if self.text_encoder is not None:
text_encoders.append(self.text_encoder)
if self.text_encoder_2 is not None:
text_encoders.append(self.text_encoder_2)
if not text_encoders:
raise ValueError(
"Cannot encode prompt since no text encoder is defined. Make sure that either `text_encoder` or `text_encoder_2` is defined."
)

if prompt_embeds is None:
prompt_2 = prompt_2 or prompt
prompt_2 = prompt_2 or prompt # Ensure prompt_2 is set if only one prompt is provided initially
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2

# textual inversion: process multi-vector tokens if necessary
prompts_list = []
# Determine which prompts to use based on available tokenizers/encoders
if self.tokenizer is not None and self.text_encoder is not None:
prompts_list.append(prompt)
if self.tokenizer_2 is not None and self.text_encoder_2 is not None:
# If only tokenizer_2 is available, it should use the first prompt
prompts_list.append(prompt_2 if (self.tokenizer is not None and self.text_encoder is not None) else prompt)


prompt_embeds_list = []
prompts = [prompt, prompt_2]
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
for current_prompt, tokenizer, text_encoder in zip(prompts_list, tokenizers, text_encoders):
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, tokenizer)
current_prompt = self.maybe_convert_prompt(current_prompt, tokenizer)

text_inputs = tokenizer(
prompt,
current_prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)

text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
untruncated_ids = tokenizer(current_prompt, padding="longest", return_tensors="pt").input_ids

if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
Expand All @@ -407,84 +429,100 @@ def encode_prompt(
f" {tokenizer.model_max_length} tokens: {removed_text}"
)

prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
current_encoder_output = text_encoder(text_input_ids.to(device), output_hidden_states=True)

# Pooled output taken from the last text encoder
if text_encoder == text_encoders[-1]: # Check if current encoder is the last one
# Ensure current_encoder_output[0] is the pooled output, typically ndim == 2
if pooled_prompt_embeds is None and hasattr(current_encoder_output, "pooler_output") and current_encoder_output.pooler_output is not None:
pooled_prompt_embeds = current_encoder_output.pooler_output
elif pooled_prompt_embeds is None and current_encoder_output[0].ndim == 2: # Fallback for models not returning explicit pooler_output
pooled_prompt_embeds = current_encoder_output[0]

# We are only ALWAYS interested in the pooled output of the final text encoder
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]

current_hidden_states: torch.Tensor
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
current_hidden_states = current_encoder_output.hidden_states[-2]
else:
# "2" because SDXL always indexes from the penultimate layer.
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
current_hidden_states = current_encoder_output.hidden_states[-(clip_skip + 2)]

prompt_embeds_list.append(prompt_embeds)
prompt_embeds_list.append(current_hidden_states)

prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
prompt_embeds = torch.cat(prompt_embeds_list, dim=-1)

# get unconditional embeddings for classifier free guidance
zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
if pooled_prompt_embeds is not None:
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
# If pooled_prompt_embeds is None, negative_pooled_prompt_embeds will be handled later or error
elif do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt_2 = negative_prompt_2 or negative_prompt
negative_prompt_2 = negative_prompt_2 or negative_prompt # Ensure negative_prompt_2 is set

# normalize str to list
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
negative_prompt_2 = (
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
)
negative_prompt_2 = batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2

uncond_tokens: List[str]
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = [negative_prompt, negative_prompt_2]

uncond_tokens_list = []
if self.tokenizer is not None and self.text_encoder is not None:
uncond_tokens_list.append(negative_prompt)
if self.tokenizer_2 is not None and self.text_encoder_2 is not None:
uncond_tokens_list.append(negative_prompt_2 if (self.tokenizer is not None and self.text_encoder is not None) else negative_prompt)


if prompt is not None: # Check if prompt is not None before type comparison
for neg_prompt_tokens_segment in uncond_tokens_list:
if batch_size != len(neg_prompt_tokens_segment):
raise ValueError(
f"`negative_prompt` segment has batch size {len(neg_prompt_tokens_segment)}, but `prompt`:"
f" has batch size {batch_size}. Please make sure that passed `negative_prompt` segments match"
" the batch size of `prompt`."
)

negative_prompt_embeds_list = []
for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
for current_negative_prompt, tokenizer, text_encoder in zip(uncond_tokens_list, tokenizers, text_encoders):
if isinstance(self, TextualInversionLoaderMixin):
negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
current_negative_prompt = self.maybe_convert_prompt(current_negative_prompt, tokenizer)

max_length = prompt_embeds.shape[1]
max_length = prompt_embeds.shape[1] # Use shape from already processed prompt_embeds
uncond_input = tokenizer(
negative_prompt,
current_negative_prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)

negative_prompt_embeds = text_encoder(
current_neg_encoder_output = text_encoder(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)

# We are only ALWAYS interested in the pooled output of the final text encoder
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
# Pooled output taken from the last text encoder
if text_encoder == text_encoders[-1]: # Check if current encoder is the last one
if negative_pooled_prompt_embeds is None and hasattr(current_neg_encoder_output, "pooler_output") and current_neg_encoder_output.pooler_output is not None:
negative_pooled_prompt_embeds = current_neg_encoder_output.pooler_output
elif negative_pooled_prompt_embeds is None and current_neg_encoder_output[0].ndim == 2: # Fallback
negative_pooled_prompt_embeds = current_neg_encoder_output[0]

current_negative_hidden_states = current_neg_encoder_output.hidden_states[-2]
negative_prompt_embeds_list.append(current_negative_hidden_states)

negative_prompt_embeds_list.append(negative_prompt_embeds)

negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
negative_prompt_embeds = torch.cat(negative_prompt_embeds_list, dim=-1)

# Determine dtype for prompt_embeds
if self.text_encoder_2 is not None:
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
target_dtype = self.text_encoder_2.dtype
elif self.text_encoder is not None:
target_dtype = self.text_encoder.dtype
else:
prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
# This case should ideally be prevented by the checks at the start of the function
target_dtype = self.unet.dtype

prompt_embeds = prompt_embeds.to(dtype=target_dtype, device=device)

bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
Expand All @@ -495,21 +533,61 @@ def encode_prompt(
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]

if self.text_encoder_2 is not None:
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
else:
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
# Determine dtype for negative_prompt_embeds (should be same as prompt_embeds)
negative_prompt_embeds = negative_prompt_embeds.to(dtype=target_dtype, device=device)

negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
if do_classifier_free_guidance:
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
# Check if pooled_prompt_embeds were generated, especially if prompts were provided.
if prompt_embeds is not None and pooled_prompt_embeds is None : # prompt_embeds is the final concatenated embeddings
raise ValueError(
"Pooled prompt embeddings were not generated. Make sure the model has a pooling layer or outputs pooler_output."
)

if pooled_prompt_embeds is not None:
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)

if do_classifier_free_guidance:
# Similar check for negative_pooled_prompt_embeds
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None and not zero_out_negative_prompt : # negative_prompt_embeds is the final concatenated embeddings
raise ValueError(
"Negative pooled prompt embeddings were not generated but were expected. Make sure the model has a pooling layer or outputs pooler_output."
)

if negative_pooled_prompt_embeds is not None:
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
elif zero_out_negative_prompt and pooled_prompt_embeds is not None: # If it was meant to be zeros and pooled_prompt_embeds exists
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
elif zero_out_negative_prompt and pooled_prompt_embeds is None : # If zero_out and main pooled is also None (should not happen if prompt is given)
# This case implies an issue earlier, as pooled_prompt_embeds should exist if prompt_embeds does.
# However, as a safeguard if negative_prompt was empty and we need to create zero pooled embeds:
if prompt_embeds is not None : # We need a shape reference
# Attempt to get a reference shape for text_encoder_2's projection if available
ref_shape_dim = self.text_encoder_2.config.projection_dim if self.text_encoder_2 else \
(self.text_encoder.config.projection_dim if self.text_encoder and hasattr(self.text_encoder.config, "projection_dim") else None)
if ref_shape_dim is None and hasattr(self.text_encoder, "text_projection") and self.text_encoder.text_projection is not None: # For OpenCLIP-H/14
ref_shape_dim = self.text_encoder.text_projection.shape[-1]
if ref_shape_dim is None and hasattr(self.text_encoder, "projection_dim") and self.text_encoder.projection_dim is not None: # For CLIP-G/14
ref_shape_dim = self.text_encoder.projection_dim

if ref_shape_dim is not None:
negative_pooled_prompt_embeds = torch.zeros(bs_embed, ref_shape_dim, device=device, dtype=target_dtype)
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(bs_embed * num_images_per_prompt, -1)
else: # Fallback if no specific projection_dim found, less ideal
logger.warning("Cannot determine the projection dimension for zeroed negative_pooled_prompt_embeds. This might lead to errors.")
# Create with a default size, though this might be incorrect.
# This path indicates a configuration problem or an unexpected model structure.
# Consider last hidden state size if projection_dim is not available
last_encoder = text_encoders[-1]
fallback_dim = last_encoder.config.hidden_size
negative_pooled_prompt_embeds = torch.zeros(bs_embed, fallback_dim, device=device, dtype=target_dtype)
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(bs_embed * num_images_per_prompt, -1)


if self.text_encoder is not None:
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
Expand Down Expand Up @@ -1131,10 +1209,24 @@ def __call__(

# 7. Prepare added time ids & embeddings
add_text_embeds = pooled_prompt_embeds
if self.text_encoder_2 is None:

if self.text_encoder_2 is not None and hasattr(self.text_encoder_2.config, "projection_dim"):
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
elif self.text_encoder is not None and hasattr(self.text_encoder.config, "projection_dim"):
text_encoder_projection_dim = self.text_encoder.config.projection_dim
elif pooled_prompt_embeds is not None: # Fallback to the shape of pooled_prompt_embeds
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
else:
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
# This should not happen if encode_prompt ran correctly and returned pooled_prompt_embeds
# Setting to a common default or raising an error might be options.
# For now, let's assume pooled_prompt_embeds will be available if we reach here.
# If not, _get_add_time_ids will likely fail due to pooled_prompt_embeds being None for add_text_embeds
# or due to text_encoder_projection_dim not being set.
# Consider raising an error if pooled_prompt_embeds is None.
raise ValueError(
"pooled_prompt_embeds is None, and text_encoder_projection_dim cannot be determined."
" Ensure that encode_prompt returns valid pooled_prompt_embeds."
)

add_time_ids = self._get_add_time_ids(
original_size,
Expand Down
Loading