Skip to content

Commit d91196e

Browse files
Apply style fixes
1 parent 73e4ebf commit d91196e

File tree

4 files changed

+86
-45
lines changed

4 files changed

+86
-45
lines changed

src/diffusers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,8 +381,8 @@
381381
"FluxFillPipeline",
382382
"FluxImg2ImgPipeline",
383383
"FluxInpaintPipeline",
384-
"FluxKontextPipeline",
385384
"FluxKontextInpaintPipeline",
385+
"FluxKontextPipeline",
386386
"FluxPipeline",
387387
"FluxPriorReduxPipeline",
388388
"HiDreamImagePipeline",

src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py

Lines changed: 74 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,19 @@
5252
>>> from diffusers.utils import load_image
5353
5454
>>> prompt = "Change the yellow dinosaur to green one"
55-
>>> img_url = "https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/dinosaur_input.jpeg?raw=true"
56-
>>> mask_url = "https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/dinosaur_mask.png?raw=true"
55+
>>> img_url = (
56+
... "https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/dinosaur_input.jpeg?raw=true"
57+
... )
58+
>>> mask_url = (
59+
... "https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/dinosaur_mask.png?raw=true"
60+
... )
5761
5862
>>> source = load_image(img_url)
5963
>>> mask = load_image(mask_url)
6064
61-
>>> pipe = FluxKontextInpaintPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16)
65+
>>> pipe = FluxKontextInpaintPipeline.from_pretrained(
66+
... "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16
67+
... )
6268
>>> pipe.to("cuda")
6369
6470
>>> image = pipe(prompt=prompt, image=source, mask_image=mask, strength=1.0).images[0]
@@ -71,25 +77,27 @@
7177
>>> from diffusers import FluxKontextInpaintPipeline
7278
>>> from diffusers.utils import load_image
7379
74-
>>> pipe = FluxKontextInpaintPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16)
80+
>>> pipe = FluxKontextInpaintPipeline.from_pretrained(
81+
... "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16
82+
... )
7583
>>> pipe.to("cuda")
7684
7785
>>> prompt = "Replace this ball"
7886
>>> img_url = "https://images.pexels.com/photos/39362/the-ball-stadion-football-the-pitch-39362.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=500"
79-
>>> mask_url = "https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/ball_mask.png?raw=true"
80-
>>> image_reference_url = "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcTah3x6OL_ECMBaZ5ZlJJhNsyC-OSMLWAI-xw&s"
87+
>>> mask_url = (
88+
... "https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/ball_mask.png?raw=true"
89+
... )
90+
>>> image_reference_url = (
91+
... "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcTah3x6OL_ECMBaZ5ZlJJhNsyC-OSMLWAI-xw&s"
92+
... )
8193
8294
>>> source = load_image(img_url)
8395
>>> mask = load_image(mask_url)
8496
>>> image_reference = load_image(image_reference_url)
8597
8698
>>> mask = pipe.mask_processor.blur(mask, blur_factor=12)
8799
>>> image = pipe(
88-
... prompt=prompt,
89-
... image=source,
90-
... mask_image=mask,
91-
... image_reference=image_reference,
92-
... strength=1.0
100+
... prompt=prompt, image=source, mask_image=mask, image_reference=image_reference, strength=1.0
93101
... ).images[0]
94102
>>> image.save("kontext_inpainting_ref.png")
95103
```
@@ -719,7 +727,7 @@ def prepare_latents(
719727
device: torch.device,
720728
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
721729
latents: Optional[torch.Tensor] = None,
722-
image_reference: Optional[torch.Tensor]=None,
730+
image_reference: Optional[torch.Tensor] = None,
723731
):
724732
if isinstance(generator, list) and len(generator) != batch_size:
725733
raise ValueError(
@@ -793,15 +801,18 @@ def prepare_latents(
793801
if image_reference_latents is not None:
794802
image_reference_latent_height, image_reference_latent_width = image_reference_latents.shape[2:]
795803
image_reference_latents = self._pack_latents(
796-
image_reference_latents, batch_size, num_channels_latents, image_reference_latent_height, image_reference_latent_width
804+
image_reference_latents,
805+
batch_size,
806+
num_channels_latents,
807+
image_reference_latent_height,
808+
image_reference_latent_width,
797809
)
798810
image_reference_ids = self._prepare_latent_image_ids(
799811
batch_size, image_reference_latent_height // 2, image_reference_latent_width // 2, device, dtype
800812
)
801813
# image_reference_ids are the same as latent ids with the first dimension set to 1 instead of 0
802814
image_reference_ids[..., 0] = 1
803815

804-
805816
noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width)
806817
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
807818

@@ -945,17 +956,18 @@ def __call__(
945956
946957
Args:
947958
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
948-
`Image`, numpy array or tensor representing an image batch to be be inpainted (which parts of the image to be masked out
949-
with `mask_image` and repainted according to `prompt` and `image_reference`). For both numpy array and pytorch tensor,
950-
the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the expected shape should be
951-
`(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`
952-
It can also accept image latents as `image`, but if passing latents directly it is not encoded again.
953-
image_reference (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
954-
`Image`, numpy array or tensor representing an image batch to be used as the starting point for the masked area. For both
959+
`Image`, numpy array or tensor representing an image batch to be be inpainted (which parts of the image
960+
to be masked out with `mask_image` and repainted according to `prompt` and `image_reference`). For both
955961
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
956-
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)` If it is a numpy array or a
962+
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
957963
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
958964
latents as `image`, but if passing latents directly it is not encoded again.
965+
image_reference (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
966+
`Image`, numpy array or tensor representing an image batch to be used as the starting point for the
967+
masked area. For both numpy array and pytorch tensor, the expected value range is between `[0, 1]` If
968+
it's a tensor or a list or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)` If it is
969+
a numpy array or a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can
970+
also accept image latents as `image`, but if passing latents directly it is not encoded again.
959971
mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
960972
`Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask
961973
are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a
@@ -1134,7 +1146,7 @@ def __call__(
11341146
image_height = image_height // multiple_of * multiple_of
11351147
image = self.image_processor.resize(image, image_height, image_width)
11361148

1137-
#Choose the resolution of the image to be the same as the image
1149+
# Choose the resolution of the image to be the same as the image
11381150
width = image_width
11391151
height = image_height
11401152

@@ -1146,18 +1158,28 @@ def __call__(
11461158
crops_coords = None
11471159
resize_mode = "default"
11481160

1149-
image = self.image_processor.preprocess(image, image_height, image_width, crops_coords=crops_coords, resize_mode=resize_mode)
1161+
image = self.image_processor.preprocess(
1162+
image, image_height, image_width, crops_coords=crops_coords, resize_mode=resize_mode
1163+
)
11501164
else:
11511165
raise ValueError("image must be provided correctly for inpainting")
11521166

11531167
init_image = image.to(dtype=torch.float32)
11541168

1155-
#2.1 Preprocess image_reference
1156-
if image_reference is not None and not (isinstance(image_reference, torch.Tensor) and image_reference.size(1) == self.latent_channels):
1157-
if isinstance(image_reference, list) and isinstance(image_reference[0], torch.Tensor) and image_reference[0].ndim == 4:
1169+
# 2.1 Preprocess image_reference
1170+
if image_reference is not None and not (
1171+
isinstance(image_reference, torch.Tensor) and image_reference.size(1) == self.latent_channels
1172+
):
1173+
if (
1174+
isinstance(image_reference, list)
1175+
and isinstance(image_reference[0], torch.Tensor)
1176+
and image_reference[0].ndim == 4
1177+
):
11581178
image_reference = torch.cat(image_reference, dim=0)
11591179
img_reference = image_reference[0] if isinstance(image_reference, list) else image_reference
1160-
image_reference_height, image_reference_width = self.image_processor.get_default_height_width(img_reference)
1180+
image_reference_height, image_reference_width = self.image_processor.get_default_height_width(
1181+
img_reference
1182+
)
11611183
aspect_ratio = image_reference_width / image_reference_height
11621184
if _auto_resize:
11631185
# Kontext is trained on specific resolutions, using one of them is recommended
@@ -1166,8 +1188,16 @@ def __call__(
11661188
)
11671189
image_reference_width = image_reference_width // multiple_of * multiple_of
11681190
image_reference_height = image_reference_height // multiple_of * multiple_of
1169-
image_reference = self.image_processor.resize(image_reference, image_reference_height, image_reference_width)
1170-
image_reference = self.image_processor.preprocess(image_reference, image_reference_height, image_reference_width, crops_coords=crops_coords, resize_mode=resize_mode)
1191+
image_reference = self.image_processor.resize(
1192+
image_reference, image_reference_height, image_reference_width
1193+
)
1194+
image_reference = self.image_processor.preprocess(
1195+
image_reference,
1196+
image_reference_height,
1197+
image_reference_width,
1198+
crops_coords=crops_coords,
1199+
resize_mode=resize_mode,
1200+
)
11711201
else:
11721202
image_reference = None
11731203

@@ -1248,18 +1278,20 @@ def __call__(
12481278

12491279
# 5. Prepare latent variables
12501280
num_channels_latents = self.transformer.config.in_channels // 4
1251-
latents, image_latents, image_reference_latents, latent_ids, image_ids, image_reference_ids, noise = self.prepare_latents(
1252-
init_image,
1253-
latent_timestep,
1254-
batch_size * num_images_per_prompt,
1255-
num_channels_latents,
1256-
height,
1257-
width,
1258-
prompt_embeds.dtype,
1259-
device,
1260-
generator,
1261-
latents,
1262-
image_reference,
1281+
latents, image_latents, image_reference_latents, latent_ids, image_ids, image_reference_ids, noise = (
1282+
self.prepare_latents(
1283+
init_image,
1284+
latent_timestep,
1285+
batch_size * num_images_per_prompt,
1286+
num_channels_latents,
1287+
height,
1288+
width,
1289+
prompt_embeds.dtype,
1290+
device,
1291+
generator,
1292+
latents,
1293+
image_reference,
1294+
)
12631295
)
12641296

12651297
if image_reference_ids is not None:

src/diffusers/utils/dummy_torch_and_transformers_objects.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,7 @@ def from_config(cls, *args, **kwargs):
706706
def from_pretrained(cls, *args, **kwargs):
707707
requires_backends(cls, ["torch", "transformers"])
708708

709+
709710
class FluxKontextInpaintPipeline(metaclass=DummyObject):
710711
_backends = ["torch", "transformers"]
711712

tests/pipelines/flux/test_pipeline_flux_kontext_inpaint.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,11 +161,19 @@ def test_flux_image_output_shape(self):
161161
for height, width in height_width_pairs:
162162
expected_height = height - height % (pipe.vae_scale_factor * 2)
163163
expected_width = width - width % (pipe.vae_scale_factor * 2)
164-
#Because output shape is the same as the input shape, we need to create a dummy image and mask image
164+
# Because output shape is the same as the input shape, we need to create a dummy image and mask image
165165
image = floats_tensor((1, 3, height, width), rng=random.Random(0)).to(torch_device)
166166
mask_image = torch.ones((1, 1, height, width)).to(torch_device)
167167

168-
inputs.update({"height": height, "width": width, "max_area": height * width, "image": image, "mask_image": mask_image})
168+
inputs.update(
169+
{
170+
"height": height,
171+
"width": width,
172+
"max_area": height * width,
173+
"image": image,
174+
"mask_image": mask_image,
175+
}
176+
)
169177
image = pipe(**inputs).images[0]
170178
output_height, output_width, _ = image.shape
171179
assert (output_height, output_width) == (expected_height, expected_width)

0 commit comments

Comments
 (0)