diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index cf50e89ca5ae..f53958df2ed0 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -142,6 +142,45 @@ def __init__( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 ) + def check_inputs( + self, + image, + prompt, + prompt_2, + prompt_embeds=None, + pooled_prompt_embeds=None, + prompt_embeds_scale=1.0, + pooled_prompt_embeds_scale=1.0, + ): + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + if prompt is not None and (isinstance(prompt, list) and isinstance(image, list) and len(prompt) != len(image)): + raise ValueError( + f"number of prompts must be equal to number of images, but {len(prompt)} prompts were provided and {len(image)} images" + ) + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + if isinstance(prompt_embeds_scale, list) and ( + isinstance(image, list) and len(prompt_embeds_scale) != len(image) + ): + raise ValueError( + f"number of weights must be equal to number of images, but {len(prompt_embeds_scale)} weights were provided and {len(image)} images" + ) + def encode_image(self, image, device, num_images_per_prompt): dtype = next(self.image_encoder.parameters()).dtype image = self.feature_extractor.preprocess( @@ -334,6 +373,12 @@ def encode_prompt( def __call__( self, image: PipelineImageInput, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0, + pooled_prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0, return_dict: bool = True, ): r""" @@ -345,6 +390,16 @@ def __call__( numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. **experimental feature**: to use this feature, + make sure to explicitly load text encoders to the pipeline. Prompts will be ignored if text encoders + are not loaded. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.flux.FluxPriorReduxPipelineOutput`] instead of a plain tuple. @@ -356,6 +411,17 @@ def __call__( returning a tuple, the first element is a list with the generated images. """ + # 1. Check inputs. Raise error if not correct + self.check_inputs( + image, + prompt, + prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + prompt_embeds_scale=prompt_embeds_scale, + pooled_prompt_embeds_scale=pooled_prompt_embeds_scale, + ) + # 2. Define call parameters if image is not None and isinstance(image, Image.Image): batch_size = 1 @@ -363,6 +429,13 @@ def __call__( batch_size = len(image) else: batch_size = image.shape[0] + if prompt is not None and isinstance(prompt, str): + prompt = batch_size * [prompt] + if isinstance(prompt_embeds_scale, float): + prompt_embeds_scale = batch_size * [prompt_embeds_scale] + if isinstance(pooled_prompt_embeds_scale, float): + pooled_prompt_embeds_scale = batch_size * [pooled_prompt_embeds_scale] + device = self._execution_device # 3. Prepare image embeddings @@ -378,24 +451,38 @@ def __call__( pooled_prompt_embeds, _, ) = self.encode_prompt( - prompt=[""] * batch_size, - prompt_2=None, - prompt_embeds=None, - pooled_prompt_embeds=None, + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, device=device, num_images_per_prompt=1, max_sequence_length=512, lora_scale=None, ) else: + if prompt is not None: + logger.warning( + "prompt input is ignored when text encoders are not loaded to the pipeline. " + "Make sure to explicitly load the text encoders to enable prompt input. " + ) # max_sequence_length is 512, t5 encoder hidden size is 4096 prompt_embeds = torch.zeros((batch_size, 512, 4096), device=device, dtype=image_embeds.dtype) # pooled_prompt_embeds is 768, clip text encoder hidden size pooled_prompt_embeds = torch.zeros((batch_size, 768), device=device, dtype=image_embeds.dtype) - # Concatenate image and text embeddings + # scale & concatenate image and text embeddings prompt_embeds = torch.cat([prompt_embeds, image_embeds], dim=1) + prompt_embeds *= torch.tensor(prompt_embeds_scale, device=device, dtype=image_embeds.dtype)[:, None, None] + pooled_prompt_embeds *= torch.tensor(pooled_prompt_embeds_scale, device=device, dtype=image_embeds.dtype)[ + :, None + ] + + # weighted sum + prompt_embeds = torch.sum(prompt_embeds, dim=0, keepdim=True) + pooled_prompt_embeds = torch.sum(pooled_prompt_embeds, dim=0, keepdim=True) + # Offload all models self.maybe_free_model_hooks()