Skip to content

Commit 95a55f9

Browse files
committed
add conversion script for vae 2.2
1 parent f5da83c commit 95a55f9

File tree

1 file changed

+314
-1
lines changed

1 file changed

+314
-1
lines changed

scripts/convert_wan_to_diffusers.py

Lines changed: 314 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,305 @@ def convert_vae():
560560
vae.load_state_dict(new_state_dict, strict=True, assign=True)
561561
return vae
562562

563+
vae22_diffusers_config = {
564+
"base_dim": 160,
565+
"z_dim": 48,
566+
"is_residual": True,
567+
"in_channels": 12,
568+
"out_channels": 12,
569+
"decoder_base_dim": 256,
570+
"latents_mean":[
571+
-0.2289,
572+
-0.0052,
573+
-0.1323,
574+
-0.2339,
575+
-0.2799,
576+
-0.0174,
577+
-0.1838,
578+
-0.1557,
579+
-0.1382,
580+
-0.0542,
581+
-0.2813,
582+
-0.0891,
583+
-0.1570,
584+
-0.0098,
585+
-0.0375,
586+
-0.1825,
587+
-0.2246,
588+
-0.1207,
589+
-0.0698,
590+
-0.5109,
591+
-0.2665,
592+
-0.2108,
593+
-0.2158,
594+
-0.2502,
595+
-0.2055,
596+
-0.0322,
597+
-0.1109,
598+
-0.1567,
599+
-0.0729,
600+
-0.0899,
601+
-0.2799,
602+
-0.1230,
603+
-0.0313,
604+
-0.1649,
605+
-0.0117,
606+
-0.0723,
607+
-0.2839,
608+
-0.2083,
609+
-0.0520,
610+
-0.3748,
611+
-0.0152,
612+
-0.1957,
613+
-0.1433,
614+
-0.2944,
615+
-0.3573,
616+
-0.0548,
617+
-0.1681,
618+
-0.0667,
619+
],
620+
"latents_std":[
621+
-0.4765,
622+
-1.0364,
623+
-0.4514,
624+
-1.1677,
625+
-0.5313,
626+
-0.4990,
627+
-0.4818,
628+
-0.5013,
629+
-0.8158,
630+
-1.0344,
631+
-0.5894,
632+
-1.0901,
633+
-0.6885,
634+
-0.6165,
635+
-0.8454,
636+
-0.4978,
637+
-0.5759,
638+
-0.3523,
639+
-0.7135,
640+
-0.6804,
641+
-0.5833,
642+
-1.4146,
643+
-0.8986,
644+
-0.5659,
645+
-0.7069,
646+
-0.5338,
647+
-0.4889,
648+
-0.4917,
649+
-0.4069,
650+
-0.4999,
651+
-0.6866,
652+
-0.4093,
653+
-0.5709,
654+
-0.6065,
655+
-0.6415,
656+
-0.4944,
657+
-0.5726,
658+
-1.2042,
659+
-0.5458,
660+
-1.6887,
661+
-0.3971,
662+
-1.0600,
663+
-0.3943,
664+
-0.5537,
665+
-0.5444,
666+
-0.4089,
667+
-0.7468,
668+
-0.7744,
669+
],
670+
}
671+
672+
673+
def convert_vae_22():
674+
vae_ckpt_path = hf_hub_download("Wan-AI/Wan2.2-TI2V-5B", "Wan2.2_VAE.pth")
675+
old_state_dict = torch.load(vae_ckpt_path, weights_only=True)
676+
new_state_dict = {}
677+
678+
# Create mappings for specific components
679+
middle_key_mapping = {
680+
# Encoder middle block
681+
"encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma",
682+
"encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias",
683+
"encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight",
684+
"encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma",
685+
"encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias",
686+
"encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight",
687+
"encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma",
688+
"encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias",
689+
"encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight",
690+
"encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma",
691+
"encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias",
692+
"encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight",
693+
# Decoder middle block
694+
"decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma",
695+
"decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias",
696+
"decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight",
697+
"decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma",
698+
"decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias",
699+
"decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight",
700+
"decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma",
701+
"decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias",
702+
"decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight",
703+
"decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma",
704+
"decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias",
705+
"decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight",
706+
}
707+
708+
# Create a mapping for attention blocks
709+
attention_mapping = {
710+
# Encoder middle attention
711+
"encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma",
712+
"encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight",
713+
"encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias",
714+
"encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight",
715+
"encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias",
716+
# Decoder middle attention
717+
"decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma",
718+
"decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight",
719+
"decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias",
720+
"decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight",
721+
"decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias",
722+
}
723+
724+
# Create a mapping for the head components
725+
head_mapping = {
726+
# Encoder head
727+
"encoder.head.0.gamma": "encoder.norm_out.gamma",
728+
"encoder.head.2.bias": "encoder.conv_out.bias",
729+
"encoder.head.2.weight": "encoder.conv_out.weight",
730+
# Decoder head
731+
"decoder.head.0.gamma": "decoder.norm_out.gamma",
732+
"decoder.head.2.bias": "decoder.conv_out.bias",
733+
"decoder.head.2.weight": "decoder.conv_out.weight",
734+
}
735+
736+
# Create a mapping for the quant components
737+
quant_mapping = {
738+
"conv1.weight": "quant_conv.weight",
739+
"conv1.bias": "quant_conv.bias",
740+
"conv2.weight": "post_quant_conv.weight",
741+
"conv2.bias": "post_quant_conv.bias",
742+
}
743+
744+
# Process each key in the state dict
745+
for key, value in old_state_dict.items():
746+
# Handle middle block keys using the mapping
747+
if key in middle_key_mapping:
748+
new_key = middle_key_mapping[key]
749+
new_state_dict[new_key] = value
750+
# Handle attention blocks using the mapping
751+
elif key in attention_mapping:
752+
new_key = attention_mapping[key]
753+
new_state_dict[new_key] = value
754+
# Handle head keys using the mapping
755+
elif key in head_mapping:
756+
new_key = head_mapping[key]
757+
new_state_dict[new_key] = value
758+
# Handle quant keys using the mapping
759+
elif key in quant_mapping:
760+
new_key = quant_mapping[key]
761+
new_state_dict[new_key] = value
762+
# Handle encoder conv1
763+
elif key == "encoder.conv1.weight":
764+
new_state_dict["encoder.conv_in.weight"] = value
765+
elif key == "encoder.conv1.bias":
766+
new_state_dict["encoder.conv_in.bias"] = value
767+
# Handle decoder conv1
768+
elif key == "decoder.conv1.weight":
769+
new_state_dict["decoder.conv_in.weight"] = value
770+
elif key == "decoder.conv1.bias":
771+
new_state_dict["decoder.conv_in.bias"] = value
772+
# Handle encoder downsamples
773+
elif key.startswith("encoder.downsamples."):
774+
# Change encoder.downsamples to encoder.down_blocks
775+
new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.")
776+
777+
# Handle residual blocks - change downsamples to resnets and rename components
778+
if "residual" in new_key or "shortcut" in new_key:
779+
# Change the second downsamples to resnets
780+
new_key = new_key.replace(".downsamples.", ".resnets.")
781+
782+
# Rename residual components
783+
if ".residual.0.gamma" in new_key:
784+
new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma")
785+
elif ".residual.2.weight" in new_key:
786+
new_key = new_key.replace(".residual.2.weight", ".conv1.weight")
787+
elif ".residual.2.bias" in new_key:
788+
new_key = new_key.replace(".residual.2.bias", ".conv1.bias")
789+
elif ".residual.3.gamma" in new_key:
790+
new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma")
791+
elif ".residual.6.weight" in new_key:
792+
new_key = new_key.replace(".residual.6.weight", ".conv2.weight")
793+
elif ".residual.6.bias" in new_key:
794+
new_key = new_key.replace(".residual.6.bias", ".conv2.bias")
795+
elif ".shortcut.weight" in new_key:
796+
new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight")
797+
elif ".shortcut.bias" in new_key:
798+
new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias")
799+
800+
# Handle resample blocks - change downsamples to downsampler and remove index
801+
elif "resample" in new_key or "time_conv" in new_key:
802+
# Change the second downsamples to downsampler and remove the index
803+
parts = new_key.split(".")
804+
# Find the pattern: encoder.down_blocks.X.downsamples.Y.resample...
805+
# We want to change it to: encoder.down_blocks.X.downsampler.resample...
806+
if len(parts) >= 4 and parts[3] == "downsamples":
807+
# Remove the index (parts[4]) and change downsamples to downsampler
808+
new_parts = parts[:3] + ["downsampler"] + parts[5:]
809+
new_key = ".".join(new_parts)
810+
811+
new_state_dict[new_key] = value
812+
813+
# Handle decoder upsamples
814+
elif key.startswith("decoder.upsamples."):
815+
# Change decoder.upsamples to decoder.up_blocks
816+
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
817+
818+
# Handle residual blocks - change upsamples to resnets and rename components
819+
if "residual" in new_key or "shortcut" in new_key:
820+
# Change the second upsamples to resnets
821+
new_key = new_key.replace(".upsamples.", ".resnets.")
822+
823+
# Rename residual components
824+
if ".residual.0.gamma" in new_key:
825+
new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma")
826+
elif ".residual.2.weight" in new_key:
827+
new_key = new_key.replace(".residual.2.weight", ".conv1.weight")
828+
elif ".residual.2.bias" in new_key:
829+
new_key = new_key.replace(".residual.2.bias", ".conv1.bias")
830+
elif ".residual.3.gamma" in new_key:
831+
new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma")
832+
elif ".residual.6.weight" in new_key:
833+
new_key = new_key.replace(".residual.6.weight", ".conv2.weight")
834+
elif ".residual.6.bias" in new_key:
835+
new_key = new_key.replace(".residual.6.bias", ".conv2.bias")
836+
elif ".shortcut.weight" in new_key:
837+
new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight")
838+
elif ".shortcut.bias" in new_key:
839+
new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias")
840+
841+
# Handle resample blocks - change upsamples to upsampler and remove index
842+
elif "resample" in new_key or "time_conv" in new_key:
843+
# Change the second upsamples to upsampler and remove the index
844+
parts = new_key.split(".")
845+
# Find the pattern: encoder.down_blocks.X.downsamples.Y.resample...
846+
# We want to change it to: encoder.down_blocks.X.downsampler.resample...
847+
if len(parts) >= 4 and parts[3] == "upsamples":
848+
# Remove the index (parts[4]) and change upsamples to upsampler
849+
new_parts = parts[:3] + ["upsampler"] + parts[5:]
850+
new_key = ".".join(new_parts)
851+
852+
new_state_dict[new_key] = value
853+
else:
854+
# Keep other keys unchanged
855+
new_state_dict[key] = value
856+
857+
with init_empty_weights():
858+
vae = AutoencoderKLWan(**vae22_config)
859+
vae.load_state_dict(new_state_dict, strict=True, assign=True)
860+
return vae
861+
563862

564863
def get_args():
565864
parser = argparse.ArgumentParser()
@@ -586,7 +885,11 @@ def get_args():
586885
transformer = convert_transformer(args.model_type)
587886
transformer_2 = None
588887

589-
vae = convert_vae()
888+
if "Wan2.2" in args.model_type and "TI2V" in args.model_type:
889+
vae = convert_vae_22()
890+
else:
891+
vae = convert_vae()
892+
590893
text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl", torch_dtype=torch.bfloat16)
591894
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
592895
flow_shift = 16.0 if "FLF2V" in args.model_type else 3.0
@@ -609,6 +912,16 @@ def get_args():
609912
scheduler=scheduler,
610913
boundary_ratio=0.9,
611914
)
915+
elif "Wan2.2" and "T2V" in args.model_type:
916+
pipe = WanPipeline(
917+
transformer=transformer,
918+
transformer_2=transformer_2,
919+
text_encoder=text_encoder,
920+
tokenizer=tokenizer,
921+
vae=vae,
922+
scheduler=scheduler,
923+
boundary_ratio=0.875,
924+
)
612925
elif "I2V" in args.model_type or "FLF2V" in args.model_type:
613926
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
614927
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16

0 commit comments

Comments
 (0)