|
1 | | -import tempfile |
2 | 1 | import gc |
| 2 | +import tempfile |
3 | 3 | import unittest |
4 | 4 |
|
5 | | -import torch |
6 | | - |
7 | | -from diffusers import QuantoConfig, FluxTransformer2DModel, FluxPipeline |
8 | | -from diffusers.utils import is_torch_available, is_optimum_quanto_available |
| 5 | +from diffusers import FluxPipeline, FluxTransformer2DModel, QuantoConfig |
| 6 | +from diffusers.models.attention_processor import Attention |
| 7 | +from diffusers.utils import is_optimum_quanto_available, is_torch_available |
9 | 8 | from diffusers.utils.testing_utils import ( |
10 | 9 | nightly, |
11 | 10 | numpy_cosine_similarity_distance, |
12 | 11 | require_accelerate, |
13 | 12 | require_big_gpu_with_torch_cuda, |
14 | 13 | torch_device, |
15 | 14 | ) |
16 | | -from diffusers.models.attention_processor import Attention |
| 15 | + |
17 | 16 |
|
18 | 17 | if is_optimum_quanto_available(): |
19 | 18 | from optimum.quanto import QLinear |
@@ -192,7 +191,11 @@ def test_torch_compile(self): |
192 | 191 | with torch.no_grad(): |
193 | 192 | compiled_model_output = compiled_model(**self.get_dummy_inputs()).sample |
194 | 193 |
|
195 | | - assert torch.allclose(model_output, compiled_model_output, rtol=1e-2, atol=1e-3) |
| 194 | + model_output = model_output.detach().float().cpu().numpy() |
| 195 | + compiled_model_output = compiled_model_output.detach().float().cpu().numpy() |
| 196 | + |
| 197 | + max_diff = numpy_cosine_similarity_distance(model_output.flatten(), compiled_model_output.flatten()) |
| 198 | + assert max_diff < 1e-3 |
196 | 199 |
|
197 | 200 |
|
198 | 201 | class FluxTransformerQuantoMixin(QuantoBaseTesterMixin): |
@@ -275,7 +278,7 @@ def test_model_cpu_offload(self): |
275 | 278 | "hf-internal-testing/tiny-flux-pipe", transformer=transformer, torch_dtype=torch.bfloat16 |
276 | 279 | ) |
277 | 280 | pipe.enable_model_cpu_offload(device=torch_device) |
278 | | - images = pipe("a cat holding a sign that says hello", num_inference_steps=2) |
| 281 | + _ = pipe("a cat holding a sign that says hello", num_inference_steps=2) |
279 | 282 |
|
280 | 283 | def test_training(self): |
281 | 284 | quantization_config = QuantoConfig(**self.get_dummy_init_kwargs()) |
@@ -311,7 +314,6 @@ def test_training(self): |
311 | 314 |
|
312 | 315 | class FluxTransformerFloat8WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase): |
313 | 316 | expected_memory_reduction = 0.3 |
314 | | - _test_torch_compile = True |
315 | 317 |
|
316 | 318 | def get_dummy_init_kwargs(self): |
317 | 319 | return {"weights_dtype": "float8"} |
@@ -341,7 +343,6 @@ def get_dummy_init_kwargs(self): |
341 | 343 |
|
342 | 344 | class FluxTransformerInt4WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase): |
343 | 345 | expected_memory_reduction = 0.55 |
344 | | - _test_torch_compile = True |
345 | 346 |
|
346 | 347 | def get_dummy_init_kwargs(self): |
347 | 348 | return {"weights_dtype": "int4"} |
|
0 commit comments