@@ -142,6 +142,45 @@ def __init__(
142142 self .tokenizer .model_max_length if hasattr (self , "tokenizer" ) and self .tokenizer is not None else 77
143143 )
144144
145+ def check_inputs (
146+ self ,
147+ image ,
148+ prompt ,
149+ prompt_2 ,
150+ prompt_embeds = None ,
151+ pooled_prompt_embeds = None ,
152+ prompt_embeds_scale = 1.0 ,
153+ pooled_prompt_embeds_scale = 1.0 ,
154+ ):
155+ if prompt is not None and prompt_embeds is not None :
156+ raise ValueError (
157+ f"Cannot forward both `prompt`: { prompt } and `prompt_embeds`: { prompt_embeds } . Please make sure to"
158+ " only forward one of the two."
159+ )
160+ elif prompt_2 is not None and prompt_embeds is not None :
161+ raise ValueError (
162+ f"Cannot forward both `prompt_2`: { prompt_2 } and `prompt_embeds`: { prompt_embeds } . Please make sure to"
163+ " only forward one of the two."
164+ )
165+ elif prompt is not None and (not isinstance (prompt , str ) and not isinstance (prompt , list )):
166+ raise ValueError (f"`prompt` has to be of type `str` or `list` but is { type (prompt )} " )
167+ elif prompt_2 is not None and (not isinstance (prompt_2 , str ) and not isinstance (prompt_2 , list )):
168+ raise ValueError (f"`prompt_2` has to be of type `str` or `list` but is { type (prompt_2 )} " )
169+ if prompt is not None and (isinstance (prompt , list ) and isinstance (image , list ) and len (prompt ) != len (image )):
170+ raise ValueError (
171+ f"number of prompts must be equal to number of images, but { len (prompt )} prompts were provided and { len (image )} images"
172+ )
173+ if prompt_embeds is not None and pooled_prompt_embeds is None :
174+ raise ValueError (
175+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
176+ )
177+ if isinstance (prompt_embeds_scale , list ) and (
178+ isinstance (image , list ) and len (prompt_embeds_scale ) != len (image )
179+ ):
180+ raise ValueError (
181+ f"number of weights must be equal to number of images, but { len (prompt_embeds_scale )} weights were provided and { len (image )} images"
182+ )
183+
145184 def encode_image (self , image , device , num_images_per_prompt ):
146185 dtype = next (self .image_encoder .parameters ()).dtype
147186 image = self .feature_extractor .preprocess (
@@ -334,6 +373,12 @@ def encode_prompt(
334373 def __call__ (
335374 self ,
336375 image : PipelineImageInput ,
376+ prompt : Union [str , List [str ]] = None ,
377+ prompt_2 : Optional [Union [str , List [str ]]] = None ,
378+ prompt_embeds : Optional [torch .FloatTensor ] = None ,
379+ pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
380+ prompt_embeds_scale : Optional [Union [float , List [float ]]] = 1.0 ,
381+ pooled_prompt_embeds_scale : Optional [Union [float , List [float ]]] = 1.0 ,
337382 return_dict : bool = True ,
338383 ):
339384 r"""
@@ -345,6 +390,16 @@ def __call__(
345390 numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
346391 or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
347392 list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`
393+ prompt (`str` or `List[str]`, *optional*):
394+ The prompt or prompts to guide the image generation. **experimental feature**: to use this feature,
395+ make sure to explicitly load text encoders to the pipeline. Prompts will be ignored if text encoders
396+ are not loaded.
397+ prompt_2 (`str` or `List[str]`, *optional*):
398+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`.
399+ prompt_embeds (`torch.FloatTensor`, *optional*):
400+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
401+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
402+ Pre-generated pooled text embeddings.
348403 return_dict (`bool`, *optional*, defaults to `True`):
349404 Whether or not to return a [`~pipelines.flux.FluxPriorReduxPipelineOutput`] instead of a plain tuple.
350405
@@ -356,13 +411,31 @@ def __call__(
356411 returning a tuple, the first element is a list with the generated images.
357412 """
358413
414+ # 1. Check inputs. Raise error if not correct
415+ self .check_inputs (
416+ image ,
417+ prompt ,
418+ prompt_2 ,
419+ prompt_embeds = prompt_embeds ,
420+ pooled_prompt_embeds = pooled_prompt_embeds ,
421+ prompt_embeds_scale = prompt_embeds_scale ,
422+ pooled_prompt_embeds_scale = pooled_prompt_embeds_scale ,
423+ )
424+
359425 # 2. Define call parameters
360426 if image is not None and isinstance (image , Image .Image ):
361427 batch_size = 1
362428 elif image is not None and isinstance (image , list ):
363429 batch_size = len (image )
364430 else :
365431 batch_size = image .shape [0 ]
432+ if prompt is not None and isinstance (prompt , str ):
433+ prompt = batch_size * [prompt ]
434+ if isinstance (prompt_embeds_scale , float ):
435+ prompt_embeds_scale = batch_size * [prompt_embeds_scale ]
436+ if isinstance (pooled_prompt_embeds_scale , float ):
437+ pooled_prompt_embeds_scale = batch_size * [pooled_prompt_embeds_scale ]
438+
366439 device = self ._execution_device
367440
368441 # 3. Prepare image embeddings
@@ -378,24 +451,38 @@ def __call__(
378451 pooled_prompt_embeds ,
379452 _ ,
380453 ) = self .encode_prompt (
381- prompt = [ "" ] * batch_size ,
382- prompt_2 = None ,
383- prompt_embeds = None ,
384- pooled_prompt_embeds = None ,
454+ prompt = prompt ,
455+ prompt_2 = prompt_2 ,
456+ prompt_embeds = prompt_embeds ,
457+ pooled_prompt_embeds = pooled_prompt_embeds ,
385458 device = device ,
386459 num_images_per_prompt = 1 ,
387460 max_sequence_length = 512 ,
388461 lora_scale = None ,
389462 )
390463 else :
464+ if prompt is not None :
465+ logger .warning (
466+ "prompt input is ignored when text encoders are not loaded to the pipeline. "
467+ "Make sure to explicitly load the text encoders to enable prompt input. "
468+ )
391469 # max_sequence_length is 512, t5 encoder hidden size is 4096
392470 prompt_embeds = torch .zeros ((batch_size , 512 , 4096 ), device = device , dtype = image_embeds .dtype )
393471 # pooled_prompt_embeds is 768, clip text encoder hidden size
394472 pooled_prompt_embeds = torch .zeros ((batch_size , 768 ), device = device , dtype = image_embeds .dtype )
395473
396- # Concatenate image and text embeddings
474+ # scale & concatenate image and text embeddings
397475 prompt_embeds = torch .cat ([prompt_embeds , image_embeds ], dim = 1 )
398476
477+ prompt_embeds *= torch .tensor (prompt_embeds_scale , device = device , dtype = image_embeds .dtype )[:, None , None ]
478+ pooled_prompt_embeds *= torch .tensor (pooled_prompt_embeds_scale , device = device , dtype = image_embeds .dtype )[
479+ :, None
480+ ]
481+
482+ # weighted sum
483+ prompt_embeds = torch .sum (prompt_embeds , dim = 0 , keepdim = True )
484+ pooled_prompt_embeds = torch .sum (pooled_prompt_embeds , dim = 0 , keepdim = True )
485+
399486 # Offload all models
400487 self .maybe_free_model_hooks ()
401488
0 commit comments