Skip to content
Merged
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 83 additions & 5 deletions src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,45 @@ def __init__(
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
)

def check_inputs(
self,
image,
prompt,
prompt_2,
prompt_embeds=None,
pooled_prompt_embeds=None,
prompt_embeds_scale=1.0,
pooled_prompt_embeds_scale=1.0,
):
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt_2 is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
if prompt is not None and (isinstance(prompt, list) and isinstance(image, list) and len(prompt) != len(image)):
raise ValueError(
f"number of prompts must be equal to number of images, but {len(prompt)} prompts were provided and {len(image)} images"
)
if prompt_embeds is not None and pooled_prompt_embeds is None:
raise ValueError(
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
)
if isinstance(prompt_embeds_scale, list) and (
isinstance(image, list) and len(prompt_embeds_scale) != len(image)
):
raise ValueError(
f"number of weights must be equal to number of images, but {len(prompt_embeds_scale)} weights were provided and {len(image)} images"
)

def encode_image(self, image, device, num_images_per_prompt):
dtype = next(self.image_encoder.parameters()).dtype
image = self.feature_extractor.preprocess(
Expand Down Expand Up @@ -334,6 +373,12 @@ def encode_prompt(
def __call__(
self,
image: PipelineImageInput,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0,
pooled_prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0,
return_dict: bool = True,
):
r"""
Expand All @@ -345,6 +390,14 @@ def __call__(
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`
prompt (`str` or `List[str]`, *optional*):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's make it clear that it is an experimental feature, and if you pass prompt, you will need to load text_encoders explicitly

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

The prompt or prompts to guide the image generation.
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.flux.FluxPriorReduxPipelineOutput`] instead of a plain tuple.

Expand All @@ -356,13 +409,29 @@ def __call__(
returning a tuple, the first element is a list with the generated images.
"""

# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
prompt_2,
image,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
)

# 2. Define call parameters
if image is not None and isinstance(image, Image.Image):
batch_size = 1
elif image is not None and isinstance(image, list):
batch_size = len(image)
else:
batch_size = image.shape[0]
if prompt is not None and isinstance(prompt, str):
prompt = batch_size * [prompt]
if isinstance(prompt_embeds_scale, float):
prompt_embeds_scale = batch_size * [prompt_embeds_scale]
if isinstance(pooled_prompt_embeds_scale, float):
pooled_prompt_embeds_scale = batch_size * [pooled_prompt_embeds_scale]

device = self._execution_device

# 3. Prepare image embeddings
Expand All @@ -378,10 +447,10 @@ def __call__(
pooled_prompt_embeds,
_,
) = self.encode_prompt(
prompt=[""] * batch_size,
prompt_2=None,
prompt_embeds=None,
pooled_prompt_embeds=None,
prompt=prompt,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's throw out a warning here:
if prompt inputs is passed but do not have text_encoder/tokenizer, in this case the text inputs will be ignored

it is a bit of different from our regular pipelines, normally, if you pass a prompt and do not have a text_encoder, you will get an error says like from encode_prompt; here we will just use zero prompt embeds instead, so let's be make an explicit warning about that

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree, added one now

prompt_2=prompt_2,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
device=device,
num_images_per_prompt=1,
max_sequence_length=512,
Expand All @@ -393,9 +462,18 @@ def __call__(
# pooled_prompt_embeds is 768, clip text encoder hidden size
pooled_prompt_embeds = torch.zeros((batch_size, 768), device=device, dtype=image_embeds.dtype)

# Concatenate image and text embeddings
# scale & oncatenate image and text embeddings
prompt_embeds = torch.cat([prompt_embeds, image_embeds], dim=1)

prompt_embeds *= torch.tensor(prompt_embeds_scale, device=device, dtype=image_embeds.dtype)[:, None, None]
pooled_prompt_embeds *= torch.tensor(pooled_prompt_embeds_scale, device=device, dtype=image_embeds.dtype)[
:, None
]

# weighted sum
prompt_embeds = torch.sum(prompt_embeds, dim=0, keepdim=True)
pooled_prompt_embeds = torch.sum(pooled_prompt_embeds, dim=0, keepdim=True)

# Offload all models
self.maybe_free_model_hooks()

Expand Down
Loading