Skip to content

Commit e8a3eec

Browse files
committed
add compile + offload tests for GGUF.
1 parent 05e8677 commit e8a3eec

File tree

1 file changed

+30
-1
lines changed

1 file changed

+30
-1
lines changed

tests/quantization/gguf/test_gguf.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
HiDreamImageTransformer2DModel,
1616
SD3Transformer2DModel,
1717
StableDiffusion3Pipeline,
18+
DiffusionPipeline,
1819
)
1920
from diffusers.utils import load_image
2021
from diffusers.utils.testing_utils import (
@@ -28,11 +29,12 @@
2829
numpy_cosine_similarity_distance,
2930
require_accelerate,
3031
require_big_accelerator,
32+
require_torch_version_greater,
3133
require_gguf_version_greater_or_equal,
3234
require_peft_backend,
3335
torch_device,
3436
)
35-
37+
from ..test_torch_compile_utils import QuantCompileTests
3638

3739
if is_gguf_available():
3840
from diffusers.quantizers.gguf.utils import GGUFLinear, GGUFParameter
@@ -577,3 +579,30 @@ def get_dummy_inputs(self):
577579
).to(torch_device, self.torch_dtype),
578580
"timesteps": torch.tensor([1]).to(torch_device, self.torch_dtype),
579581
}
582+
583+
584+
@require_torch_version_greater("2.7.1")
585+
class GGUFCompileTests(QuantCompileTests):
586+
torch_dtype = torch.bfloat16
587+
quantization_config = GGUFQuantizationConfig(compute_dtype=torch_dtype)
588+
gguf_ckpt = "https://huggingface.co/city96/stable-diffusion-3.5-medium-gguf/blob/main/sd3.5_medium-Q3_K_M.gguf"
589+
590+
def _init_pipeline(self, *args, **kwargs):
591+
transformer = SD3Transformer2DModel.from_single_file(
592+
self.gguf_ckpt, quantization_config=self.quantization_config, torch_dtype=self.torch_dtype
593+
)
594+
pipe = DiffusionPipeline.from_pretrained(
595+
"stabilityai/stable-diffusion-3.5-medium",
596+
transformer=transformer,
597+
torch_dtype=self.torch_dtype
598+
)
599+
return pipe
600+
601+
def test_torch_compile(self):
602+
super()._test_torch_compile(quantization_config=self.quantization_config)
603+
604+
def test_torch_compile_with_cpu_offload(self):
605+
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
606+
607+
def test_torch_compile_with_group_offload(self):
608+
super()._test_torch_compile_with_group_offload(quantization_config=self.quantization_config)

0 commit comments

Comments
 (0)