@@ -577,3 +577,90 @@ def get_dummy_inputs(self):
577577 ).to (torch_device , self .torch_dtype ),
578578 "timesteps" : torch .tensor ([1 ]).to (torch_device , self .torch_dtype ),
579579 }
580+
581+
582+ class WanGGUFTexttoVideoSingleFileTests (GGUFSingleFileTesterMixin , unittest .TestCase ):
583+ ckpt_path = "https://huggingface.co/city96/Wan2.1-T2V-14B-gguf/blob/main/wan2.1-t2v-14b-Q3_K_S.gguf"
584+ torch_dtype = torch .bfloat16
585+ model_cls = WanTransformer3DModel
586+ expected_memory_use_in_gb = 9
587+
588+ def get_dummy_inputs (self ):
589+ return {
590+ "hidden_states" : torch .randn ((1 , 16 , 2 , 64 , 64 ), generator = torch .Generator ("cpu" ).manual_seed (0 )).to (
591+ torch_device , self .torch_dtype
592+ ),
593+ "encoder_hidden_states" : torch .randn (
594+ (1 , 512 , 4096 ),
595+ generator = torch .Generator ("cpu" ).manual_seed (0 ),
596+ ).to (torch_device , self .torch_dtype ),
597+ "timestep" : torch .tensor ([1 ]).to (torch_device , self .torch_dtype ),
598+ }
599+
600+
601+ class WanGGUFImagetoVideoSingleFileTests (GGUFSingleFileTesterMixin , unittest .TestCase ):
602+ ckpt_path = "https://huggingface.co/city96/Wan2.1-I2V-14B-480P-gguf/blob/main/wan2.1-i2v-14b-480p-Q3_K_S.gguf"
603+ torch_dtype = torch .bfloat16
604+ model_cls = WanTransformer3DModel
605+ expected_memory_use_in_gb = 9
606+
607+ def get_dummy_inputs (self ):
608+ return {
609+ "hidden_states" : torch .randn ((1 , 36 , 2 , 64 , 64 ), generator = torch .Generator ("cpu" ).manual_seed (0 )).to (
610+ torch_device , self .torch_dtype
611+ ),
612+ "encoder_hidden_states" : torch .randn (
613+ (1 , 512 , 4096 ),
614+ generator = torch .Generator ("cpu" ).manual_seed (0 ),
615+ ).to (torch_device , self .torch_dtype ),
616+ "encoder_hidden_states_image" : torch .randn (
617+ (1 , 257 , 1280 ), generator = torch .Generator ("cpu" ).manual_seed (0 )
618+ ).to (torch_device , self .torch_dtype ),
619+ "timestep" : torch .tensor ([1 ]).to (torch_device , self .torch_dtype ),
620+ }
621+
622+
623+ class WanVACEGGUFSingleFileTests (GGUFSingleFileTesterMixin , unittest .TestCase ):
624+ ckpt_path = "https://huggingface.co/QuantStack/Wan2.1_14B_VACE-GGUF/blob/main/Wan2.1_14B_VACE-Q3_K_S.gguf"
625+ torch_dtype = torch .bfloat16
626+ model_cls = WanVACETransformer3DModel
627+ expected_memory_use_in_gb = 9
628+
629+ def get_dummy_inputs (self ):
630+ return {
631+ "hidden_states" : torch .randn ((1 , 16 , 2 , 64 , 64 ), generator = torch .Generator ("cpu" ).manual_seed (0 )).to (
632+ torch_device , self .torch_dtype
633+ ),
634+ "encoder_hidden_states" : torch .randn (
635+ (1 , 512 , 4096 ),
636+ generator = torch .Generator ("cpu" ).manual_seed (0 ),
637+ ).to (torch_device , self .torch_dtype ),
638+ "control_hidden_states" : torch .randn (
639+ (1 , 96 , 2 , 64 , 64 ),
640+ generator = torch .Generator ("cpu" ).manual_seed (0 ),
641+ ).to (torch_device , self .torch_dtype ),
642+ "control_hidden_states_scale" : torch .randn (
643+ (8 ,),
644+ generator = torch .Generator ("cpu" ).manual_seed (0 ),
645+ ).to (torch_device , self .torch_dtype ),
646+ "timestep" : torch .tensor ([1 ]).to (torch_device , self .torch_dtype ),
647+ }
648+
649+
650+ @require_torch_version_greater ("2.7.1" )
651+ class GGUFCompileTests (QuantCompileTests , unittest .TestCase ):
652+ torch_dtype = torch .bfloat16
653+ gguf_ckpt = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
654+
655+ @property
656+ def quantization_config (self ):
657+ return GGUFQuantizationConfig (compute_dtype = self .torch_dtype )
658+
659+ def _init_pipeline (self , * args , ** kwargs ):
660+ transformer = FluxTransformer2DModel .from_single_file (
661+ self .gguf_ckpt , quantization_config = self .quantization_config , torch_dtype = self .torch_dtype
662+ )
663+ pipe = DiffusionPipeline .from_pretrained (
664+ "black-forest-labs/FLUX.1-dev" , transformer = transformer , torch_dtype = self .torch_dtype
665+ )
666+ return pipe
0 commit comments