Skip to content

Commit 36159dd

Browse files
committed
add tests
1 parent 01e521a commit 36159dd

File tree

2 files changed

+187
-6
lines changed

2 files changed

+187
-6
lines changed

src/diffusers/pipelines/flux/pipeline_flux_kontext.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,7 @@ def __call__(
740740
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
741741
max_sequence_length: int = 512,
742742
max_area: int = 1024**2,
743+
_auto_resize: bool = True,
743744
):
744745
r"""
745746
Function invoked when calling the pipeline for generation.
@@ -937,13 +938,16 @@ def __call__(
937938

938939
# 3. Preprocess image
939940
if not torch.is_tensor(image) or image.size(1) == self.latent_channels:
940-
image_width, image_height = self.image_processor.get_default_height_width(image)
941+
if isinstance(image, list):
942+
image_width, image_height = self.image_processor.get_default_height_width(image[0])
943+
else:
944+
image_width, image_height = self.image_processor.get_default_height_width(image)
941945
aspect_ratio = image_width / image_height
942-
943-
# Kontext is trained on specific resolutions, using one of them is recommended
944-
_, image_width, image_height = min(
945-
(abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
946-
)
946+
if _auto_resize:
947+
# Kontext is trained on specific resolutions, using one of them is recommended
948+
_, image_width, image_height = min(
949+
(abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
950+
)
947951
image_width = image_width // multiple_of * multiple_of
948952
image_height = image_height // multiple_of * multiple_of
949953
image = self.image_processor.resize(image, image_height, image_width)
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
import unittest
2+
3+
import numpy as np
4+
import PIL.Image
5+
import torch
6+
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
7+
8+
from diffusers import (
9+
AutoencoderKL,
10+
FasterCacheConfig,
11+
FlowMatchEulerDiscreteScheduler,
12+
FluxKontextPipeline,
13+
FluxTransformer2DModel,
14+
)
15+
from diffusers.utils.testing_utils import torch_device
16+
17+
from ..test_pipelines_common import (
18+
FasterCacheTesterMixin,
19+
FluxIPAdapterTesterMixin,
20+
PipelineTesterMixin,
21+
PyramidAttentionBroadcastTesterMixin,
22+
)
23+
24+
25+
class FluxKontextPipelineFastTests(
26+
unittest.TestCase,
27+
PipelineTesterMixin,
28+
FluxIPAdapterTesterMixin,
29+
PyramidAttentionBroadcastTesterMixin,
30+
FasterCacheTesterMixin,
31+
):
32+
pipeline_class = FluxKontextPipeline
33+
params = frozenset(
34+
["image", "prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]
35+
)
36+
batch_params = frozenset(["image", "prompt"])
37+
38+
# there is no xformers processor for Flux
39+
test_xformers_attention = False
40+
test_layerwise_casting = True
41+
test_group_offloading = True
42+
43+
faster_cache_config = FasterCacheConfig(
44+
spatial_attention_block_skip_range=2,
45+
spatial_attention_timestep_skip_range=(-1, 901),
46+
unconditional_batch_skip_range=2,
47+
attention_weight_callback=lambda _: 0.5,
48+
is_guidance_distilled=True,
49+
)
50+
51+
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
52+
torch.manual_seed(0)
53+
transformer = FluxTransformer2DModel(
54+
patch_size=1,
55+
in_channels=4,
56+
num_layers=num_layers,
57+
num_single_layers=num_single_layers,
58+
attention_head_dim=16,
59+
num_attention_heads=2,
60+
joint_attention_dim=32,
61+
pooled_projection_dim=32,
62+
axes_dims_rope=[4, 4, 8],
63+
)
64+
clip_text_encoder_config = CLIPTextConfig(
65+
bos_token_id=0,
66+
eos_token_id=2,
67+
hidden_size=32,
68+
intermediate_size=37,
69+
layer_norm_eps=1e-05,
70+
num_attention_heads=4,
71+
num_hidden_layers=5,
72+
pad_token_id=1,
73+
vocab_size=1000,
74+
hidden_act="gelu",
75+
projection_dim=32,
76+
)
77+
78+
torch.manual_seed(0)
79+
text_encoder = CLIPTextModel(clip_text_encoder_config)
80+
81+
torch.manual_seed(0)
82+
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
83+
84+
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
85+
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
86+
87+
torch.manual_seed(0)
88+
vae = AutoencoderKL(
89+
sample_size=32,
90+
in_channels=3,
91+
out_channels=3,
92+
block_out_channels=(4,),
93+
layers_per_block=1,
94+
latent_channels=1,
95+
norm_num_groups=1,
96+
use_quant_conv=False,
97+
use_post_quant_conv=False,
98+
shift_factor=0.0609,
99+
scaling_factor=1.5035,
100+
)
101+
102+
scheduler = FlowMatchEulerDiscreteScheduler()
103+
104+
return {
105+
"scheduler": scheduler,
106+
"text_encoder": text_encoder,
107+
"text_encoder_2": text_encoder_2,
108+
"tokenizer": tokenizer,
109+
"tokenizer_2": tokenizer_2,
110+
"transformer": transformer,
111+
"vae": vae,
112+
"image_encoder": None,
113+
"feature_extractor": None,
114+
}
115+
116+
def get_dummy_inputs(self, device, seed=0):
117+
if str(device).startswith("mps"):
118+
generator = torch.manual_seed(seed)
119+
else:
120+
generator = torch.Generator(device="cpu").manual_seed(seed)
121+
122+
image = PIL.Image.new("RGB", (32, 32), 0)
123+
inputs = {
124+
"image": image,
125+
"prompt": "A painting of a squirrel eating a burger",
126+
"generator": generator,
127+
"num_inference_steps": 2,
128+
"guidance_scale": 5.0,
129+
"height": 8,
130+
"width": 8,
131+
"max_area": 8 * 8,
132+
"max_sequence_length": 48,
133+
"output_type": "np",
134+
"_auto_resize": False,
135+
}
136+
return inputs
137+
138+
def test_flux_different_prompts(self):
139+
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
140+
141+
inputs = self.get_dummy_inputs(torch_device)
142+
output_same_prompt = pipe(**inputs).images[0]
143+
144+
inputs = self.get_dummy_inputs(torch_device)
145+
inputs["prompt_2"] = "a different prompt"
146+
output_different_prompts = pipe(**inputs).images[0]
147+
148+
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
149+
150+
# Outputs should be different here
151+
# For some reasons, they don't show large differences
152+
assert max_diff > 1e-6
153+
154+
def test_flux_image_output_shape(self):
155+
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
156+
inputs = self.get_dummy_inputs(torch_device)
157+
158+
height_width_pairs = [(32, 32), (72, 57)]
159+
for height, width in height_width_pairs:
160+
expected_height = height - height % (pipe.vae_scale_factor * 2)
161+
expected_width = width - width % (pipe.vae_scale_factor * 2)
162+
163+
inputs.update({"height": height, "width": width, "max_area": height * width})
164+
image = pipe(**inputs).images[0]
165+
output_height, output_width, _ = image.shape
166+
assert (output_height, output_width) == (expected_height, expected_width)
167+
168+
def test_flux_true_cfg(self):
169+
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
170+
inputs = self.get_dummy_inputs(torch_device)
171+
inputs.pop("generator")
172+
173+
no_true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0]
174+
inputs["negative_prompt"] = "bad quality"
175+
inputs["true_cfg_scale"] = 2.0
176+
true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0]
177+
assert not np.allclose(no_true_cfg_out, true_cfg_out)

0 commit comments

Comments
 (0)