diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index 0c6f3cda666e..e9556398b98f 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -135,6 +135,10 @@ "checkpoint_mapping_fn": convert_wan_transformer_to_diffusers, "default_subfolder": "transformer", }, + "WanVACETransformer3DModel": { + "checkpoint_mapping_fn": convert_wan_transformer_to_diffusers, + "default_subfolder": "transformer", + }, "AutoencoderKLWan": { "checkpoint_mapping_fn": convert_wan_vae_to_diffusers, "default_subfolder": "vae", diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index d8d183304e9a..19ff83ffabea 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -126,6 +126,7 @@ ], "wan": ["model.diffusion_model.head.modulation", "head.modulation"], "wan_vae": "decoder.middle.0.residual.0.gamma", + "wan_vace": "vace_blocks.0.after_proj.bias", "hidream": "double_stream_blocks.0.block.adaLN_modulation.1.bias", } @@ -192,6 +193,8 @@ "wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"}, "wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"}, "wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"}, + "wan-vace-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-VACE-1.3B-diffusers"}, + "wan-vace-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-VACE-14B-diffusers"}, "hidream": {"pretrained_model_name_or_path": "HiDream-ai/HiDream-I1-Dev"}, } @@ -698,12 +701,19 @@ def infer_diffusers_model_type(checkpoint): else: target_key = "patch_embedding.weight" - if checkpoint[target_key].shape[0] == 1536: + if CHECKPOINT_KEY_NAMES["wan_vace"] in checkpoint: + if checkpoint[target_key].shape[0] == 1536: + model_type = "wan-vace-1.3B" + elif checkpoint[target_key].shape[0] == 5120: + model_type = "wan-vace-14B" + + elif checkpoint[target_key].shape[0] == 1536: model_type = "wan-t2v-1.3B" elif checkpoint[target_key].shape[0] == 5120 and checkpoint[target_key].shape[1] == 16: model_type = "wan-t2v-14B" else: model_type = "wan-i2v-14B" + elif CHECKPOINT_KEY_NAMES["wan_vae"] in checkpoint: # All Wan models use the same VAE so we can use the same default model repo to fetch the config model_type = "wan-t2v-14B" @@ -3093,6 +3103,9 @@ def convert_wan_transformer_to_diffusers(checkpoint, **kwargs): "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj", "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2", "img_emb.proj.4": "condition_embedder.image_embedder.norm2", + # For the VACE model + "before_proj": "proj_in", + "after_proj": "proj_out", } for key in list(checkpoint.keys()): diff --git a/tests/quantization/gguf/test_gguf.py b/tests/quantization/gguf/test_gguf.py index 5d1fa4c22e2a..0d786de7e78f 100644 --- a/tests/quantization/gguf/test_gguf.py +++ b/tests/quantization/gguf/test_gguf.py @@ -15,6 +15,8 @@ HiDreamImageTransformer2DModel, SD3Transformer2DModel, StableDiffusion3Pipeline, + WanTransformer3DModel, + WanVACETransformer3DModel, ) from diffusers.utils import load_image from diffusers.utils.testing_utils import ( @@ -577,3 +579,71 @@ def get_dummy_inputs(self): ).to(torch_device, self.torch_dtype), "timesteps": torch.tensor([1]).to(torch_device, self.torch_dtype), } + + +class WanGGUFTexttoVideoSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase): + ckpt_path = "https://huggingface.co/city96/Wan2.1-T2V-14B-gguf/blob/main/wan2.1-t2v-14b-Q3_K_S.gguf" + torch_dtype = torch.bfloat16 + model_cls = WanTransformer3DModel + expected_memory_use_in_gb = 9 + + def get_dummy_inputs(self): + return { + "hidden_states": torch.randn((1, 36, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to( + torch_device, self.torch_dtype + ), + "encoder_hidden_states": torch.randn( + (1, 512, 4096), + generator=torch.Generator("cpu").manual_seed(0), + ).to(torch_device, self.torch_dtype), + "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype), + } + + +class WanGGUFImagetoVideoSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase): + ckpt_path = "https://huggingface.co/city96/Wan2.1-I2V-14B-480P-gguf/blob/main/wan2.1-i2v-14b-480p-Q3_K_S.gguf" + torch_dtype = torch.bfloat16 + model_cls = WanTransformer3DModel + expected_memory_use_in_gb = 9 + + def get_dummy_inputs(self): + return { + "hidden_states": torch.randn((1, 36, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to( + torch_device, self.torch_dtype + ), + "encoder_hidden_states": torch.randn( + (1, 512, 4096), + generator=torch.Generator("cpu").manual_seed(0), + ).to(torch_device, self.torch_dtype), + "encoder_hidden_states_image": torch.randn( + (1, 257, 1280), generator=torch.Generator("cpu").manual_seed(0) + ).to(torch_device, self.torch_dtype), + "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype), + } + + +class WanVACEGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase): + ckpt_path = "https://huggingface.co/QuantStack/Wan2.1_14B_VACE-GGUF/blob/main/Wan2.1_14B_VACE-Q3_K_S.gguf" + torch_dtype = torch.bfloat16 + model_cls = WanVACETransformer3DModel + expected_memory_use_in_gb = 9 + + def get_dummy_inputs(self): + return { + "hidden_states": torch.randn((1, 16, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to( + torch_device, self.torch_dtype + ), + "encoder_hidden_states": torch.randn( + (1, 512, 4096), + generator=torch.Generator("cpu").manual_seed(0), + ).to(torch_device, self.torch_dtype), + "control_hidden_states": torch.randn( + (1, 96, 2, 64, 64), + generator=torch.Generator("cpu").manual_seed(0), + ).to(torch_device, self.torch_dtype), + "control_hidden_states_scale": torch.randn( + (8,), + generator=torch.Generator("cpu").manual_seed(0), + ).to(torch_device, self.torch_dtype), + "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype), + }