@@ -225,10 +225,9 @@ def __init__(
225225 # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
226226 # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
227227 self .image_processor = VaeImageProcessor (vae_scale_factor = self .vae_scale_factor * 2 )
228- latent_channels = self .vae .config .latent_channels if getattr (self , "vae" , None ) else 16
229228 self .mask_processor = VaeImageProcessor (
230229 vae_scale_factor = self .vae_scale_factor * 2 ,
231- vae_latent_channels = latent_channels ,
230+ vae_latent_channels = self . vae . config . latent_channels ,
232231 do_normalize = False ,
233232 do_binarize = True ,
234233 do_convert_grayscale = True ,
@@ -493,10 +492,40 @@ def encode_prompt(
493492
494493 return prompt_embeds , pooled_prompt_embeds , text_ids
495494
495+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
496+ def _encode_vae_image (self , image : torch .Tensor , generator : torch .Generator ):
497+ if isinstance (generator , list ):
498+ image_latents = [
499+ retrieve_latents (self .vae .encode (image [i : i + 1 ]), generator = generator [i ])
500+ for i in range (image .shape [0 ])
501+ ]
502+ image_latents = torch .cat (image_latents , dim = 0 )
503+ else :
504+ image_latents = retrieve_latents (self .vae .encode (image ), generator = generator )
505+
506+ image_latents = (
507+ image_latents - self .vae .config .shift_factor
508+ ) * self .vae .config .scaling_factor
509+
510+ return image_latents
511+
512+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
513+ def get_timesteps (self , num_inference_steps , strength , device ):
514+ # get the original timestep using init_timestep
515+ init_timestep = min (num_inference_steps * strength , num_inference_steps )
516+
517+ t_start = int (max (num_inference_steps - init_timestep , 0 ))
518+ timesteps = self .scheduler .timesteps [t_start * self .scheduler .order :]
519+ if hasattr (self .scheduler , "set_begin_index" ):
520+ self .scheduler .set_begin_index (t_start * self .scheduler .order )
521+
522+ return timesteps , num_inference_steps - t_start
523+
496524 def check_inputs (
497525 self ,
498526 prompt ,
499527 prompt_2 ,
528+ strength ,
500529 height ,
501530 width ,
502531 prompt_embeds = None ,
@@ -507,6 +536,9 @@ def check_inputs(
507536 mask_image = None ,
508537 masked_image_latents = None ,
509538 ):
539+ if strength < 0 or strength > 1 :
540+ raise ValueError (f"The value of strength should in [0.0, 1.0] but is { strength } " )
541+
510542 if height % (self .vae_scale_factor * 2 ) != 0 or width % (self .vae_scale_factor * 2 ) != 0 :
511543 logger .warning (
512544 f"`height` and `width` have to be divisible by { self .vae_scale_factor * 2 } but are { height } and { width } . Dimensions will be resized accordingly"
@@ -627,6 +659,8 @@ def disable_vae_tiling(self):
627659 # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
628660 def prepare_latents (
629661 self ,
662+ image ,
663+ timestep ,
630664 batch_size ,
631665 num_channels_latents ,
632666 height ,
@@ -643,22 +677,37 @@ def prepare_latents(
643677
644678 shape = (batch_size , num_channels_latents , height , width )
645679
646- if latents is not None :
647- latent_image_ids = self . _prepare_latent_image_ids ( batch_size , height // 2 , width // 2 , device , dtype )
648- return latents . to ( device = device , dtype = dtype ), latent_image_ids
680+ # if latents is not None:
681+ image = image . to ( device = device , dtype = dtype )
682+ image_latents = self . _encode_vae_image ( image = image , generator = generator )
649683
650- if isinstance (generator , list ) and len (generator ) != batch_size :
684+ latent_image_ids = self ._prepare_latent_image_ids (
685+ batch_size , height // 2 , width // 2 , device , dtype
686+ )
687+ if batch_size > image_latents .shape [0 ] and batch_size % image_latents .shape [0 ] == 0 :
688+ # expand init_latents for batch_size
689+ additional_image_per_prompt = batch_size // image_latents .shape [0 ]
690+ image_latents = torch .cat ([image_latents ] * additional_image_per_prompt , dim = 0 )
691+ elif batch_size > image_latents .shape [0 ] and batch_size % image_latents .shape [0 ] != 0 :
651692 raise ValueError (
652- f"You have passed a list of generators of length { len (generator )} , but requested an effective batch"
653- f" size of { batch_size } . Make sure the batch size matches the length of the generators."
693+ f"Cannot duplicate `image` of batch size { image_latents .shape [0 ]} to { batch_size } text prompts."
654694 )
695+ else :
696+ image_latents = torch .cat ([image_latents ], dim = 0 )
655697
656- latents = randn_tensor (shape , generator = generator , device = device , dtype = dtype )
657- latents = self ._pack_latents (latents , batch_size , num_channels_latents , height , width )
658-
659- latent_image_ids = self ._prepare_latent_image_ids (batch_size , height // 2 , width // 2 , device , dtype )
698+ if latents is None :
699+ noise = randn_tensor (shape , generator = generator , device = device , dtype = dtype )
700+ latents = self .scheduler .scale_noise (image_latents , timestep , noise )
701+ else :
702+ noise = latents .to (device )
703+ latents = noise
660704
661- return latents , latent_image_ids
705+ noise = self ._pack_latents (noise , batch_size , num_channels_latents , height , width )
706+ image_latents = self ._pack_latents (
707+ image_latents , batch_size , num_channels_latents , height , width
708+ )
709+ latents = self ._pack_latents (latents , batch_size , num_channels_latents , height , width )
710+ return latents , noise , image_latents , latent_image_ids
662711
663712 @property
664713 def guidance_scale (self ):
@@ -687,6 +736,7 @@ def __call__(
687736 masked_image_latents : Optional [torch .FloatTensor ] = None ,
688737 height : Optional [int ] = None ,
689738 width : Optional [int ] = None ,
739+ strength : float = 1.0 ,
690740 num_inference_steps : int = 50 ,
691741 sigmas : Optional [List [float ]] = None ,
692742 guidance_scale : float = 30.0 ,
@@ -731,6 +781,12 @@ def __call__(
731781 The height in pixels of the generated image. This is set to 1024 by default for the best results.
732782 width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
733783 The width in pixels of the generated image. This is set to 1024 by default for the best results.
784+ strength (`float`, *optional*, defaults to 1.0):
785+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
786+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
787+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
788+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
789+ essentially ignores `image`.
734790 num_inference_steps (`int`, *optional*, defaults to 50):
735791 The number of denoising steps. More denoising steps usually lead to a higher quality image at the
736792 expense of slower inference.
@@ -794,6 +850,7 @@ def __call__(
794850 self .check_inputs (
795851 prompt ,
796852 prompt_2 ,
853+ strength ,
797854 height ,
798855 width ,
799856 prompt_embeds = prompt_embeds ,
@@ -809,6 +866,10 @@ def __call__(
809866 self ._joint_attention_kwargs = joint_attention_kwargs
810867 self ._interrupt = False
811868
869+ original_image = image
870+ init_image = self .image_processor .preprocess (image , height = height , width = width )
871+ init_image = init_image .to (dtype = torch .float32 )
872+
812873 # 2. Define call parameters
813874 if prompt is not None and isinstance (prompt , str ):
814875 batch_size = 1
@@ -821,7 +882,9 @@ def __call__(
821882
822883 # 3. Prepare prompt embeddings
823884 lora_scale = (
824- self .joint_attention_kwargs .get ("scale" , None ) if self .joint_attention_kwargs is not None else None
885+ self .joint_attention_kwargs .get ("scale" , None )
886+ if self .joint_attention_kwargs is not None
887+ else None
825888 )
826889 (
827890 prompt_embeds ,
@@ -838,9 +901,43 @@ def __call__(
838901 lora_scale = lora_scale ,
839902 )
840903
904+ # 6. Prepare timesteps
905+ sigmas = (
906+ np .linspace (1.0 , 1 / num_inference_steps , num_inference_steps )
907+ if sigmas is None
908+ else sigmas
909+ )
910+ image_seq_len = (int (height ) // self .vae_scale_factor // 2 ) * (
911+ int (width ) // self .vae_scale_factor // 2
912+ )
913+ mu = calculate_shift (
914+ image_seq_len ,
915+ self .scheduler .config .base_image_seq_len ,
916+ self .scheduler .config .max_image_seq_len ,
917+ self .scheduler .config .base_shift ,
918+ self .scheduler .config .max_shift ,
919+ )
920+ timesteps , num_inference_steps = retrieve_timesteps (
921+ self .scheduler ,
922+ num_inference_steps ,
923+ device ,
924+ sigmas = sigmas ,
925+ mu = mu ,
926+ )
927+ timesteps , num_inference_steps = self .get_timesteps (num_inference_steps , strength , device )
928+
929+ if num_inference_steps < 1 :
930+ raise ValueError (
931+ f"After adjusting the num_inference_steps by strength parameter: { strength } , the number of pipeline"
932+ f"steps is { num_inference_steps } which is < 1 and not appropriate for this pipeline."
933+ )
934+ latent_timestep = timesteps [:1 ].repeat (batch_size * num_images_per_prompt )
935+
841936 # 4. Prepare latent variables
842937 num_channels_latents = self .vae .config .latent_channels
843- latents , latent_image_ids = self .prepare_latents (
938+ latents , noise , image_latents , latent_image_ids = self .prepare_latents (
939+ init_image ,
940+ latent_timestep ,
844941 batch_size * num_images_per_prompt ,
845942 num_channels_latents ,
846943 height ,
@@ -855,13 +952,13 @@ def __call__(
855952 if masked_image_latents is not None :
856953 masked_image_latents = masked_image_latents .to (latents .device )
857954 else :
858- image = self .image_processor .preprocess (image , height = height , width = width )
955+ # image = self.image_processor.preprocess(image, height=height, width=width)
859956 mask_image = self .mask_processor .preprocess (mask_image , height = height , width = width )
860957
861- masked_image = image * (1 - mask_image )
958+ masked_image = init_image * (1 - mask_image )
862959 masked_image = masked_image .to (device = device , dtype = prompt_embeds .dtype )
863960
864- height , width = image .shape [- 2 :]
961+ height , width = init_image .shape [- 2 :]
865962 mask , masked_image_latents = self .prepare_mask_latents (
866963 mask_image ,
867964 masked_image ,
@@ -876,23 +973,6 @@ def __call__(
876973 )
877974 masked_image_latents = torch .cat ((masked_image_latents , mask ), dim = - 1 )
878975
879- # 6. Prepare timesteps
880- sigmas = np .linspace (1.0 , 1 / num_inference_steps , num_inference_steps ) if sigmas is None else sigmas
881- image_seq_len = latents .shape [1 ]
882- mu = calculate_shift (
883- image_seq_len ,
884- self .scheduler .config .get ("base_image_seq_len" , 256 ),
885- self .scheduler .config .get ("max_image_seq_len" , 4096 ),
886- self .scheduler .config .get ("base_shift" , 0.5 ),
887- self .scheduler .config .get ("max_shift" , 1.16 ),
888- )
889- timesteps , num_inference_steps = retrieve_timesteps (
890- self .scheduler ,
891- num_inference_steps ,
892- device ,
893- sigmas = sigmas ,
894- mu = mu ,
895- )
896976 num_warmup_steps = max (len (timesteps ) - num_inference_steps * self .scheduler .order , 0 )
897977 self ._num_timesteps = len (timesteps )
898978
0 commit comments