5252 >>> from diffusers.utils import load_image
5353
5454 >>> prompt = "Change the yellow dinosaur to green one"
55- >>> img_url = "https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/dinosaur_input.jpeg?raw=true"
56- >>> mask_url = "https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/dinosaur_mask.png?raw=true"
55+ >>> img_url = (
56+ ... "https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/dinosaur_input.jpeg?raw=true"
57+ ... )
58+ >>> mask_url = (
59+ ... "https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/dinosaur_mask.png?raw=true"
60+ ... )
5761
5862 >>> source = load_image(img_url)
5963 >>> mask = load_image(mask_url)
6064
61- >>> pipe = FluxKontextInpaintPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16)
65+ >>> pipe = FluxKontextInpaintPipeline.from_pretrained(
66+ ... "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16
67+ ... )
6268 >>> pipe.to("cuda")
6369
6470 >>> image = pipe(prompt=prompt, image=source, mask_image=mask, strength=1.0).images[0]
7177 >>> from diffusers import FluxKontextInpaintPipeline
7278 >>> from diffusers.utils import load_image
7379
74- >>> pipe = FluxKontextInpaintPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16)
80+ >>> pipe = FluxKontextInpaintPipeline.from_pretrained(
81+ ... "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16
82+ ... )
7583 >>> pipe.to("cuda")
7684
7785 >>> prompt = "Replace this ball"
7886 >>> img_url = "https://images.pexels.com/photos/39362/the-ball-stadion-football-the-pitch-39362.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=500"
79- >>> mask_url = "https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/ball_mask.png?raw=true"
80- >>> image_reference_url = "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcTah3x6OL_ECMBaZ5ZlJJhNsyC-OSMLWAI-xw&s"
87+ >>> mask_url = (
88+ ... "https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/ball_mask.png?raw=true"
89+ ... )
90+ >>> image_reference_url = (
91+ ... "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcTah3x6OL_ECMBaZ5ZlJJhNsyC-OSMLWAI-xw&s"
92+ ... )
8193
8294 >>> source = load_image(img_url)
8395 >>> mask = load_image(mask_url)
8496 >>> image_reference = load_image(image_reference_url)
8597
8698 >>> mask = pipe.mask_processor.blur(mask, blur_factor=12)
8799 >>> image = pipe(
88- ... prompt=prompt,
89- ... image=source,
90- ... mask_image=mask,
91- ... image_reference=image_reference,
92- ... strength=1.0
100+ ... prompt=prompt, image=source, mask_image=mask, image_reference=image_reference, strength=1.0
93101 ... ).images[0]
94102 >>> image.save("kontext_inpainting_ref.png")
95103 ```
@@ -719,7 +727,7 @@ def prepare_latents(
719727 device : torch .device ,
720728 generator : Optional [Union [torch .Generator , List [torch .Generator ]]] = None ,
721729 latents : Optional [torch .Tensor ] = None ,
722- image_reference : Optional [torch .Tensor ]= None ,
730+ image_reference : Optional [torch .Tensor ] = None ,
723731 ):
724732 if isinstance (generator , list ) and len (generator ) != batch_size :
725733 raise ValueError (
@@ -793,15 +801,18 @@ def prepare_latents(
793801 if image_reference_latents is not None :
794802 image_reference_latent_height , image_reference_latent_width = image_reference_latents .shape [2 :]
795803 image_reference_latents = self ._pack_latents (
796- image_reference_latents , batch_size , num_channels_latents , image_reference_latent_height , image_reference_latent_width
804+ image_reference_latents ,
805+ batch_size ,
806+ num_channels_latents ,
807+ image_reference_latent_height ,
808+ image_reference_latent_width ,
797809 )
798810 image_reference_ids = self ._prepare_latent_image_ids (
799811 batch_size , image_reference_latent_height // 2 , image_reference_latent_width // 2 , device , dtype
800812 )
801813 # image_reference_ids are the same as latent ids with the first dimension set to 1 instead of 0
802814 image_reference_ids [..., 0 ] = 1
803815
804-
805816 noise = self ._pack_latents (noise , batch_size , num_channels_latents , height , width )
806817 latents = self ._pack_latents (latents , batch_size , num_channels_latents , height , width )
807818
@@ -945,17 +956,18 @@ def __call__(
945956
946957 Args:
947958 image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
948- `Image`, numpy array or tensor representing an image batch to be be inpainted (which parts of the image to be masked out
949- with `mask_image` and repainted according to `prompt` and `image_reference`). For both numpy array and pytorch tensor,
950- the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the expected shape should be
951- `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`
952- It can also accept image latents as `image`, but if passing latents directly it is not encoded again.
953- image_reference (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
954- `Image`, numpy array or tensor representing an image batch to be used as the starting point for the masked area. For both
959+ `Image`, numpy array or tensor representing an image batch to be be inpainted (which parts of the image
960+ to be masked out with `mask_image` and repainted according to `prompt` and `image_reference`). For both
955961 numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
956- or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)` If it is a numpy array or a
962+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
957963 list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
958964 latents as `image`, but if passing latents directly it is not encoded again.
965+ image_reference (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
966+ `Image`, numpy array or tensor representing an image batch to be used as the starting point for the
967+ masked area. For both numpy array and pytorch tensor, the expected value range is between `[0, 1]` If
968+ it's a tensor or a list or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)` If it is
969+ a numpy array or a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can
970+ also accept image latents as `image`, but if passing latents directly it is not encoded again.
959971 mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
960972 `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask
961973 are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a
@@ -1134,7 +1146,7 @@ def __call__(
11341146 image_height = image_height // multiple_of * multiple_of
11351147 image = self .image_processor .resize (image , image_height , image_width )
11361148
1137- #Choose the resolution of the image to be the same as the image
1149+ # Choose the resolution of the image to be the same as the image
11381150 width = image_width
11391151 height = image_height
11401152
@@ -1146,18 +1158,28 @@ def __call__(
11461158 crops_coords = None
11471159 resize_mode = "default"
11481160
1149- image = self .image_processor .preprocess (image , image_height , image_width , crops_coords = crops_coords , resize_mode = resize_mode )
1161+ image = self .image_processor .preprocess (
1162+ image , image_height , image_width , crops_coords = crops_coords , resize_mode = resize_mode
1163+ )
11501164 else :
11511165 raise ValueError ("image must be provided correctly for inpainting" )
11521166
11531167 init_image = image .to (dtype = torch .float32 )
11541168
1155- #2.1 Preprocess image_reference
1156- if image_reference is not None and not (isinstance (image_reference , torch .Tensor ) and image_reference .size (1 ) == self .latent_channels ):
1157- if isinstance (image_reference , list ) and isinstance (image_reference [0 ], torch .Tensor ) and image_reference [0 ].ndim == 4 :
1169+ # 2.1 Preprocess image_reference
1170+ if image_reference is not None and not (
1171+ isinstance (image_reference , torch .Tensor ) and image_reference .size (1 ) == self .latent_channels
1172+ ):
1173+ if (
1174+ isinstance (image_reference , list )
1175+ and isinstance (image_reference [0 ], torch .Tensor )
1176+ and image_reference [0 ].ndim == 4
1177+ ):
11581178 image_reference = torch .cat (image_reference , dim = 0 )
11591179 img_reference = image_reference [0 ] if isinstance (image_reference , list ) else image_reference
1160- image_reference_height , image_reference_width = self .image_processor .get_default_height_width (img_reference )
1180+ image_reference_height , image_reference_width = self .image_processor .get_default_height_width (
1181+ img_reference
1182+ )
11611183 aspect_ratio = image_reference_width / image_reference_height
11621184 if _auto_resize :
11631185 # Kontext is trained on specific resolutions, using one of them is recommended
@@ -1166,8 +1188,16 @@ def __call__(
11661188 )
11671189 image_reference_width = image_reference_width // multiple_of * multiple_of
11681190 image_reference_height = image_reference_height // multiple_of * multiple_of
1169- image_reference = self .image_processor .resize (image_reference , image_reference_height , image_reference_width )
1170- image_reference = self .image_processor .preprocess (image_reference , image_reference_height , image_reference_width , crops_coords = crops_coords , resize_mode = resize_mode )
1191+ image_reference = self .image_processor .resize (
1192+ image_reference , image_reference_height , image_reference_width
1193+ )
1194+ image_reference = self .image_processor .preprocess (
1195+ image_reference ,
1196+ image_reference_height ,
1197+ image_reference_width ,
1198+ crops_coords = crops_coords ,
1199+ resize_mode = resize_mode ,
1200+ )
11711201 else :
11721202 image_reference = None
11731203
@@ -1248,18 +1278,20 @@ def __call__(
12481278
12491279 # 5. Prepare latent variables
12501280 num_channels_latents = self .transformer .config .in_channels // 4
1251- latents , image_latents , image_reference_latents , latent_ids , image_ids , image_reference_ids , noise = self .prepare_latents (
1252- init_image ,
1253- latent_timestep ,
1254- batch_size * num_images_per_prompt ,
1255- num_channels_latents ,
1256- height ,
1257- width ,
1258- prompt_embeds .dtype ,
1259- device ,
1260- generator ,
1261- latents ,
1262- image_reference ,
1281+ latents , image_latents , image_reference_latents , latent_ids , image_ids , image_reference_ids , noise = (
1282+ self .prepare_latents (
1283+ init_image ,
1284+ latent_timestep ,
1285+ batch_size * num_images_per_prompt ,
1286+ num_channels_latents ,
1287+ height ,
1288+ width ,
1289+ prompt_embeds .dtype ,
1290+ device ,
1291+ generator ,
1292+ latents ,
1293+ image_reference ,
1294+ )
12631295 )
12641296
12651297 if image_reference_ids is not None :
0 commit comments