@@ -586,17 +586,11 @@ def get_dummy_inputs(self):
586586@require_torch_version_greater ("2.7.1" )
587587class GGUFCompileTests (QuantCompileTests ):
588588 torch_dtype = torch .bfloat16
589- quantization_config = GGUFQuantizationConfig (compute_dtype = torch_dtype )
590589 gguf_ckpt = "https://huggingface.co/city96/stable-diffusion-3.5-medium-gguf/blob/main/sd3.5_medium-Q3_K_M.gguf"
591590
592- def _init_pipeline (self , * args , ** kwargs ):
593- transformer = SD3Transformer2DModel .from_single_file (
594- self .gguf_ckpt , quantization_config = self .quantization_config , torch_dtype = self .torch_dtype
595- )
596- pipe = DiffusionPipeline .from_pretrained (
597- "stabilityai/stable-diffusion-3.5-medium" , transformer = transformer , torch_dtype = self .torch_dtype
598- )
599- return pipe
591+ @property
592+ def quantization_config (self ):
593+ return GGUFQuantizationConfig (compute_dtype = self .torch_dtype )
600594
601595 def test_torch_compile (self ):
602596 super ()._test_torch_compile (quantization_config = self .quantization_config )
@@ -606,3 +600,12 @@ def test_torch_compile_with_cpu_offload(self):
606600
607601 def test_torch_compile_with_group_offload (self ):
608602 super ()._test_torch_compile_with_group_offload (quantization_config = self .quantization_config )
603+
604+ def _init_pipeline (self , * args , ** kwargs ):
605+ transformer = SD3Transformer2DModel .from_single_file (
606+ self .gguf_ckpt , quantization_config = self .quantization_config , torch_dtype = self .torch_dtype
607+ )
608+ pipe = DiffusionPipeline .from_pretrained (
609+ "stabilityai/stable-diffusion-3.5-medium" , transformer = transformer , torch_dtype = self .torch_dtype
610+ )
611+ return pipe
0 commit comments