Skip to content

Commit 29cca99

Browse files
committed
fixes
1 parent 6fe2414 commit 29cca99

File tree

2 files changed

+27
-6
lines changed

2 files changed

+27
-6
lines changed

src/diffusers/utils/testing_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,18 @@ def decorator(test_case):
291291
return decorator
292292

293293

294+
def require_torch_version_greater(torch_version):
295+
"""Decorator marking a test that requires torch with a specific version greater."""
296+
297+
def decorator(test_case):
298+
correct_torch_version = is_torch_available() and is_torch_version(">", torch_version)
299+
return unittest.skipUnless(
300+
correct_torch_version, f"test requires torch with the version greater than {torch_version}"
301+
)(test_case)
302+
303+
return decorator
304+
305+
294306
def require_torch_gpu(test_case):
295307
"""Decorator marking a test that requires CUDA and PyTorch."""
296308
return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")(

tests/quantization/bnb/test_4bit.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@
2828
DiffusionPipeline,
2929
FluxControlPipeline,
3030
FluxTransformer2DModel,
31-
PipelineQuantizationConfig,
3231
SD3Transformer2DModel,
3332
)
33+
from diffusers.quantizers import PipelineQuantizationConfig
3434
from diffusers.utils import is_accelerate_version, logging
3535
from diffusers.utils.testing_utils import (
3636
CaptureLogger,
@@ -46,7 +46,7 @@
4646
require_torch,
4747
require_torch_accelerator,
4848
require_torch_gpu,
49-
require_torch_version_greater_equal,
49+
require_torch_version_greater,
5050
require_transformers_version_greater,
5151
slow,
5252
torch_device,
@@ -875,17 +875,26 @@ def tearDown(self):
875875
backend_empty_cache(torch_device)
876876
torch.compiler.reset()
877877

878-
@require_torch_version_greater_equal("2.8")
878+
@require_torch_version_greater("2.7.1")
879879
def test_torch_compile_4bit(self):
880+
torch._dynamo.config.capture_dynamic_output_shape_ops = True
881+
880882
quantization_config = PipelineQuantizationConfig(
881883
quant_backend="bitsandbytes_4bit",
882-
quant_kwargs={"load_in_4bit": True},
884+
quant_kwargs={
885+
"load_in_4bit": True,
886+
"bnb_4bit_quant_type": "nf4",
887+
"bnb_4bit_compute_dtype": torch.bfloat16,
888+
},
883889
components_to_quantize=["transformer"],
884890
)
885891
pipe = DiffusionPipeline.from_pretrained(
886-
"hf-internal-testing/tiny-flux-pipe", quantization_config=quantization_config, torch_dtype=torch.bfloat16
892+
"stabilityai/stable-diffusion-3-medium-diffusers",
893+
quantization_config=quantization_config,
894+
torch_dtype=torch.bfloat16,
887895
).to("cuda")
888896
pipe.transformer.compile(fullgraph=True)
889897

890898
for _ in range(2):
891-
pipe("a dog", num_inference_steps=4, max_sequence_length=16)
899+
# with torch._dynamo.config.patch(error_on_recompile=True):
900+
pipe("a dog", num_inference_steps=4, max_sequence_length=16, height=256, width=256)

0 commit comments

Comments
 (0)