@@ -656,7 +656,7 @@ def get_dummy_inputs(self):
656656@require_torch_version_greater ("2.7.1" ) 
657657class  GGUFCompileTests (QuantCompileTests ):
658658    torch_dtype  =  torch .bfloat16 
659-     gguf_ckpt  =  "https://huggingface.co/city96/stable-diffusion-3.5-medium -gguf/blob/main/sd3.5_medium-Q3_K_M .gguf" 
659+     gguf_ckpt  =  "https://huggingface.co/city96/FLUX.1-dev -gguf/blob/main/flux1-dev-Q2_K .gguf" 
660660
661661    @property  
662662    def  quantization_config (self ):
@@ -668,14 +668,14 @@ def test_torch_compile(self):
668668    def  test_torch_compile_with_cpu_offload (self ):
669669        super ()._test_torch_compile_with_cpu_offload (quantization_config = self .quantization_config )
670670
671-     def  test_torch_compile_with_group_offload (self ):
672-         super ()._test_torch_compile_with_group_offload (quantization_config = self .quantization_config )
671+     def  test_torch_compile_with_group_offload_leaf (self ):
672+         super ()._test_torch_compile_with_group_offload_leaf (quantization_config = self .quantization_config )
673673
674674    def  _init_pipeline (self , * args , ** kwargs ):
675-         transformer  =  SD3Transformer2DModel .from_single_file (
675+         transformer  =  FluxTransformer2DModel .from_single_file (
676676            self .gguf_ckpt , quantization_config = self .quantization_config , torch_dtype = self .torch_dtype 
677677        )
678678        pipe  =  DiffusionPipeline .from_pretrained (
679-             "stabilityai/stable-diffusion-3.5-medium " , transformer = transformer , torch_dtype = self .torch_dtype 
679+             "black-forest-labs/FLUX.1-dev " , transformer = transformer , torch_dtype = self .torch_dtype 
680680        )
681681        return  pipe 
0 commit comments