6464""" 
6565
6666
67+ # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift 
6768def  calculate_shift (
6869    image_seq_len ,
6970    base_seq_len : int  =  256 ,
@@ -136,6 +137,7 @@ def retrieve_timesteps(
136137        timesteps  =  scheduler .timesteps 
137138    return  timesteps , num_inference_steps 
138139
140+ 
139141# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents 
140142def  retrieve_latents (
141143    encoder_output : torch .Tensor , generator : Optional [torch .Generator ] =  None , sample_mode : str  =  "sample" 
@@ -226,6 +228,7 @@ def __init__(
226228        )
227229        self .default_sample_size  =  128 
228230
231+     # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds 
229232    def  _get_t5_prompt_embeds (
230233        self ,
231234        prompt : Union [str , List [str ]] =  None ,
@@ -275,6 +278,7 @@ def _get_t5_prompt_embeds(
275278
276279        return  prompt_embeds 
277280
281+     # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds 
278282    def  _get_clip_prompt_embeds (
279283        self ,
280284        prompt : Union [str , List [str ]],
@@ -318,7 +322,7 @@ def _get_clip_prompt_embeds(
318322        prompt_embeds  =  prompt_embeds .view (batch_size  *  num_images_per_prompt , - 1 )
319323
320324        return  prompt_embeds 
321-      
325+ 
322326    def  prepare_mask_latents (
323327        self ,
324328        mask ,
@@ -364,7 +368,7 @@ def prepare_mask_latents(
364368            masked_image_latents  =  masked_image_latents .repeat (batch_size  //  masked_image_latents .shape [0 ], 1 , 1 , 1 )
365369
366370        # prepare mask for latents 
367-         mask  =  mask [:,0 ,:, :]
371+         mask  =  mask [:,  0 , :,  :]
368372        mask  =  mask .view (batch_size , height , self .vae_scale_factor , width , self .vae_scale_factor )
369373        mask  =  mask .permute (0 , 2 , 4 , 1 , 3 )
370374        mask  =  mask .reshape (batch_size , self .vae_scale_factor  *  self .vae_scale_factor , height , width )
@@ -390,7 +394,7 @@ def prepare_mask_latents(
390394
391395        return  mask , masked_image_latents 
392396
393- 
397+      # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt 
394398    def  encode_prompt (
395399        self ,
396400        prompt : Union [str , List [str ]],
@@ -470,6 +474,7 @@ def encode_prompt(
470474
471475        return  prompt_embeds , pooled_prompt_embeds , text_ids 
472476
477+     # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.check_inputs 
473478    def  check_inputs (
474479        self ,
475480        prompt ,
@@ -521,6 +526,7 @@ def check_inputs(
521526            raise  ValueError (f"`max_sequence_length` cannot be greater than 512 but is { max_sequence_length }  )
522527
523528    @staticmethod  
529+     # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids 
524530    def  _prepare_latent_image_ids (batch_size , height , width , device , dtype ):
525531        latent_image_ids  =  torch .zeros (height , width , 3 )
526532        latent_image_ids [..., 1 ] =  latent_image_ids [..., 1 ] +  torch .arange (height )[:, None ]
@@ -535,6 +541,7 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
535541        return  latent_image_ids .to (device = device , dtype = dtype )
536542
537543    @staticmethod  
544+     # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents 
538545    def  _pack_latents (latents , batch_size , num_channels_latents , height , width ):
539546        latents  =  latents .view (batch_size , num_channels_latents , height  //  2 , 2 , width  //  2 , 2 )
540547        latents  =  latents .permute (0 , 2 , 4 , 1 , 3 , 5 )
@@ -543,6 +550,7 @@ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
543550        return  latents 
544551
545552    @staticmethod  
553+     # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents 
546554    def  _unpack_latents (latents , height , width , vae_scale_factor ):
547555        batch_size , num_patches , channels  =  latents .shape 
548556
@@ -587,6 +595,7 @@ def disable_vae_tiling(self):
587595        """ 
588596        self .vae .disable_tiling ()
589597
598+     # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents 
590599    def  prepare_latents (
591600        self ,
592601        batch_size ,
@@ -644,6 +653,9 @@ def __call__(
644653        self ,
645654        prompt : Union [str , List [str ]] =  None ,
646655        prompt_2 : Optional [Union [str , List [str ]]] =  None ,
656+         image : Optional [torch .FloatTensor ] =  None ,
657+         mask_image : Optional [torch .FloatTensor ] =  None ,
658+         masked_image_latents : Optional [torch .FloatTensor ] =  None ,
647659        height : Optional [int ] =  None ,
648660        width : Optional [int ] =  None ,
649661        num_inference_steps : int  =  28 ,
@@ -660,9 +672,6 @@ def __call__(
660672        callback_on_step_end : Optional [Callable [[int , int , Dict ], None ]] =  None ,
661673        callback_on_step_end_tensor_inputs : List [str ] =  ["latents" ],
662674        max_sequence_length : int  =  512 ,
663-         img_cond : Optional [torch .FloatTensor ] =  None ,
664-         image : Optional [torch .FloatTensor ] =  None ,
665-         mask_image : Optional [torch .FloatTensor ] =  None ,
666675    ):
667676        r""" 
668677        Function invoked when calling the pipeline for generation. 
@@ -674,6 +683,22 @@ def __call__(
674683            prompt_2 (`str` or `List[str]`, *optional*): 
675684                The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is 
676685                will be used instead 
686+             image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): 
687+                 `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both 
688+                 numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list 
689+                 or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a 
690+                 list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image 
691+                 latents as `image`, but if passing latents directly it is not encoded again. 
692+             mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): 
693+                 `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask 
694+                 are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a 
695+                 single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one 
696+                 color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B, 
697+                 H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, 
698+                 1)`, or `(H, W)`. 
699+             mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`): 
700+                 `Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask 
701+                 latents tensor will ge generated by `mask_image`. 
677702            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 
678703                The height in pixels of the generated image. This is set to 1024 by default for the best results. 
679704            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 
@@ -794,18 +819,17 @@ def __call__(
794819            latents ,
795820        )
796821
797-         if  img_cond  is  not None :
798-             img_cond  =  img_cond .to (latents .device )
822+         if  masked_image_latents  is  not None :
823+             masked_image_latents  =  masked_image_latents .to (latents .device )
799824        else :
800- 
801825            if  image  is  not None  and  mask_image  is  not None :
802826                image  =  self .image_processor .preprocess (image )
803827                mask_image  =  self .mask_processor .preprocess (mask_image )
804828                masked_image  =  image  *  (1  -  mask_image )
805829                masked_image  =  masked_image .to (device = device , dtype = prompt_embeds .dtype )
806830
807831                height , width  =  image .shape [- 2 :]
808-              
832+ 
809833                mask , masked_image_latents  =  self .prepare_mask_latents (
810834                    mask_image ,
811835                    masked_image ,
@@ -818,7 +842,7 @@ def __call__(
818842                    device ,
819843                    generator ,
820844                )
821-                 img_cond  =  torch .cat ((masked_image_latents , mask ), dim = - 1 )
845+                 masked_image_latents  =  torch .cat ((masked_image_latents , mask ), dim = - 1 )
822846
823847        # 5. Prepare timesteps 
824848        sigmas  =  np .linspace (1.0 , 1  /  num_inference_steps , num_inference_steps )
@@ -858,7 +882,7 @@ def __call__(
858882                timestep  =  t .expand (latents .shape [0 ]).to (latents .dtype )
859883
860884                noise_pred  =  self .transformer (
861-                     hidden_states = torch .cat ((latents , img_cond ), dim = 2 )  if   img_cond   is   not   None   else   latents ,
885+                     hidden_states = torch .cat ((latents , masked_image_latents ), dim = 2 ),
862886                    timestep = timestep  /  1000 ,
863887                    guidance = guidance ,
864888                    pooled_projections = pooled_prompt_embeds ,
0 commit comments