Skip to content

Commit 73e4ebf

Browse files
committed
feat: Add test case and fix with pytest
1 parent 1f7939d commit 73e4ebf

File tree

2 files changed

+186
-0
lines changed

2 files changed

+186
-0
lines changed

src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,6 +1120,8 @@ def __call__(
11201120

11211121
# 2. Preprocess image
11221122
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
1123+
if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4:
1124+
image = torch.cat(image, dim=0)
11231125
img = image[0] if isinstance(image, list) else image
11241126
image_height, image_width = self.image_processor.get_default_height_width(img)
11251127
aspect_ratio = image_width / image_height
@@ -1152,6 +1154,8 @@ def __call__(
11521154

11531155
#2.1 Preprocess image_reference
11541156
if image_reference is not None and not (isinstance(image_reference, torch.Tensor) and image_reference.size(1) == self.latent_channels):
1157+
if isinstance(image_reference, list) and isinstance(image_reference[0], torch.Tensor) and image_reference[0].ndim == 4:
1158+
image_reference = torch.cat(image_reference, dim=0)
11551159
img_reference = image_reference[0] if isinstance(image_reference, list) else image_reference
11561160
image_reference_height, image_reference_width = self.image_processor.get_default_height_width(img_reference)
11571161
aspect_ratio = image_reference_width / image_reference_height
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
import random
2+
import unittest
3+
4+
import numpy as np
5+
import torch
6+
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
7+
8+
from diffusers import (
9+
AutoencoderKL,
10+
FasterCacheConfig,
11+
FlowMatchEulerDiscreteScheduler,
12+
FluxKontextInpaintPipeline,
13+
FluxTransformer2DModel,
14+
)
15+
from diffusers.utils.testing_utils import floats_tensor, torch_device
16+
17+
from ..test_pipelines_common import (
18+
FasterCacheTesterMixin,
19+
FluxIPAdapterTesterMixin,
20+
PipelineTesterMixin,
21+
PyramidAttentionBroadcastTesterMixin,
22+
)
23+
24+
25+
class FluxKontextInpaintPipelineFastTests(
26+
unittest.TestCase,
27+
PipelineTesterMixin,
28+
FluxIPAdapterTesterMixin,
29+
PyramidAttentionBroadcastTesterMixin,
30+
FasterCacheTesterMixin,
31+
):
32+
pipeline_class = FluxKontextInpaintPipeline
33+
params = frozenset(
34+
["image", "prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]
35+
)
36+
batch_params = frozenset(["image", "prompt"])
37+
38+
# there is no xformers processor for Flux
39+
test_xformers_attention = False
40+
test_layerwise_casting = True
41+
test_group_offloading = True
42+
43+
faster_cache_config = FasterCacheConfig(
44+
spatial_attention_block_skip_range=2,
45+
spatial_attention_timestep_skip_range=(-1, 901),
46+
unconditional_batch_skip_range=2,
47+
attention_weight_callback=lambda _: 0.5,
48+
is_guidance_distilled=True,
49+
)
50+
51+
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
52+
torch.manual_seed(0)
53+
transformer = FluxTransformer2DModel(
54+
patch_size=1,
55+
in_channels=4,
56+
num_layers=num_layers,
57+
num_single_layers=num_single_layers,
58+
attention_head_dim=16,
59+
num_attention_heads=2,
60+
joint_attention_dim=32,
61+
pooled_projection_dim=32,
62+
axes_dims_rope=[4, 4, 8],
63+
)
64+
clip_text_encoder_config = CLIPTextConfig(
65+
bos_token_id=0,
66+
eos_token_id=2,
67+
hidden_size=32,
68+
intermediate_size=37,
69+
layer_norm_eps=1e-05,
70+
num_attention_heads=4,
71+
num_hidden_layers=5,
72+
pad_token_id=1,
73+
vocab_size=1000,
74+
hidden_act="gelu",
75+
projection_dim=32,
76+
)
77+
78+
torch.manual_seed(0)
79+
text_encoder = CLIPTextModel(clip_text_encoder_config)
80+
81+
torch.manual_seed(0)
82+
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
83+
84+
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
85+
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
86+
87+
torch.manual_seed(0)
88+
vae = AutoencoderKL(
89+
sample_size=32,
90+
in_channels=3,
91+
out_channels=3,
92+
block_out_channels=(4,),
93+
layers_per_block=1,
94+
latent_channels=1,
95+
norm_num_groups=1,
96+
use_quant_conv=False,
97+
use_post_quant_conv=False,
98+
shift_factor=0.0609,
99+
scaling_factor=1.5035,
100+
)
101+
102+
scheduler = FlowMatchEulerDiscreteScheduler()
103+
104+
return {
105+
"scheduler": scheduler,
106+
"text_encoder": text_encoder,
107+
"text_encoder_2": text_encoder_2,
108+
"tokenizer": tokenizer,
109+
"tokenizer_2": tokenizer_2,
110+
"transformer": transformer,
111+
"vae": vae,
112+
"image_encoder": None,
113+
"feature_extractor": None,
114+
}
115+
116+
def get_dummy_inputs(self, device, seed=0):
117+
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
118+
mask_image = torch.ones((1, 1, 32, 32)).to(device)
119+
if str(device).startswith("mps"):
120+
generator = torch.manual_seed(seed)
121+
else:
122+
generator = torch.Generator(device="cpu").manual_seed(seed)
123+
124+
inputs = {
125+
"prompt": "A painting of a squirrel eating a burger",
126+
"image": image,
127+
"mask_image": mask_image,
128+
"generator": generator,
129+
"num_inference_steps": 2,
130+
"guidance_scale": 5.0,
131+
"height": 32,
132+
"width": 32,
133+
"max_sequence_length": 48,
134+
"strength": 0.8,
135+
"output_type": "np",
136+
"_auto_resize": False,
137+
}
138+
return inputs
139+
140+
def test_flux_inpaint_different_prompts(self):
141+
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
142+
143+
inputs = self.get_dummy_inputs(torch_device)
144+
output_same_prompt = pipe(**inputs).images[0]
145+
146+
inputs = self.get_dummy_inputs(torch_device)
147+
inputs["prompt_2"] = "a different prompt"
148+
output_different_prompts = pipe(**inputs).images[0]
149+
150+
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
151+
152+
# Outputs should be different here
153+
# For some reasons, they don't show large differences
154+
assert max_diff > 1e-6
155+
156+
def test_flux_image_output_shape(self):
157+
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
158+
inputs = self.get_dummy_inputs(torch_device)
159+
160+
height_width_pairs = [(32, 32), (72, 56)]
161+
for height, width in height_width_pairs:
162+
expected_height = height - height % (pipe.vae_scale_factor * 2)
163+
expected_width = width - width % (pipe.vae_scale_factor * 2)
164+
#Because output shape is the same as the input shape, we need to create a dummy image and mask image
165+
image = floats_tensor((1, 3, height, width), rng=random.Random(0)).to(torch_device)
166+
mask_image = torch.ones((1, 1, height, width)).to(torch_device)
167+
168+
inputs.update({"height": height, "width": width, "max_area": height * width, "image": image, "mask_image": mask_image})
169+
image = pipe(**inputs).images[0]
170+
output_height, output_width, _ = image.shape
171+
assert (output_height, output_width) == (expected_height, expected_width)
172+
173+
def test_flux_true_cfg(self):
174+
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
175+
inputs = self.get_dummy_inputs(torch_device)
176+
inputs.pop("generator")
177+
178+
no_true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0]
179+
inputs["negative_prompt"] = "bad quality"
180+
inputs["true_cfg_scale"] = 2.0
181+
true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0]
182+
assert not np.allclose(no_true_cfg_out, true_cfg_out)

0 commit comments

Comments
 (0)