-
Couldn't load subscription status.
- Fork 6.5k
[Flux Redux] add prompt & multiple image input #10056
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 19 commits
ac266fb
bf2e149
27acef8
6fbf290
e6f26b9
7198ec3
382e556
ef9ec65
7d13a41
b8dfdf7
8bc5f7a
012a0ec
5af6811
6275586
bf68f2e
34715b1
d2b4881
a9e893e
df49440
7c93dd0
a350d0c
971b376
f42fe8c
1ab7060
0520ca5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -142,6 +142,45 @@ def __init__( | |
| self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 | ||
| ) | ||
|
|
||
| def check_inputs( | ||
| self, | ||
| image, | ||
| prompt, | ||
| prompt_2, | ||
| prompt_embeds=None, | ||
| pooled_prompt_embeds=None, | ||
| prompt_embeds_scale=1.0, | ||
| pooled_prompt_embeds_scale=1.0, | ||
| ): | ||
| if prompt is not None and prompt_embeds is not None: | ||
| raise ValueError( | ||
| f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" | ||
| " only forward one of the two." | ||
| ) | ||
| elif prompt_2 is not None and prompt_embeds is not None: | ||
| raise ValueError( | ||
| f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" | ||
| " only forward one of the two." | ||
| ) | ||
| elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): | ||
| raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") | ||
| elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): | ||
| raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") | ||
| if prompt is not None and (isinstance(prompt, list) and isinstance(image, list) and len(prompt) != len(image)): | ||
| raise ValueError( | ||
| f"number of prompts must be equal to number of images, but {len(prompt)} prompts were provided and {len(image)} images" | ||
| ) | ||
| if prompt_embeds is not None and pooled_prompt_embeds is None: | ||
| raise ValueError( | ||
| "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." | ||
| ) | ||
| if isinstance(prompt_embeds_scale, list) and ( | ||
| isinstance(image, list) and len(prompt_embeds_scale) != len(image) | ||
| ): | ||
| raise ValueError( | ||
| f"number of weights must be equal to number of images, but {len(prompt_embeds_scale)} weights were provided and {len(image)} images" | ||
| ) | ||
|
|
||
| def encode_image(self, image, device, num_images_per_prompt): | ||
| dtype = next(self.image_encoder.parameters()).dtype | ||
| image = self.feature_extractor.preprocess( | ||
|
|
@@ -334,6 +373,12 @@ def encode_prompt( | |
| def __call__( | ||
| self, | ||
| image: PipelineImageInput, | ||
| prompt: Union[str, List[str]] = None, | ||
| prompt_2: Optional[Union[str, List[str]]] = None, | ||
| prompt_embeds: Optional[torch.FloatTensor] = None, | ||
| pooled_prompt_embeds: Optional[torch.FloatTensor] = None, | ||
| prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0, | ||
| pooled_prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0, | ||
| return_dict: bool = True, | ||
| ): | ||
| r""" | ||
|
|
@@ -345,6 +390,14 @@ def __call__( | |
| numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list | ||
| or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a | ||
| list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` | ||
| prompt (`str` or `List[str]`, *optional*): | ||
| The prompt or prompts to guide the image generation. | ||
| prompt_2 (`str` or `List[str]`, *optional*): | ||
| The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. | ||
| prompt_embeds (`torch.FloatTensor`, *optional*): | ||
| Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. | ||
| pooled_prompt_embeds (`torch.FloatTensor`, *optional*): | ||
| Pre-generated pooled text embeddings. | ||
| return_dict (`bool`, *optional*, defaults to `True`): | ||
| Whether or not to return a [`~pipelines.flux.FluxPriorReduxPipelineOutput`] instead of a plain tuple. | ||
|
|
||
|
|
@@ -356,13 +409,31 @@ def __call__( | |
| returning a tuple, the first element is a list with the generated images. | ||
| """ | ||
|
|
||
| # 1. Check inputs. Raise error if not correct | ||
| self.check_inputs( | ||
| image, | ||
| prompt, | ||
| prompt_2, | ||
| prompt_embeds=prompt_embeds, | ||
| pooled_prompt_embeds=pooled_prompt_embeds, | ||
| prompt_embeds_scale=prompt_embeds_scale, | ||
| pooled_prompt_embeds_scale=pooled_prompt_embeds_scale, | ||
| ) | ||
linoytsaban marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # 2. Define call parameters | ||
| if image is not None and isinstance(image, Image.Image): | ||
| batch_size = 1 | ||
| elif image is not None and isinstance(image, list): | ||
| batch_size = len(image) | ||
| else: | ||
| batch_size = image.shape[0] | ||
| if prompt is not None and isinstance(prompt, str): | ||
| prompt = batch_size * [prompt] | ||
| if isinstance(prompt_embeds_scale, float): | ||
| prompt_embeds_scale = batch_size * [prompt_embeds_scale] | ||
| if isinstance(pooled_prompt_embeds_scale, float): | ||
| pooled_prompt_embeds_scale = batch_size * [pooled_prompt_embeds_scale] | ||
|
|
||
| device = self._execution_device | ||
|
|
||
| # 3. Prepare image embeddings | ||
|
|
@@ -378,10 +449,10 @@ def __call__( | |
| pooled_prompt_embeds, | ||
| _, | ||
| ) = self.encode_prompt( | ||
| prompt=[""] * batch_size, | ||
| prompt_2=None, | ||
| prompt_embeds=None, | ||
| pooled_prompt_embeds=None, | ||
| prompt=prompt, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's throw out a warning here: it is a bit of different from our regular pipelines, normally, if you pass a prompt and do not have a text_encoder, you will get an error says like from encode_prompt; here we will just use zero prompt embeds instead, so let's be make an explicit warning about that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. agree, added one now |
||
| prompt_2=prompt_2, | ||
| prompt_embeds=prompt_embeds, | ||
| pooled_prompt_embeds=pooled_prompt_embeds, | ||
| device=device, | ||
| num_images_per_prompt=1, | ||
| max_sequence_length=512, | ||
|
|
@@ -393,9 +464,18 @@ def __call__( | |
| # pooled_prompt_embeds is 768, clip text encoder hidden size | ||
| pooled_prompt_embeds = torch.zeros((batch_size, 768), device=device, dtype=image_embeds.dtype) | ||
|
|
||
| # Concatenate image and text embeddings | ||
| # scale & concatenate image and text embeddings | ||
| prompt_embeds = torch.cat([prompt_embeds, image_embeds], dim=1) | ||
|
|
||
| prompt_embeds *= torch.tensor(prompt_embeds_scale, device=device, dtype=image_embeds.dtype)[:, None, None] | ||
| pooled_prompt_embeds *= torch.tensor(pooled_prompt_embeds_scale, device=device, dtype=image_embeds.dtype)[ | ||
| :, None | ||
| ] | ||
|
|
||
| # weighted sum | ||
| prompt_embeds = torch.sum(prompt_embeds, dim=0, keepdim=True) | ||
| pooled_prompt_embeds = torch.sum(pooled_prompt_embeds, dim=0, keepdim=True) | ||
|
|
||
| # Offload all models | ||
| self.maybe_free_model_hooks() | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's make it clear that it is an experimental feature, and if you pass
prompt, you will need to load text_encoders explicitlyThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done