Skip to content

Commit 6fe2414

Browse files
committed
start adding compilation tests for quantization.
1 parent 16c955c commit 6fe2414

File tree

1 file changed

+34
-0
lines changed

1 file changed

+34
-0
lines changed

tests/quantization/bnb/test_4bit.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
DiffusionPipeline,
2929
FluxControlPipeline,
3030
FluxTransformer2DModel,
31+
PipelineQuantizationConfig,
3132
SD3Transformer2DModel,
3233
)
3334
from diffusers.utils import is_accelerate_version, logging
@@ -44,6 +45,8 @@
4445
require_peft_backend,
4546
require_torch,
4647
require_torch_accelerator,
48+
require_torch_gpu,
49+
require_torch_version_greater_equal,
4750
require_transformers_version_greater,
4851
slow,
4952
torch_device,
@@ -855,3 +858,34 @@ def test_fp4_double_unsafe(self):
855858

856859
def test_fp4_double_safe(self):
857860
self.test_serialization(quant_type="fp4", double_quant=True, safe_serialization=True)
861+
862+
863+
@require_torch_gpu
864+
@slow
865+
class Bnb4BitCompileTests(unittest.TestCase):
866+
def setUp(self):
867+
super().setUp()
868+
gc.collect()
869+
backend_empty_cache(torch_device)
870+
torch.compiler.reset()
871+
872+
def tearDown(self):
873+
super().tearDown()
874+
gc.collect()
875+
backend_empty_cache(torch_device)
876+
torch.compiler.reset()
877+
878+
@require_torch_version_greater_equal("2.8")
879+
def test_torch_compile_4bit(self):
880+
quantization_config = PipelineQuantizationConfig(
881+
quant_backend="bitsandbytes_4bit",
882+
quant_kwargs={"load_in_4bit": True},
883+
components_to_quantize=["transformer"],
884+
)
885+
pipe = DiffusionPipeline.from_pretrained(
886+
"hf-internal-testing/tiny-flux-pipe", quantization_config=quantization_config, torch_dtype=torch.bfloat16
887+
).to("cuda")
888+
pipe.transformer.compile(fullgraph=True)
889+
890+
for _ in range(2):
891+
pipe("a dog", num_inference_steps=4, max_sequence_length=16)

0 commit comments

Comments
 (0)