@@ -660,8 +660,8 @@ def prepare_image(
660660
661661    def  prepare_mask_latents (
662662        self ,
663-         mask ,
664-         masked_image ,
663+         image ,
664+         mask_image ,
665665        batch_size ,
666666        num_channels_latents ,
667667        num_images_per_prompt ,
@@ -673,34 +673,40 @@ def prepare_mask_latents(
673673    ):
674674        # VAE applies 8x compression on images but we must also account for packing which requires 
675675        # latent height and width to be divisible by 2. 
676+         image  =  self .image_processor .preprocess (image , height = height , width = width )
677+         mask_image  =  self .mask_processor .preprocess (mask_image , height = height , width = width )
678+ 
679+         masked_image  =  image  *  (1  -  mask_image )
680+         masked_image  =  masked_image .to (device = device , dtype = dtype )
681+ 
676682        height  =  2  *  (int (height ) //  (self .vae_scale_factor  *  2 ))
677683        width  =  2  *  (int (width ) //  (self .vae_scale_factor  *  2 ))
678684        # resize the mask to latents shape as we concatenate the mask to the latents 
679685        # we do that before converting to dtype to avoid breaking in case we're using cpu_offload 
680686        # and half precision 
681-         mask  =  torch .nn .functional .interpolate (mask , size = (height , width ))
682-         mask  =  mask .to (device = device , dtype = dtype )
687+         mask_image  =  torch .nn .functional .interpolate (mask_image , size = (height , width ))
688+         mask_image  =  mask_image .to (device = device , dtype = dtype )
683689
684690        batch_size  =  batch_size  *  num_images_per_prompt 
685691
686692        masked_image  =  masked_image .to (device = device , dtype = dtype )
687693
688-         if  masked_image .shape [1 ] ==  16 :
694+         if  masked_image .shape [1 ] ==  num_channels_latents :
689695            masked_image_latents  =  masked_image 
690696        else :
691697            masked_image_latents  =  retrieve_latents (self .vae .encode (masked_image ), generator = generator )
692698
693699        masked_image_latents  =  (masked_image_latents  -  self .vae .config .shift_factor ) *  self .vae .config .scaling_factor 
694700
695701        # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method 
696-         if  mask .shape [0 ] <  batch_size :
697-             if  not  batch_size  %  mask .shape [0 ] ==  0 :
702+         if  mask_image .shape [0 ] <  batch_size :
703+             if  not  batch_size  %  mask_image .shape [0 ] ==  0 :
698704                raise  ValueError (
699705                    "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" 
700-                     f" a total batch size of { batch_size } { mask .shape [0 ]} masks  were passed. Make sure the number" 
706+                     f" a total batch size of { batch_size } { mask_image .shape [0 ]} mask_image  were passed. Make sure the number" 
701707                    " of masks that you pass is divisible by the total requested batch size." 
702708                )
703-             mask  =  mask .repeat (batch_size  //  mask .shape [0 ], 1 , 1 , 1 )
709+             mask_image  =  mask_image .repeat (batch_size  //  mask_image .shape [0 ], 1 , 1 , 1 )
704710        if  masked_image_latents .shape [0 ] <  batch_size :
705711            if  not  batch_size  %  masked_image_latents .shape [0 ] ==  0 :
706712                raise  ValueError (
@@ -719,15 +725,16 @@ def prepare_mask_latents(
719725            height ,
720726            width ,
721727        )
722-         mask  =  self ._pack_latents (
723-             mask .repeat (1 , num_channels_latents , 1 , 1 ),
728+         mask_image  =  self ._pack_latents (
729+             mask_image .repeat (1 , num_channels_latents , 1 , 1 ),
724730            batch_size ,
725731            num_channels_latents ,
726732            height ,
727733            width ,
728734        )
735+         masked_image_latents  =  torch .cat ((masked_image_latents , mask_image ), dim = - 1 )
729736
730-         return  mask , masked_image_latents 
737+         return  mask_image , masked_image_latents 
731738
732739    @property  
733740    def  guidance_scale (self ):
@@ -759,7 +766,7 @@ def __call__(
759766        width : Optional [int ] =  None ,
760767        strength : float  =  0.6 ,
761768        num_inference_steps : int  =  28 ,
762-         timesteps :  List [int ] =  None ,
769+         sigmas :  Optional [ List [float ] ] =  None ,
763770        guidance_scale : float  =  7.0 ,
764771        num_images_per_prompt : Optional [int ] =  1 ,
765772        generator : Optional [Union [torch .Generator , List [torch .Generator ]]] =  None ,
@@ -820,10 +827,10 @@ def __call__(
820827            num_inference_steps (`int`, *optional*, defaults to 50): 
821828                The number of denoising steps. More denoising steps usually lead to a higher quality image at the 
822829                expense of slower inference. 
823-             timesteps  (`List[int ]`, *optional*): 
824-                 Custom timesteps  to use for the denoising process with schedulers which support a `timesteps ` argument 
825-                 in  their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is 
826-                 passed  will be used. Must be in descending order . 
830+             sigmas  (`List[float ]`, *optional*): 
831+                 Custom sigmas  to use for the denoising process with schedulers which support a `sigmas ` argument in  
832+                 their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed  
833+                 will be used. 
827834            guidance_scale (`float`, *optional*, defaults to 7.0): 
828835                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 
829836                `guidance_scale` is defined as `w` of equation 2. of [Imagen 
@@ -927,18 +934,13 @@ def __call__(
927934        # 3. Preprocess mask and image 
928935        num_channels_latents  =  self .vae .config .latent_channels 
929936        if  masked_image_latents  is  not None :
937+             # pre computed masked_image_latents and mask_image 
930938            masked_image_latents  =  masked_image_latents .to (latents .device )
939+             mask  =  mask_image .to (latents .device )
931940        else :
932-             image  =  self .image_processor .preprocess (image , height = height , width = width )
933-             mask_image  =  self .mask_processor .preprocess (mask_image , height = height , width = width )
934- 
935-             masked_image  =  image  *  (1  -  mask_image )
936-             masked_image  =  masked_image .to (device = device , dtype = prompt_embeds .dtype )
937- 
938-             height , width  =  image .shape [- 2 :]
939941            mask , masked_image_latents  =  self .prepare_mask_latents (
942+                 image ,
940943                mask_image ,
941-                 masked_image ,
942944                batch_size ,
943945                num_channels_latents ,
944946                num_images_per_prompt ,
@@ -948,13 +950,12 @@ def __call__(
948950                device ,
949951                generator ,
950952            )
951-             masked_image_latents  =  torch .cat ((masked_image_latents , mask ), dim = - 1 )
952953
953954        init_image  =  self .image_processor .preprocess (image , height = height , width = width )
954955        init_image  =  init_image .to (dtype = torch .float32 )
955956
956957        # 4.Prepare timesteps 
957-         sigmas  =  np .linspace (1.0 , 1  /  num_inference_steps , num_inference_steps )
958+         sigmas  =  np .linspace (1.0 , 1  /  num_inference_steps , num_inference_steps )  if   sigmas   is   None   else   sigmas 
958959        image_seq_len  =  (int (height ) //  self .vae_scale_factor  //  2 ) *  (int (width ) //  self .vae_scale_factor  //  2 )
959960        mu  =  calculate_shift (
960961            image_seq_len ,
@@ -967,8 +968,7 @@ def __call__(
967968            self .scheduler ,
968969            num_inference_steps ,
969970            device ,
970-             timesteps ,
971-             sigmas ,
971+             sigmas = sigmas ,
972972            mu = mu ,
973973        )
974974        timesteps , num_inference_steps  =  self .get_timesteps (num_inference_steps , strength , device )
@@ -1062,11 +1062,12 @@ def __call__(
10621062                latents  =  self .scheduler .step (noise_pred , t , latents , return_dict = False )[0 ]
10631063
10641064                # for 64 channel transformer only. 
1065-                 init_latents_proper  =  image_latents 
10661065                init_mask  =  mask 
10671066                if  i  <  len (timesteps ) -  1 :
10681067                    noise_timestep  =  timesteps [i  +  1 ]
1069-                     init_latents_proper  =  self .scheduler .scale_noise (init_latents_proper , torch .tensor ([noise_timestep ]), noise )
1068+                     init_latents_proper  =  self .scheduler .scale_noise (image_latents , torch .tensor ([noise_timestep ]), noise )
1069+                 else :
1070+                     init_latents_proper  =  image_latents 
10701071                init_latents_proper  =  self ._pack_latents (init_latents_proper , batch_size  *  num_images_per_prompt , num_channels_latents , height_8 , width_8 )
10711072
10721073                latents  =  (1  -  init_mask ) *  init_latents_proper  +  init_mask  *  latents 
0 commit comments