|  | 
| 28 | 28 |     DiffusionPipeline, | 
| 29 | 29 |     FluxControlPipeline, | 
| 30 | 30 |     FluxTransformer2DModel, | 
|  | 31 | +    PipelineQuantizationConfig, | 
| 31 | 32 |     SD3Transformer2DModel, | 
| 32 | 33 | ) | 
| 33 | 34 | from diffusers.utils import is_accelerate_version, logging | 
|  | 
| 44 | 45 |     require_peft_backend, | 
| 45 | 46 |     require_torch, | 
| 46 | 47 |     require_torch_accelerator, | 
|  | 48 | +    require_torch_gpu, | 
|  | 49 | +    require_torch_version_greater_equal, | 
| 47 | 50 |     require_transformers_version_greater, | 
| 48 | 51 |     slow, | 
| 49 | 52 |     torch_device, | 
| @@ -855,3 +858,34 @@ def test_fp4_double_unsafe(self): | 
| 855 | 858 | 
 | 
| 856 | 859 |     def test_fp4_double_safe(self): | 
| 857 | 860 |         self.test_serialization(quant_type="fp4", double_quant=True, safe_serialization=True) | 
|  | 861 | + | 
|  | 862 | + | 
|  | 863 | +@require_torch_gpu | 
|  | 864 | +@slow | 
|  | 865 | +class Bnb4BitCompileTests(unittest.TestCase): | 
|  | 866 | +    def setUp(self): | 
|  | 867 | +        super().setUp() | 
|  | 868 | +        gc.collect() | 
|  | 869 | +        backend_empty_cache(torch_device) | 
|  | 870 | +        torch.compiler.reset() | 
|  | 871 | + | 
|  | 872 | +    def tearDown(self): | 
|  | 873 | +        super().tearDown() | 
|  | 874 | +        gc.collect() | 
|  | 875 | +        backend_empty_cache(torch_device) | 
|  | 876 | +        torch.compiler.reset() | 
|  | 877 | + | 
|  | 878 | +    @require_torch_version_greater_equal("2.8") | 
|  | 879 | +    def test_torch_compile_4bit(self): | 
|  | 880 | +        quantization_config = PipelineQuantizationConfig( | 
|  | 881 | +            quant_backend="bitsandbytes_4bit", | 
|  | 882 | +            quant_kwargs={"load_in_4bit": True}, | 
|  | 883 | +            components_to_quantize=["transformer"], | 
|  | 884 | +        ) | 
|  | 885 | +        pipe = DiffusionPipeline.from_pretrained( | 
|  | 886 | +            "hf-internal-testing/tiny-flux-pipe", quantization_config=quantization_config, torch_dtype=torch.bfloat16 | 
|  | 887 | +        ).to("cuda") | 
|  | 888 | +        pipe.transformer.compile(fullgraph=True) | 
|  | 889 | + | 
|  | 890 | +        for _ in range(2): | 
|  | 891 | +            pipe("a dog", num_inference_steps=4, max_sequence_length=16) | 
0 commit comments