Skip to content

Commit 66fb1f4

Browse files
jiqing-fengBeinsezii
authored andcommitted
fix input shape for WanGGUFTexttoVideoSingleFileTests (huggingface#12081)
Signed-off-by: jiqing-feng <[email protected]>
1 parent e5bbf2f commit 66fb1f4

File tree

1 file changed

+87
-0
lines changed

1 file changed

+87
-0
lines changed

tests/quantization/gguf/test_gguf.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,3 +577,90 @@ def get_dummy_inputs(self):
577577
).to(torch_device, self.torch_dtype),
578578
"timesteps": torch.tensor([1]).to(torch_device, self.torch_dtype),
579579
}
580+
581+
582+
class WanGGUFTexttoVideoSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
583+
ckpt_path = "https://huggingface.co/city96/Wan2.1-T2V-14B-gguf/blob/main/wan2.1-t2v-14b-Q3_K_S.gguf"
584+
torch_dtype = torch.bfloat16
585+
model_cls = WanTransformer3DModel
586+
expected_memory_use_in_gb = 9
587+
588+
def get_dummy_inputs(self):
589+
return {
590+
"hidden_states": torch.randn((1, 16, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
591+
torch_device, self.torch_dtype
592+
),
593+
"encoder_hidden_states": torch.randn(
594+
(1, 512, 4096),
595+
generator=torch.Generator("cpu").manual_seed(0),
596+
).to(torch_device, self.torch_dtype),
597+
"timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
598+
}
599+
600+
601+
class WanGGUFImagetoVideoSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
602+
ckpt_path = "https://huggingface.co/city96/Wan2.1-I2V-14B-480P-gguf/blob/main/wan2.1-i2v-14b-480p-Q3_K_S.gguf"
603+
torch_dtype = torch.bfloat16
604+
model_cls = WanTransformer3DModel
605+
expected_memory_use_in_gb = 9
606+
607+
def get_dummy_inputs(self):
608+
return {
609+
"hidden_states": torch.randn((1, 36, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
610+
torch_device, self.torch_dtype
611+
),
612+
"encoder_hidden_states": torch.randn(
613+
(1, 512, 4096),
614+
generator=torch.Generator("cpu").manual_seed(0),
615+
).to(torch_device, self.torch_dtype),
616+
"encoder_hidden_states_image": torch.randn(
617+
(1, 257, 1280), generator=torch.Generator("cpu").manual_seed(0)
618+
).to(torch_device, self.torch_dtype),
619+
"timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
620+
}
621+
622+
623+
class WanVACEGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
624+
ckpt_path = "https://huggingface.co/QuantStack/Wan2.1_14B_VACE-GGUF/blob/main/Wan2.1_14B_VACE-Q3_K_S.gguf"
625+
torch_dtype = torch.bfloat16
626+
model_cls = WanVACETransformer3DModel
627+
expected_memory_use_in_gb = 9
628+
629+
def get_dummy_inputs(self):
630+
return {
631+
"hidden_states": torch.randn((1, 16, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
632+
torch_device, self.torch_dtype
633+
),
634+
"encoder_hidden_states": torch.randn(
635+
(1, 512, 4096),
636+
generator=torch.Generator("cpu").manual_seed(0),
637+
).to(torch_device, self.torch_dtype),
638+
"control_hidden_states": torch.randn(
639+
(1, 96, 2, 64, 64),
640+
generator=torch.Generator("cpu").manual_seed(0),
641+
).to(torch_device, self.torch_dtype),
642+
"control_hidden_states_scale": torch.randn(
643+
(8,),
644+
generator=torch.Generator("cpu").manual_seed(0),
645+
).to(torch_device, self.torch_dtype),
646+
"timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
647+
}
648+
649+
650+
@require_torch_version_greater("2.7.1")
651+
class GGUFCompileTests(QuantCompileTests, unittest.TestCase):
652+
torch_dtype = torch.bfloat16
653+
gguf_ckpt = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
654+
655+
@property
656+
def quantization_config(self):
657+
return GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
658+
659+
def _init_pipeline(self, *args, **kwargs):
660+
transformer = FluxTransformer2DModel.from_single_file(
661+
self.gguf_ckpt, quantization_config=self.quantization_config, torch_dtype=self.torch_dtype
662+
)
663+
pipe = DiffusionPipeline.from_pretrained(
664+
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=self.torch_dtype
665+
)
666+
return pipe

0 commit comments

Comments
 (0)