Skip to content
12 changes: 12 additions & 0 deletions src/diffusers/utils/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,18 @@ def decorator(test_case):
return decorator


def require_torch_version_greater(torch_version):
"""Decorator marking a test that requires torch with a specific version greater."""

def decorator(test_case):
correct_torch_version = is_torch_available() and is_torch_version(">", torch_version)
return unittest.skipUnless(
correct_torch_version, f"test requires torch with the version greater than {torch_version}"
)(test_case)

return decorator


def require_torch_gpu(test_case):
"""Decorator marking a test that requires CUDA and PyTorch."""
return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")(
Expand Down
43 changes: 43 additions & 0 deletions tests/quantization/bnb/test_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
FluxTransformer2DModel,
SD3Transformer2DModel,
)
from diffusers.quantizers import PipelineQuantizationConfig
from diffusers.utils import is_accelerate_version, logging
from diffusers.utils.testing_utils import (
CaptureLogger,
Expand All @@ -44,6 +45,8 @@
require_peft_backend,
require_torch,
require_torch_accelerator,
require_torch_gpu,
require_torch_version_greater,
require_transformers_version_greater,
slow,
torch_device,
Expand Down Expand Up @@ -855,3 +858,43 @@ def test_fp4_double_unsafe(self):

def test_fp4_double_safe(self):
self.test_serialization(quant_type="fp4", double_quant=True, safe_serialization=True)


@require_torch_gpu
@slow
class Bnb4BitCompileTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
torch.compiler.reset()

def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
torch.compiler.reset()

@require_torch_version_greater("2.7.1")
def test_torch_compile_4bit(self):
torch._dynamo.config.capture_dynamic_output_shape_ops = True

quantization_config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_4bit",
quant_kwargs={
"load_in_4bit": True,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_compute_dtype": torch.bfloat16,
},
components_to_quantize=["transformer"],
)
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-3-medium-diffusers",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
).to("cuda")
pipe.transformer.compile(fullgraph=True)

for _ in range(2):
# with torch._dynamo.config.patch(error_on_recompile=True):
pipe("a dog", num_inference_steps=4, max_sequence_length=16, height=256, width=256)
Loading