Skip to content

Commit 17b678f

Browse files
authored
Merge branch 'main' into to-single-file/flux
2 parents b8f7fe6 + 06fd427 commit 17b678f

File tree

3 files changed

+40
-84
lines changed

3 files changed

+40
-84
lines changed

examples/controlnet/train_controlnet_sd3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1330,7 +1330,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
13301330
# controlnet(s) inference
13311331
controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
13321332
controlnet_image = vae.encode(controlnet_image).latent_dist.sample()
1333-
controlnet_image = controlnet_image * vae.config.scaling_factor
1333+
controlnet_image = (controlnet_image - vae.config.shift_factor) * vae.config.scaling_factor
13341334

13351335
control_block_res_samples = controlnet(
13361336
hidden_states=noisy_model_input,

examples/server/requirements.txt

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
# This file was autogenerated by uv via the following command:
22
# uv pip compile requirements.in -o requirements.txt
3-
aiohappyeyeballs==2.4.3
3+
aiohappyeyeballs==2.6.1
44
# via aiohttp
5-
aiohttp==3.10.10
5+
aiohttp==3.12.14
66
# via -r requirements.in
7-
aiosignal==1.3.1
7+
aiosignal==1.4.0
88
# via aiohttp
99
annotated-types==0.7.0
1010
# via pydantic
@@ -29,7 +29,6 @@ filelock==3.16.1
2929
# huggingface-hub
3030
# torch
3131
# transformers
32-
# triton
3332
frozenlist==1.5.0
3433
# via
3534
# aiohttp
@@ -111,7 +110,9 @@ prometheus-client==0.21.0
111110
prometheus-fastapi-instrumentator==7.0.0
112111
# via -r requirements.in
113112
propcache==0.2.0
114-
# via yarl
113+
# via
114+
# aiohttp
115+
# yarl
115116
py-consul==1.5.3
116117
# via -r requirements.in
117118
pydantic==2.9.2
@@ -155,7 +156,9 @@ triton==3.3.0
155156
# via torch
156157
typing-extensions==4.12.2
157158
# via
159+
# aiosignal
158160
# anyio
161+
# exceptiongroup
159162
# fastapi
160163
# huggingface-hub
161164
# multidict
@@ -168,5 +171,5 @@ urllib3==2.5.0
168171
# via requests
169172
uvicorn==0.32.0
170173
# via -r requirements.in
171-
yarl==1.16.0
174+
yarl==1.18.3
172175
# via aiohttp

tests/pipelines/flux/test_pipeline_flux.py

Lines changed: 30 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def test_flux_different_prompts(self):
154154

155155
# Outputs should be different here
156156
# For some reasons, they don't show large differences
157-
assert max_diff > 1e-6
157+
self.assertGreater(max_diff, 1e-6, "Outputs should be different for different prompts.")
158158

159159
def test_fused_qkv_projections(self):
160160
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -184,14 +184,17 @@ def test_fused_qkv_projections(self):
184184
image = pipe(**inputs).images
185185
image_slice_disabled = image[0, -3:, -3:, -1]
186186

187-
assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
188-
"Fusion of QKV projections shouldn't affect the outputs."
187+
self.assertTrue(
188+
np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3),
189+
("Fusion of QKV projections shouldn't affect the outputs."),
189190
)
190-
assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
191-
"Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
191+
self.assertTrue(
192+
np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3),
193+
("Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."),
192194
)
193-
assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
194-
"Original outputs should match when fused QKV projections are disabled."
195+
self.assertTrue(
196+
np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2),
197+
("Original outputs should match when fused QKV projections are disabled."),
195198
)
196199

197200
def test_flux_image_output_shape(self):
@@ -206,7 +209,11 @@ def test_flux_image_output_shape(self):
206209
inputs.update({"height": height, "width": width})
207210
image = pipe(**inputs).images[0]
208211
output_height, output_width, _ = image.shape
209-
assert (output_height, output_width) == (expected_height, expected_width)
212+
self.assertEqual(
213+
(output_height, output_width),
214+
(expected_height, expected_width),
215+
f"Output shape {image.shape} does not match expected shape {(expected_height, expected_width)}",
216+
)
210217

211218
def test_flux_true_cfg(self):
212219
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
@@ -217,7 +224,9 @@ def test_flux_true_cfg(self):
217224
inputs["negative_prompt"] = "bad quality"
218225
inputs["true_cfg_scale"] = 2.0
219226
true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0]
220-
assert not np.allclose(no_true_cfg_out, true_cfg_out)
227+
self.assertFalse(
228+
np.allclose(no_true_cfg_out, true_cfg_out), "Outputs should be different when true_cfg_scale is set."
229+
)
221230

222231

223232
@nightly
@@ -266,45 +275,17 @@ def test_flux_inference(self):
266275

267276
image = pipe(**inputs).images[0]
268277
image_slice = image[0, :10, :10]
278+
# fmt: off
269279
expected_slice = np.array(
270-
[
271-
0.3242,
272-
0.3203,
273-
0.3164,
274-
0.3164,
275-
0.3125,
276-
0.3125,
277-
0.3281,
278-
0.3242,
279-
0.3203,
280-
0.3301,
281-
0.3262,
282-
0.3242,
283-
0.3281,
284-
0.3242,
285-
0.3203,
286-
0.3262,
287-
0.3262,
288-
0.3164,
289-
0.3262,
290-
0.3281,
291-
0.3184,
292-
0.3281,
293-
0.3281,
294-
0.3203,
295-
0.3281,
296-
0.3281,
297-
0.3164,
298-
0.3320,
299-
0.3320,
300-
0.3203,
301-
],
280+
[0.3242, 0.3203, 0.3164, 0.3164, 0.3125, 0.3125, 0.3281, 0.3242, 0.3203, 0.3301, 0.3262, 0.3242, 0.3281, 0.3242, 0.3203, 0.3262, 0.3262, 0.3164, 0.3262, 0.3281, 0.3184, 0.3281, 0.3281, 0.3203, 0.3281, 0.3281, 0.3164, 0.3320, 0.3320, 0.3203],
302281
dtype=np.float32,
303282
)
283+
# fmt: on
304284

305285
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
306-
307-
assert max_diff < 1e-4
286+
self.assertLess(
287+
max_diff, 1e-4, f"Image slice is different from expected slice: {image_slice} != {expected_slice}"
288+
)
308289

309290

310291
@slow
@@ -374,42 +355,14 @@ def test_flux_ip_adapter_inference(self):
374355
image = pipe(**inputs).images[0]
375356
image_slice = image[0, :10, :10]
376357

358+
# fmt: off
377359
expected_slice = np.array(
378-
[
379-
0.1855,
380-
0.1680,
381-
0.1406,
382-
0.1953,
383-
0.1699,
384-
0.1465,
385-
0.2012,
386-
0.1738,
387-
0.1484,
388-
0.2051,
389-
0.1797,
390-
0.1523,
391-
0.2012,
392-
0.1719,
393-
0.1445,
394-
0.2070,
395-
0.1777,
396-
0.1465,
397-
0.2090,
398-
0.1836,
399-
0.1484,
400-
0.2129,
401-
0.1875,
402-
0.1523,
403-
0.2090,
404-
0.1816,
405-
0.1484,
406-
0.2110,
407-
0.1836,
408-
0.1543,
409-
],
360+
[0.1855, 0.1680, 0.1406, 0.1953, 0.1699, 0.1465, 0.2012, 0.1738, 0.1484, 0.2051, 0.1797, 0.1523, 0.2012, 0.1719, 0.1445, 0.2070, 0.1777, 0.1465, 0.2090, 0.1836, 0.1484, 0.2129, 0.1875, 0.1523, 0.2090, 0.1816, 0.1484, 0.2110, 0.1836, 0.1543],
410361
dtype=np.float32,
411362
)
363+
# fmt: on
412364

413365
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
414-
415-
assert max_diff < 1e-4, f"{image_slice} != {expected_slice}"
366+
self.assertLess(
367+
max_diff, 1e-4, f"Image slice is different from expected slice: {image_slice} != {expected_slice}"
368+
)

0 commit comments

Comments
 (0)