55from  typing  import  Any , Callable , Dict , List , Optional , Union 
66
77import  numpy  as  np 
8+ import  PIL .Image 
89import  torch 
910from  transformers  import  (
1011    CLIPImageProcessor ,
4445
4546EXAMPLE_DOC_STRING  =  """ 
4647    Examples: 
48+         # Inpainting with text only 
4749        ```py 
4850        >>> import torch 
49-         >>> from diffusers import FluxKontextPipeline  
51+         >>> from diffusers import FluxKontextInpaintPipeline  
5052        >>> from diffusers.utils import load_image 
5153
52-         >>> pipe = FluxKontextPipeline.from_pretrained( 
53-         ...     "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16 
54-         ... ) 
54+         >>> prompt = "Change the yellow dinosaur to green one" 
55+         >>> img_url = "https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/dinosaur_input.jpeg?raw=true" 
56+         >>> mask_url = "https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/dinosaur_mask.png?raw=true" 
57+ 
58+         >>> source = load_image(img_url) 
59+         >>> mask = load_image(mask_url) 
60+ 
61+         >>> pipe = FluxKontextInpaintPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16) 
62+         >>> pipe.to("cuda") 
63+ 
64+         >>> image = pipe(prompt=prompt, image=source, mask_image=mask, strength=1.0).images[0] 
65+         >>> image.save("kontext_inpainting_normal.png") 
66+         ``` 
67+ 
68+         # Inpainting with image conditioning 
69+         ```py 
70+         >>> import torch 
71+         >>> from diffusers import FluxKontextInpaintPipeline 
72+         >>> from diffusers.utils import load_image 
73+ 
74+         >>> pipe = FluxKontextInpaintPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16) 
5575        >>> pipe.to("cuda") 
5676
57-         >>> image = load_image( 
58-         ...     "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png" 
59-         ... ).convert("RGB") 
60-         >>> prompt = "Make Pikachu hold a sign that says 'Black Forest Labs is awesome', yarn art style, detailed, vibrant colors" 
77+         >>> prompt = "Replace this ball" 
78+         >>> img_url = "https://images.pexels.com/photos/39362/the-ball-stadion-football-the-pitch-39362.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=500" 
79+         >>> mask_url = "https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/ball_mask.png?raw=true" 
80+         >>> image_reference_url = "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcTah3x6OL_ECMBaZ5ZlJJhNsyC-OSMLWAI-xw&s" 
81+ 
82+         >>> source = load_image(img_url) 
83+         >>> mask = load_image(mask_url) 
84+         >>> image_reference = load_image(image_reference_url) 
85+ 
86+         >>> mask = pipe.mask_processor.blur(mask, blur_factor=12) 
6187        >>> image = pipe( 
62-         ...     image=image, 
6388        ...     prompt=prompt, 
64-         ...     guidance_scale=2.5, 
65-         ...     generator=torch.Generator().manual_seed(42), 
89+         ...     image=source, 
90+         ...     mask_image=mask, 
91+         ...     image_reference=image_reference, 
92+         ...     strength=1.0 
6693        ... ).images[0] 
67-         >>> image.save("output .png") 
94+         >>> image.save("kontext_inpainting_ref .png") 
6895        ``` 
6996""" 
7097
@@ -250,7 +277,7 @@ def __init__(
250277            do_normalize = False ,
251278            do_binarize = True ,
252279            do_convert_grayscale = True ,
253-         )  
280+         )
254281
255282        self .tokenizer_max_length  =  (
256283            self .tokenizer .model_max_length  if  hasattr (self , "tokenizer" ) and  self .tokenizer  is  not None  else  77 
@@ -780,6 +807,7 @@ def prepare_latents(
780807
781808        return  latents , image_latents , image_reference_latents , latent_ids , image_ids , image_reference_ids , noise 
782809
810+     # Copied from diffusers.pipelines.flux.pipeline_flux_inpaint.FluxInpaintPipeline.prepare_mask_latents 
783811    def  prepare_mask_latents (
784812        self ,
785813        mask ,
@@ -880,7 +908,6 @@ def __call__(
880908        image : Optional [PipelineImageInput ] =  None ,
881909        image_reference : Optional [PipelineImageInput ] =  None ,
882910        mask_image : PipelineImageInput  =  None ,
883-         masked_image_latents : PipelineImageInput  =  None ,
884911        prompt : Union [str , List [str ]] =  None ,
885912        prompt_2 : Optional [Union [str , List [str ]]] =  None ,
886913        negative_prompt : Union [str , List [str ]] =  None ,
@@ -918,13 +945,13 @@ def __call__(
918945
919946        Args: 
920947            image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): 
921-                 `Image`, numpy array or tensor representing an image batch to be used as  the starting point. For both  
922-                 numpy array  and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list  
923-                 or tensors,  the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a  
924-                 list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image  
925-                 latents as `image`, but if passing latents directly it is not encoded again. 
948+                 `Image`, numpy array or tensor representing an image batch to be be inpainted (which parts of  the image to be masked out  
949+                 with `mask_image`  and repainted according to `prompt` and `image_reference`). For both numpy array and pytorch tensor,  
950+                 the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the expected shape should be  
951+                 `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a  list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` 
952+                 It can also accept image  latents as `image`, but if passing latents directly it is not encoded again. 
926953            image_reference (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): 
927-                 `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both 
954+                 `Image`, numpy array or tensor representing an image batch to be used as the starting point for the masked area . For both 
928955                numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list 
929956                or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)` If it is a numpy array or a 
930957                list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image 
@@ -936,9 +963,6 @@ def __call__(
936963                color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B, 
937964                H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, 
938965                1)`, or `(H, W)`. 
939-             mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`): 
940-                 `Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask 
941-                 latents tensor will ge generated by `mask_image`.            
942966            prompt (`str` or `List[str]`, *optional*): 
943967                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. 
944968                instead. 
@@ -1121,8 +1145,10 @@ def __call__(
11211145                resize_mode  =  "default" 
11221146
11231147            image  =  self .image_processor .preprocess (image , image_height , image_width , crops_coords = crops_coords , resize_mode = resize_mode )
1148+         else :
1149+             raise  ValueError ("image must be provided correctly for inpainting" )
11241150
1125-         init_image  =  image .to (dtype = torch .float32 )   
1151+         init_image  =  image .to (dtype = torch .float32 )
11261152
11271153        #2.1 Preprocess image_reference 
11281154        if  image_reference  is  not None  and  not  (isinstance (image_reference , torch .Tensor ) and  image_reference .size (1 ) ==  self .latent_channels ):
@@ -1138,6 +1164,8 @@ def __call__(
11381164            image_reference_height  =  image_reference_height  //  multiple_of  *  multiple_of 
11391165            image_reference  =  self .image_processor .resize (image_reference , image_reference_height , image_reference_width )
11401166            image_reference  =  self .image_processor .preprocess (image_reference , image_reference_height , image_reference_width , crops_coords = crops_coords , resize_mode = resize_mode )
1167+         else :
1168+             image_reference  =  None 
11411169
11421170        # 3. Define call parameters 
11431171        if  prompt  is  not None  and  isinstance (prompt , str ):
@@ -1174,7 +1202,7 @@ def __call__(
11741202            (
11751203                negative_prompt_embeds ,
11761204                negative_pooled_prompt_embeds ,
1177-                 negative_text_ids ,      
1205+                 negative_text_ids ,
11781206            ) =  self .encode_prompt (
11791207                prompt = negative_prompt ,
11801208                prompt_2 = negative_prompt_2 ,
@@ -1239,12 +1267,9 @@ def __call__(
12391267            mask_image , height = height , width = width , resize_mode = resize_mode , crops_coords = crops_coords 
12401268        )
12411269
1242-         if  masked_image_latents  is  None :
1243-             masked_image  =  init_image  *  (mask_condition  <  0.5 )
1244-         else :
1245-             masked_image  =  masked_image_latents 
1270+         masked_image  =  init_image  *  (mask_condition  <  0.5 )
12461271
1247-         mask , masked_image_latents  =  self .prepare_mask_latents (
1272+         mask , _  =  self .prepare_mask_latents (
12481273            mask_condition ,
12491274            masked_image ,
12501275            batch_size ,
@@ -1355,7 +1380,7 @@ def __call__(
13551380                    init_latents_proper  =  self .scheduler .scale_noise (
13561381                        init_latents_proper , torch .tensor ([noise_timestep ]), noise 
13571382                    )
1358-                  
1383+ 
13591384                latents  =  (1  -  init_mask ) *  init_latents_proper  +  init_mask  *  latents 
13601385
13611386                if  latents .dtype  !=  latents_dtype :
0 commit comments