Skip to content

Commit cc0a55a

Browse files
committed
prop.
1 parent 43e14d7 commit cc0a55a

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

tests/quantization/gguf/test_gguf.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -586,17 +586,11 @@ def get_dummy_inputs(self):
586586
@require_torch_version_greater("2.7.1")
587587
class GGUFCompileTests(QuantCompileTests):
588588
torch_dtype = torch.bfloat16
589-
quantization_config = GGUFQuantizationConfig(compute_dtype=torch_dtype)
590589
gguf_ckpt = "https://huggingface.co/city96/stable-diffusion-3.5-medium-gguf/blob/main/sd3.5_medium-Q3_K_M.gguf"
591590

592-
def _init_pipeline(self, *args, **kwargs):
593-
transformer = SD3Transformer2DModel.from_single_file(
594-
self.gguf_ckpt, quantization_config=self.quantization_config, torch_dtype=self.torch_dtype
595-
)
596-
pipe = DiffusionPipeline.from_pretrained(
597-
"stabilityai/stable-diffusion-3.5-medium", transformer=transformer, torch_dtype=self.torch_dtype
598-
)
599-
return pipe
591+
@property
592+
def quantization_config(self):
593+
return GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
600594

601595
def test_torch_compile(self):
602596
super()._test_torch_compile(quantization_config=self.quantization_config)
@@ -606,3 +600,12 @@ def test_torch_compile_with_cpu_offload(self):
606600

607601
def test_torch_compile_with_group_offload(self):
608602
super()._test_torch_compile_with_group_offload(quantization_config=self.quantization_config)
603+
604+
def _init_pipeline(self, *args, **kwargs):
605+
transformer = SD3Transformer2DModel.from_single_file(
606+
self.gguf_ckpt, quantization_config=self.quantization_config, torch_dtype=self.torch_dtype
607+
)
608+
pipe = DiffusionPipeline.from_pretrained(
609+
"stabilityai/stable-diffusion-3.5-medium", transformer=transformer, torch_dtype=self.torch_dtype
610+
)
611+
return pipe

0 commit comments

Comments
 (0)