@@ -292,8 +292,10 @@ def encode_prompt(
292292 negative_prompt : Union [str , List [str ]] = None ,
293293 device : Optional [torch .device ] = None ,
294294 num_images_per_prompt : int = 1 ,
295- prompt_embeds : Optional [torch .FloatTensor ] = None ,
296- negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
295+ prompt_embeds : Optional [torch .Tensor ] = None ,
296+ negative_prompt_embeds : Optional [torch .Tensor ] = None ,
297+ prompt_attention_mask : Optional [torch .Tensor ] = None ,
298+ negative_prompt_attention_mask : Optional [torch .Tensor ] = None ,
297299 do_classifier_free_guidance : bool = True ,
298300 max_sequence_length : int = 512 ,
299301 lora_scale : Optional [float ] = None ,
@@ -310,7 +312,7 @@ def encode_prompt(
310312 torch device
311313 num_images_per_prompt (`int`):
312314 number of images that should be generated per prompt
313- prompt_embeds (`torch.FloatTensor `, *optional*):
315+ prompt_embeds (`torch.Tensor `, *optional*):
314316 Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
315317 provided, text embeddings will be generated from `prompt` input argument.
316318 lora_scale (`float`, *optional*):
@@ -335,7 +337,7 @@ def encode_prompt(
335337 batch_size = prompt_embeds .shape [0 ]
336338
337339 if prompt_embeds is None :
338- prompt_embeds = self ._get_t5_prompt_embeds (
340+ prompt_embeds , prompt_attention_mask = self ._get_t5_prompt_embeds (
339341 prompt = prompt ,
340342 num_images_per_prompt = num_images_per_prompt ,
341343 max_sequence_length = max_sequence_length ,
@@ -365,20 +367,28 @@ def encode_prompt(
365367 " the batch size of `prompt`."
366368 )
367369
368- negative_prompt_embeds = self ._get_t5_prompt_embeds (
370+ negative_prompt_embeds , negative_prompt_attention_mask = self ._get_t5_prompt_embeds (
369371 prompt = negative_prompt ,
370372 num_images_per_prompt = num_images_per_prompt ,
371373 max_sequence_length = max_sequence_length ,
372374 device = device ,
373375 )
376+
374377 negative_text_ids = torch .zeros (negative_prompt_embeds .shape [1 ], 3 ).to (device = device , dtype = dtype )
375378
376379 if self .text_encoder is not None :
377380 if isinstance (self , FluxLoraLoaderMixin ) and USE_PEFT_BACKEND :
378381 # Retrieve the original scale by scaling back the LoRA layers
379382 unscale_lora_layers (self .text_encoder , lora_scale )
380383
381- return prompt_embeds , text_ids , negative_prompt_embeds , negative_text_ids
384+ return (
385+ prompt_embeds ,
386+ text_ids ,
387+ prompt_attention_mask ,
388+ negative_prompt_embeds ,
389+ negative_text_ids ,
390+ negative_prompt_attention_mask ,
391+ )
382392
383393 # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
384394 def encode_image (self , image , device , num_images_per_prompt ):
@@ -392,52 +402,44 @@ def encode_image(self, image, device, num_images_per_prompt):
392402 image_embeds = image_embeds .repeat_interleave (num_images_per_prompt , dim = 0 )
393403 return image_embeds
394404
395- # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline .prepare_ip_adapter_image_embeds
405+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline .prepare_ip_adapter_image_embeds
396406 def prepare_ip_adapter_image_embeds (
397- self ,
398- ip_adapter_image : Optional [PipelineImageInput ] = None ,
399- ip_adapter_image_embeds : Optional [torch .Tensor ] = None ,
400- device : Optional [torch .device ] = None ,
401- num_images_per_prompt : int = 1 ,
402- do_classifier_free_guidance : bool = True ,
403- ) -> torch .Tensor :
404- """Prepares image embeddings for use in the IP-Adapter.
407+ self , ip_adapter_image , ip_adapter_image_embeds , device , num_images_per_prompt
408+ ):
409+ device = device or self ._execution_device
405410
406- Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed.
411+ image_embeds = []
412+ if ip_adapter_image_embeds is None :
413+ if not isinstance (ip_adapter_image , list ):
414+ ip_adapter_image = [ip_adapter_image ]
407415
408- Args:
409- ip_adapter_image (`PipelineImageInput`, *optional*):
410- The input image to extract features from for IP-Adapter.
411- ip_adapter_image_embeds (`torch.Tensor`, *optional*):
412- Precomputed image embeddings.
413- device: (`torch.device`, *optional*):
414- Torch device.
415- num_images_per_prompt (`int`, defaults to 1):
416- Number of images that should be generated per prompt.
417- do_classifier_free_guidance (`bool`, defaults to True):
418- Whether to use classifier free guidance or not.
419- """
420- device = device or self ._execution_device
416+ if len (ip_adapter_image ) != self .transformer .encoder_hid_proj .num_ip_adapters :
417+ raise ValueError (
418+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got { len (ip_adapter_image )} images and { self .transformer .encoder_hid_proj .num_ip_adapters } IP Adapters."
419+ )
421420
422- if ip_adapter_image_embeds is not None :
423- if do_classifier_free_guidance :
424- single_negative_image_embeds , single_image_embeds = ip_adapter_image_embeds .chunk (2 )
425- else :
426- single_image_embeds = ip_adapter_image_embeds
427- elif ip_adapter_image is not None :
428- single_image_embeds = self .encode_image (ip_adapter_image , device )
429- if do_classifier_free_guidance :
430- single_negative_image_embeds = torch .zeros_like (single_image_embeds )
421+ for single_ip_adapter_image in ip_adapter_image :
422+ single_image_embeds = self .encode_image (single_ip_adapter_image , device , 1 )
423+ image_embeds .append (single_image_embeds [None , :])
431424 else :
432- raise ValueError ("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided." )
425+ if not isinstance (ip_adapter_image_embeds , list ):
426+ ip_adapter_image_embeds = [ip_adapter_image_embeds ]
433427
434- image_embeds = torch .cat ([single_image_embeds ] * num_images_per_prompt , dim = 0 )
428+ if len (ip_adapter_image_embeds ) != self .transformer .encoder_hid_proj .num_ip_adapters :
429+ raise ValueError (
430+ f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got { len (ip_adapter_image_embeds )} image embeds and { self .transformer .encoder_hid_proj .num_ip_adapters } IP Adapters."
431+ )
435432
436- if do_classifier_free_guidance :
437- negative_image_embeds = torch .cat ([single_negative_image_embeds ] * num_images_per_prompt , dim = 0 )
438- image_embeds = torch .cat ([negative_image_embeds , image_embeds ], dim = 0 )
433+ for single_image_embeds in ip_adapter_image_embeds :
434+ image_embeds .append (single_image_embeds )
435+
436+ ip_adapter_image_embeds = []
437+ for single_image_embeds in image_embeds :
438+ single_image_embeds = torch .cat ([single_image_embeds ] * num_images_per_prompt , dim = 0 )
439+ single_image_embeds = single_image_embeds .to (device = device )
440+ ip_adapter_image_embeds .append (single_image_embeds )
439441
440- return image_embeds . to ( device = device )
442+ return ip_adapter_image_embeds
441443
442444 def check_inputs (
443445 self ,
@@ -448,6 +450,8 @@ def check_inputs(
448450 negative_prompt = None ,
449451 prompt_embeds = None ,
450452 negative_prompt_embeds = None ,
453+ prompt_attention_mask = None ,
454+ negative_prompt_attention_mask = None ,
451455 callback_on_step_end_tensor_inputs = None ,
452456 max_sequence_length = None ,
453457 ):
@@ -483,6 +487,15 @@ def check_inputs(
483487 f"Cannot forward both `negative_prompt`: { negative_prompt } and `negative_prompt_embeds`:"
484488 f" { negative_prompt_embeds } . Please make sure to only forward one of the two."
485489 )
490+ if prompt_attention_mask is not None and negative_prompt_attention_mask is None :
491+ raise ValueError (
492+ "Cannot provide `prompt_attention_mask` without also providing `negative_prompt_attention_mask`"
493+ )
494+
495+ if negative_prompt_attention_mask is not None and prompt_attention_mask is None :
496+ raise ValueError (
497+ "Cannot provide `negative_prompt_attention_mask` without also providing `prompt_attention_mask`"
498+ )
486499
487500 if max_sequence_length is not None and max_sequence_length > 512 :
488501 raise ValueError (f"`max_sequence_length` cannot be greater than 512 but is { max_sequence_length } " )
@@ -591,7 +604,7 @@ def prepare_latents(
591604 height = 2 * (int (height ) // (self .vae_scale_factor * 2 ))
592605 width = 2 * (int (width ) // (self .vae_scale_factor * 2 ))
593606 shape = (batch_size , num_channels_latents , height , width )
594- latent_image_ids = self ._prepare_latent_image_ids (batch_size , height // 2 , width // 2 , device , dtype )
607+ latent_image_ids = self ._prepare_latent_image_ids (height // 2 , width // 2 , device , dtype )
595608
596609 if latents is not None :
597610 return latents .to (device = device , dtype = dtype ), latent_image_ids
@@ -617,6 +630,25 @@ def prepare_latents(
617630 latents = self ._pack_latents (latents , batch_size , num_channels_latents , height , width )
618631 return latents , latent_image_ids
619632
633+ def _prepare_attention_mask (
634+ self ,
635+ batch_size ,
636+ sequence_length ,
637+ dtype ,
638+ attention_mask = None ,
639+ ):
640+ if attention_mask is None :
641+ return attention_mask
642+
643+ # Extend the prompt attention mask to account for image tokens in the final sequence
644+ attention_mask = torch .cat (
645+ [attention_mask , torch .ones (batch_size , sequence_length , device = attention_mask .device )],
646+ dim = 1 ,
647+ )
648+ attention_mask = attention_mask .to (dtype )
649+
650+ return attention_mask
651+
620652 @property
621653 def guidance_scale (self ):
622654 return self ._guidance_scale
@@ -656,13 +688,15 @@ def __call__(
656688 strength : float = 0.8 ,
657689 num_images_per_prompt : Optional [int ] = 1 ,
658690 generator : Optional [Union [torch .Generator , List [torch .Generator ]]] = None ,
659- latents : Optional [torch .FloatTensor ] = None ,
660- prompt_embeds : Optional [torch .FloatTensor ] = None ,
691+ latents : Optional [torch .Tensor ] = None ,
692+ prompt_embeds : Optional [torch .Tensor ] = None ,
661693 ip_adapter_image : Optional [PipelineImageInput ] = None ,
662694 ip_adapter_image_embeds : Optional [List [torch .Tensor ]] = None ,
663695 negative_ip_adapter_image : Optional [PipelineImageInput ] = None ,
664696 negative_ip_adapter_image_embeds : Optional [List [torch .Tensor ]] = None ,
665- negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
697+ negative_prompt_embeds : Optional [torch .Tensor ] = None ,
698+ prompt_attention_mask : Optional [torch .Tensor ] = None ,
699+ negative_prompt_attention_mask : Optional [torch .tensor ] = None ,
666700 output_type : Optional [str ] = "pil" ,
667701 return_dict : bool = True ,
668702 joint_attention_kwargs : Optional [Dict [str , Any ]] = None ,
@@ -703,11 +737,11 @@ def __call__(
703737 generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
704738 One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
705739 to make generation deterministic.
706- latents (`torch.FloatTensor `, *optional*):
740+ latents (`torch.Tensor `, *optional*):
707741 Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
708742 generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
709743 tensor will ge generated by sampling using the supplied random `generator`.
710- prompt_embeds (`torch.FloatTensor `, *optional*):
744+ prompt_embeds (`torch.Tensor `, *optional*):
711745 Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
712746 provided, text embeddings will be generated from `prompt` input argument.
713747 ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
@@ -721,7 +755,7 @@ def __call__(
721755 Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
722756 IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
723757 provided, embeddings are computed from the `ip_adapter_image` input argument.
724- negative_prompt_embeds (`torch.FloatTensor `, *optional*):
758+ negative_prompt_embeds (`torch.Tensor `, *optional*):
725759 Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
726760 weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
727761 argument.
@@ -765,6 +799,8 @@ def __call__(
765799 negative_prompt = negative_prompt ,
766800 prompt_embeds = prompt_embeds ,
767801 negative_prompt_embeds = negative_prompt_embeds ,
802+ prompt_attention_mask = prompt_attention_mask ,
803+ negative_prompt_attention_mask = negative_prompt_attention_mask ,
768804 callback_on_step_end_tensor_inputs = callback_on_step_end_tensor_inputs ,
769805 max_sequence_length = max_sequence_length ,
770806 )
@@ -794,13 +830,17 @@ def __call__(
794830 (
795831 prompt_embeds ,
796832 text_ids ,
833+ prompt_attention_mask ,
797834 negative_prompt_embeds ,
798835 negative_text_ids ,
836+ negative_prompt_attention_mask ,
799837 ) = self .encode_prompt (
800838 prompt = prompt ,
801839 negative_prompt = negative_prompt ,
802840 prompt_embeds = prompt_embeds ,
803841 negative_prompt_embeds = negative_prompt_embeds ,
842+ prompt_attention_mask = prompt_attention_mask ,
843+ negative_prompt_attention_mask = negative_prompt_attention_mask ,
804844 do_classifier_free_guidance = self .do_classifier_free_guidance ,
805845 device = device ,
806846 num_images_per_prompt = num_images_per_prompt ,
@@ -856,20 +896,55 @@ def __call__(
856896 latents ,
857897 )
858898
899+ attention_mask = self ._prepare_attention_mask (
900+ batch_size = latents .shape [0 ],
901+ sequence_length = image_seq_len ,
902+ dtype = latents .dtype ,
903+ attention_mask = prompt_attention_mask ,
904+ )
905+ if self .do_classifier_free_guidance and negative_prompt_attention_mask is not None :
906+ negative_attention_mask = self ._prepare_attention_mask (
907+ batch_size = latents .shape [0 ],
908+ sequence_length = image_seq_len ,
909+ dtype = latents .dtype ,
910+ attention_mask = negative_prompt_attention_mask ,
911+ )
912+ attention_mask = torch .cat ([negative_attention_mask , attention_mask ], dim = 0 )
913+
859914 # 6. Prepare image embeddings
860- if (ip_adapter_image is not None and self .is_ip_adapter_active ) or ip_adapter_image_embeds is not None :
861- ip_adapter_image_embeds = self .prepare_ip_adapter_image_embeds (
915+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None ) and (
916+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
917+ ):
918+ negative_ip_adapter_image = np .zeros ((width , height , 3 ), dtype = np .uint8 )
919+ negative_ip_adapter_image = [negative_ip_adapter_image ] * self .transformer .encoder_hid_proj .num_ip_adapters
920+
921+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None ) and (
922+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
923+ ):
924+ ip_adapter_image = np .zeros ((width , height , 3 ), dtype = np .uint8 )
925+ ip_adapter_image = [ip_adapter_image ] * self .transformer .encoder_hid_proj .num_ip_adapters
926+
927+ image_embeds = None
928+ negative_image_embeds = None
929+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None :
930+ image_embeds = self .prepare_ip_adapter_image_embeds (
862931 ip_adapter_image ,
863932 ip_adapter_image_embeds ,
864933 device ,
865934 batch_size * num_images_per_prompt ,
866- self .do_classifier_free_guidance ,
867935 )
936+ if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None :
937+ negative_image_embeds = self .prepare_ip_adapter_image_embeds (
938+ negative_ip_adapter_image ,
939+ negative_ip_adapter_image_embeds ,
940+ device ,
941+ batch_size * num_images_per_prompt ,
942+ )
943+ if self .do_classifier_free_guidance and image_embeds is not None and negative_image_embeds is not None :
944+ image_embeds = torch .cat ([negative_image_embeds , image_embeds ], dim = 0 )
868945
869- if self .joint_attention_kwargs is None :
870- self ._joint_attention_kwargs = {"ip_adapter_image_embeds" : ip_adapter_image_embeds }
871- else :
872- self ._joint_attention_kwargs .update (ip_adapter_image_embeds = ip_adapter_image_embeds )
946+ if image_embeds is not None :
947+ self ._joint_attention_kwargs ["ip_adapter_image_embeds" ] = image_embeds
873948
874949 # 6. Denoising loop
875950 with self .progress_bar (total = num_inference_steps ) as progress_bar :
@@ -878,9 +953,6 @@ def __call__(
878953 continue
879954
880955 self ._current_timestep = t
881- if ip_adapter_image_embeds is not None :
882- self ._joint_attention_kwargs ["ip_adapter_image_embeds" ] = ip_adapter_image_embeds
883-
884956 latent_model_input = torch .cat ([latents ] * 2 ) if self .do_classifier_free_guidance else latents
885957
886958 # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
@@ -892,6 +964,7 @@ def __call__(
892964 encoder_hidden_states = prompt_embeds ,
893965 txt_ids = text_ids ,
894966 img_ids = latent_image_ids ,
967+ attention_mask = attention_mask ,
895968 joint_attention_kwargs = self .joint_attention_kwargs ,
896969 return_dict = False ,
897970 )[0 ]
0 commit comments