|  | 
| 15 | 15 |     HiDreamImageTransformer2DModel, | 
| 16 | 16 |     SD3Transformer2DModel, | 
| 17 | 17 |     StableDiffusion3Pipeline, | 
|  | 18 | +    DiffusionPipeline, | 
| 18 | 19 | ) | 
| 19 | 20 | from diffusers.utils import load_image | 
| 20 | 21 | from diffusers.utils.testing_utils import ( | 
|  | 
| 28 | 29 |     numpy_cosine_similarity_distance, | 
| 29 | 30 |     require_accelerate, | 
| 30 | 31 |     require_big_accelerator, | 
|  | 32 | +    require_torch_version_greater, | 
| 31 | 33 |     require_gguf_version_greater_or_equal, | 
| 32 | 34 |     require_peft_backend, | 
| 33 | 35 |     torch_device, | 
| 34 | 36 | ) | 
| 35 |  | - | 
|  | 37 | +from ..test_torch_compile_utils import QuantCompileTests | 
| 36 | 38 | 
 | 
| 37 | 39 | if is_gguf_available(): | 
| 38 | 40 |     from diffusers.quantizers.gguf.utils import GGUFLinear, GGUFParameter | 
| @@ -577,3 +579,30 @@ def get_dummy_inputs(self): | 
| 577 | 579 |             ).to(torch_device, self.torch_dtype), | 
| 578 | 580 |             "timesteps": torch.tensor([1]).to(torch_device, self.torch_dtype), | 
| 579 | 581 |         } | 
|  | 582 | + | 
|  | 583 | + | 
|  | 584 | +@require_torch_version_greater("2.7.1") | 
|  | 585 | +class GGUFCompileTests(QuantCompileTests): | 
|  | 586 | +    torch_dtype = torch.bfloat16 | 
|  | 587 | +    quantization_config = GGUFQuantizationConfig(compute_dtype=torch_dtype) | 
|  | 588 | +    gguf_ckpt = "https://huggingface.co/city96/stable-diffusion-3.5-medium-gguf/blob/main/sd3.5_medium-Q3_K_M.gguf" | 
|  | 589 | + | 
|  | 590 | +    def _init_pipeline(self, *args, **kwargs): | 
|  | 591 | +        transformer = SD3Transformer2DModel.from_single_file( | 
|  | 592 | +            self.gguf_ckpt, quantization_config=self.quantization_config, torch_dtype=self.torch_dtype | 
|  | 593 | +        ) | 
|  | 594 | +        pipe = DiffusionPipeline.from_pretrained( | 
|  | 595 | +            "stabilityai/stable-diffusion-3.5-medium", | 
|  | 596 | +            transformer=transformer, | 
|  | 597 | +            torch_dtype=self.torch_dtype | 
|  | 598 | +        ) | 
|  | 599 | +        return pipe | 
|  | 600 | + | 
|  | 601 | +    def test_torch_compile(self): | 
|  | 602 | +        super()._test_torch_compile(quantization_config=self.quantization_config) | 
|  | 603 | + | 
|  | 604 | +    def test_torch_compile_with_cpu_offload(self): | 
|  | 605 | +        super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config) | 
|  | 606 | + | 
|  | 607 | +    def test_torch_compile_with_group_offload(self): | 
|  | 608 | +        super()._test_torch_compile_with_group_offload(quantization_config=self.quantization_config) | 
0 commit comments