@@ -560,6 +560,305 @@ def convert_vae():
560560    vae .load_state_dict (new_state_dict , strict = True , assign = True )
561561    return  vae 
562562
563+ vae22_diffusers_config  =  {
564+   "base_dim" : 160 ,
565+   "z_dim" : 48 ,
566+   "is_residual" : True ,
567+   "in_channels" : 12 ,
568+   "out_channels" : 12 ,
569+   "decoder_base_dim" : 256 ,
570+   "latents_mean" :[
571+     - 0.2289 , 
572+     - 0.0052 , 
573+     - 0.1323 , 
574+     - 0.2339 , 
575+     - 0.2799 ,
576+     - 0.0174 ,
577+     - 0.1838 ,
578+     - 0.1557 ,
579+     - 0.1382 ,
580+     - 0.0542 ,
581+     - 0.2813 ,
582+     - 0.0891 ,
583+     - 0.1570 ,
584+     - 0.0098 ,
585+     - 0.0375 ,
586+     - 0.1825 ,
587+     - 0.2246 ,
588+     - 0.1207 ,
589+     - 0.0698 ,
590+     - 0.5109 ,
591+     - 0.2665 ,
592+     - 0.2108 ,
593+     - 0.2158 ,
594+     - 0.2502 ,
595+     - 0.2055 ,
596+     - 0.0322 ,
597+     - 0.1109 ,
598+     - 0.1567 ,
599+     - 0.0729 ,
600+     - 0.0899 ,
601+     - 0.2799 ,
602+     - 0.1230 ,
603+     - 0.0313 ,
604+     - 0.1649 ,
605+     - 0.0117 ,
606+     - 0.0723 ,
607+     - 0.2839 ,
608+     - 0.2083 ,
609+     - 0.0520 ,
610+     - 0.3748 ,
611+     - 0.0152 ,
612+     - 0.1957 ,
613+     - 0.1433 ,
614+     - 0.2944 ,
615+     - 0.3573 ,
616+     - 0.0548 ,
617+     - 0.1681 ,
618+     - 0.0667 ,
619+     ],
620+     "latents_std" :[
621+     - 0.4765 ,
622+     - 1.0364 ,
623+     - 0.4514 ,
624+     - 1.1677 ,
625+     - 0.5313 ,
626+     - 0.4990 ,
627+     - 0.4818 ,
628+     - 0.5013 ,
629+     - 0.8158 ,
630+     - 1.0344 ,
631+     - 0.5894 ,
632+     - 1.0901 ,
633+     - 0.6885 ,
634+     - 0.6165 ,
635+     - 0.8454 ,
636+     - 0.4978 ,
637+     - 0.5759 ,
638+     - 0.3523 ,
639+     - 0.7135 ,
640+     - 0.6804 ,
641+     - 0.5833 ,
642+     - 1.4146 ,
643+     - 0.8986 ,
644+     - 0.5659 ,
645+     - 0.7069 ,
646+     - 0.5338 ,
647+     - 0.4889 ,
648+     - 0.4917 ,
649+     - 0.4069 ,
650+     - 0.4999 ,
651+     - 0.6866 ,
652+     - 0.4093 ,
653+     - 0.5709 ,
654+     - 0.6065 ,
655+     - 0.6415 ,
656+     - 0.4944 ,
657+     - 0.5726 ,
658+     - 1.2042 ,
659+     - 0.5458 ,
660+     - 1.6887 ,
661+     - 0.3971 ,
662+     - 1.0600 ,
663+     - 0.3943 ,
664+     - 0.5537 ,
665+     - 0.5444 ,
666+     - 0.4089 ,
667+     - 0.7468 ,
668+     - 0.7744 ,
669+     ],
670+ }
671+ 
672+ 
673+ def  convert_vae_22 ():
674+     vae_ckpt_path  =  hf_hub_download ("Wan-AI/Wan2.2-TI2V-5B" , "Wan2.2_VAE.pth" )
675+     old_state_dict  =  torch .load (vae_ckpt_path , weights_only = True )
676+     new_state_dict  =  {}
677+ 
678+     # Create mappings for specific components 
679+     middle_key_mapping  =  {
680+         # Encoder middle block 
681+         "encoder.middle.0.residual.0.gamma" : "encoder.mid_block.resnets.0.norm1.gamma" ,
682+         "encoder.middle.0.residual.2.bias" : "encoder.mid_block.resnets.0.conv1.bias" ,
683+         "encoder.middle.0.residual.2.weight" : "encoder.mid_block.resnets.0.conv1.weight" ,
684+         "encoder.middle.0.residual.3.gamma" : "encoder.mid_block.resnets.0.norm2.gamma" ,
685+         "encoder.middle.0.residual.6.bias" : "encoder.mid_block.resnets.0.conv2.bias" ,
686+         "encoder.middle.0.residual.6.weight" : "encoder.mid_block.resnets.0.conv2.weight" ,
687+         "encoder.middle.2.residual.0.gamma" : "encoder.mid_block.resnets.1.norm1.gamma" ,
688+         "encoder.middle.2.residual.2.bias" : "encoder.mid_block.resnets.1.conv1.bias" ,
689+         "encoder.middle.2.residual.2.weight" : "encoder.mid_block.resnets.1.conv1.weight" ,
690+         "encoder.middle.2.residual.3.gamma" : "encoder.mid_block.resnets.1.norm2.gamma" ,
691+         "encoder.middle.2.residual.6.bias" : "encoder.mid_block.resnets.1.conv2.bias" ,
692+         "encoder.middle.2.residual.6.weight" : "encoder.mid_block.resnets.1.conv2.weight" ,
693+         # Decoder middle block 
694+         "decoder.middle.0.residual.0.gamma" : "decoder.mid_block.resnets.0.norm1.gamma" ,
695+         "decoder.middle.0.residual.2.bias" : "decoder.mid_block.resnets.0.conv1.bias" ,
696+         "decoder.middle.0.residual.2.weight" : "decoder.mid_block.resnets.0.conv1.weight" ,
697+         "decoder.middle.0.residual.3.gamma" : "decoder.mid_block.resnets.0.norm2.gamma" ,
698+         "decoder.middle.0.residual.6.bias" : "decoder.mid_block.resnets.0.conv2.bias" ,
699+         "decoder.middle.0.residual.6.weight" : "decoder.mid_block.resnets.0.conv2.weight" ,
700+         "decoder.middle.2.residual.0.gamma" : "decoder.mid_block.resnets.1.norm1.gamma" ,
701+         "decoder.middle.2.residual.2.bias" : "decoder.mid_block.resnets.1.conv1.bias" ,
702+         "decoder.middle.2.residual.2.weight" : "decoder.mid_block.resnets.1.conv1.weight" ,
703+         "decoder.middle.2.residual.3.gamma" : "decoder.mid_block.resnets.1.norm2.gamma" ,
704+         "decoder.middle.2.residual.6.bias" : "decoder.mid_block.resnets.1.conv2.bias" ,
705+         "decoder.middle.2.residual.6.weight" : "decoder.mid_block.resnets.1.conv2.weight" ,
706+     }
707+ 
708+     # Create a mapping for attention blocks 
709+     attention_mapping  =  {
710+         # Encoder middle attention 
711+         "encoder.middle.1.norm.gamma" : "encoder.mid_block.attentions.0.norm.gamma" ,
712+         "encoder.middle.1.to_qkv.weight" : "encoder.mid_block.attentions.0.to_qkv.weight" ,
713+         "encoder.middle.1.to_qkv.bias" : "encoder.mid_block.attentions.0.to_qkv.bias" ,
714+         "encoder.middle.1.proj.weight" : "encoder.mid_block.attentions.0.proj.weight" ,
715+         "encoder.middle.1.proj.bias" : "encoder.mid_block.attentions.0.proj.bias" ,
716+         # Decoder middle attention 
717+         "decoder.middle.1.norm.gamma" : "decoder.mid_block.attentions.0.norm.gamma" ,
718+         "decoder.middle.1.to_qkv.weight" : "decoder.mid_block.attentions.0.to_qkv.weight" ,
719+         "decoder.middle.1.to_qkv.bias" : "decoder.mid_block.attentions.0.to_qkv.bias" ,
720+         "decoder.middle.1.proj.weight" : "decoder.mid_block.attentions.0.proj.weight" ,
721+         "decoder.middle.1.proj.bias" : "decoder.mid_block.attentions.0.proj.bias" ,
722+     }
723+ 
724+     # Create a mapping for the head components 
725+     head_mapping  =  {
726+         # Encoder head 
727+         "encoder.head.0.gamma" : "encoder.norm_out.gamma" ,
728+         "encoder.head.2.bias" : "encoder.conv_out.bias" ,
729+         "encoder.head.2.weight" : "encoder.conv_out.weight" ,
730+         # Decoder head 
731+         "decoder.head.0.gamma" : "decoder.norm_out.gamma" ,
732+         "decoder.head.2.bias" : "decoder.conv_out.bias" ,
733+         "decoder.head.2.weight" : "decoder.conv_out.weight" ,
734+     }
735+ 
736+     # Create a mapping for the quant components 
737+     quant_mapping  =  {
738+         "conv1.weight" : "quant_conv.weight" ,
739+         "conv1.bias" : "quant_conv.bias" ,
740+         "conv2.weight" : "post_quant_conv.weight" ,
741+         "conv2.bias" : "post_quant_conv.bias" ,
742+     }
743+         
744+     # Process each key in the state dict 
745+     for  key , value  in  old_state_dict .items ():
746+         # Handle middle block keys using the mapping 
747+         if  key  in  middle_key_mapping :
748+             new_key  =  middle_key_mapping [key ]
749+             new_state_dict [new_key ] =  value 
750+         # Handle attention blocks using the mapping 
751+         elif  key  in  attention_mapping :
752+             new_key  =  attention_mapping [key ]
753+             new_state_dict [new_key ] =  value 
754+         # Handle head keys using the mapping 
755+         elif  key  in  head_mapping :
756+             new_key  =  head_mapping [key ]
757+             new_state_dict [new_key ] =  value 
758+         # Handle quant keys using the mapping 
759+         elif  key  in  quant_mapping :
760+             new_key  =  quant_mapping [key ]
761+             new_state_dict [new_key ] =  value 
762+         # Handle encoder conv1 
763+         elif  key  ==  "encoder.conv1.weight" :
764+             new_state_dict ["encoder.conv_in.weight" ] =  value 
765+         elif  key  ==  "encoder.conv1.bias" :
766+             new_state_dict ["encoder.conv_in.bias" ] =  value 
767+         # Handle decoder conv1 
768+         elif  key  ==  "decoder.conv1.weight" :
769+             new_state_dict ["decoder.conv_in.weight" ] =  value 
770+         elif  key  ==  "decoder.conv1.bias" :
771+             new_state_dict ["decoder.conv_in.bias" ] =  value 
772+         # Handle encoder downsamples 
773+         elif  key .startswith ("encoder.downsamples." ):
774+             # Change encoder.downsamples to encoder.down_blocks 
775+             new_key  =  key .replace ("encoder.downsamples." , "encoder.down_blocks." )
776+             
777+             # Handle residual blocks - change downsamples to resnets and rename components 
778+             if  "residual"  in  new_key  or  "shortcut"  in  new_key :
779+                 # Change the second downsamples to resnets 
780+                 new_key  =  new_key .replace (".downsamples." , ".resnets." )
781+                 
782+                 # Rename residual components 
783+                 if  ".residual.0.gamma"  in  new_key :
784+                     new_key  =  new_key .replace (".residual.0.gamma" , ".norm1.gamma" )
785+                 elif  ".residual.2.weight"  in  new_key :
786+                     new_key  =  new_key .replace (".residual.2.weight" , ".conv1.weight" )
787+                 elif  ".residual.2.bias"  in  new_key :
788+                     new_key  =  new_key .replace (".residual.2.bias" , ".conv1.bias" )
789+                 elif  ".residual.3.gamma"  in  new_key :
790+                     new_key  =  new_key .replace (".residual.3.gamma" , ".norm2.gamma" )
791+                 elif  ".residual.6.weight"  in  new_key :
792+                     new_key  =  new_key .replace (".residual.6.weight" , ".conv2.weight" )
793+                 elif  ".residual.6.bias"  in  new_key :
794+                     new_key  =  new_key .replace (".residual.6.bias" , ".conv2.bias" )
795+                 elif  ".shortcut.weight"  in  new_key :
796+                     new_key  =  new_key .replace (".shortcut.weight" , ".conv_shortcut.weight" )
797+                 elif  ".shortcut.bias"  in  new_key :
798+                     new_key  =  new_key .replace (".shortcut.bias" , ".conv_shortcut.bias" )
799+             
800+             # Handle resample blocks - change downsamples to downsampler and remove index 
801+             elif  "resample"  in  new_key  or  "time_conv"  in  new_key :
802+                 # Change the second downsamples to downsampler and remove the index 
803+                 parts  =  new_key .split ("." )
804+                 # Find the pattern: encoder.down_blocks.X.downsamples.Y.resample... 
805+                 # We want to change it to: encoder.down_blocks.X.downsampler.resample... 
806+                 if  len (parts ) >=  4  and  parts [3 ] ==  "downsamples" :
807+                     # Remove the index (parts[4]) and change downsamples to downsampler 
808+                     new_parts  =  parts [:3 ] +  ["downsampler" ] +  parts [5 :]
809+                     new_key  =  "." .join (new_parts )
810+             
811+             new_state_dict [new_key ] =  value 
812+ 
813+         # Handle decoder upsamples 
814+         elif  key .startswith ("decoder.upsamples." ):
815+             # Change decoder.upsamples to decoder.up_blocks 
816+             new_key  =  key .replace ("decoder.upsamples." , "decoder.up_blocks." )
817+             
818+             # Handle residual blocks - change upsamples to resnets and rename components 
819+             if  "residual"  in  new_key  or  "shortcut"  in  new_key :
820+                 # Change the second upsamples to resnets 
821+                 new_key  =  new_key .replace (".upsamples." , ".resnets." )
822+                 
823+                 # Rename residual components 
824+                 if  ".residual.0.gamma"  in  new_key :
825+                     new_key  =  new_key .replace (".residual.0.gamma" , ".norm1.gamma" )
826+                 elif  ".residual.2.weight"  in  new_key :
827+                     new_key  =  new_key .replace (".residual.2.weight" , ".conv1.weight" )
828+                 elif  ".residual.2.bias"  in  new_key :
829+                     new_key  =  new_key .replace (".residual.2.bias" , ".conv1.bias" )
830+                 elif  ".residual.3.gamma"  in  new_key :
831+                     new_key  =  new_key .replace (".residual.3.gamma" , ".norm2.gamma" )
832+                 elif  ".residual.6.weight"  in  new_key :
833+                     new_key  =  new_key .replace (".residual.6.weight" , ".conv2.weight" )
834+                 elif  ".residual.6.bias"  in  new_key :
835+                     new_key  =  new_key .replace (".residual.6.bias" , ".conv2.bias" )
836+                 elif  ".shortcut.weight"  in  new_key :
837+                     new_key  =  new_key .replace (".shortcut.weight" , ".conv_shortcut.weight" )
838+                 elif  ".shortcut.bias"  in  new_key :
839+                     new_key  =  new_key .replace (".shortcut.bias" , ".conv_shortcut.bias" )
840+             
841+             # Handle resample blocks - change upsamples to upsampler and remove index 
842+             elif  "resample"  in  new_key  or  "time_conv"  in  new_key :
843+                 # Change the second upsamples to upsampler and remove the index 
844+                 parts  =  new_key .split ("." )
845+                 # Find the pattern: encoder.down_blocks.X.downsamples.Y.resample... 
846+                 # We want to change it to: encoder.down_blocks.X.downsampler.resample... 
847+                 if  len (parts ) >=  4  and  parts [3 ] ==  "upsamples" :
848+                     # Remove the index (parts[4]) and change upsamples to upsampler 
849+                     new_parts  =  parts [:3 ] +  ["upsampler" ] +  parts [5 :]
850+                     new_key  =  "." .join (new_parts )
851+             
852+             new_state_dict [new_key ] =  value 
853+         else :
854+             # Keep other keys unchanged 
855+             new_state_dict [key ] =  value 
856+ 
857+     with  init_empty_weights ():
858+         vae  =  AutoencoderKLWan (** vae22_config )
859+     vae .load_state_dict (new_state_dict , strict = True , assign = True )
860+     return  vae 
861+ 
563862
564863def  get_args ():
565864    parser  =  argparse .ArgumentParser ()
@@ -586,7 +885,11 @@ def get_args():
586885        transformer  =  convert_transformer (args .model_type )
587886        transformer_2  =  None 
588887
589-     vae  =  convert_vae ()
888+     if  "Wan2.2"  in  args .model_type  and  "TI2V"  in  args .model_type :
889+         vae  =  convert_vae_22 ()
890+     else :
891+         vae  =  convert_vae ()
892+ 
590893    text_encoder  =  UMT5EncoderModel .from_pretrained ("google/umt5-xxl" , torch_dtype = torch .bfloat16 )
591894    tokenizer  =  AutoTokenizer .from_pretrained ("google/umt5-xxl" )
592895    flow_shift  =  16.0  if  "FLF2V"  in  args .model_type  else  3.0 
@@ -609,6 +912,16 @@ def get_args():
609912            scheduler = scheduler ,
610913            boundary_ratio = 0.9 ,
611914        )
915+     elif  "Wan2.2"  and  "T2V"  in  args .model_type :
916+         pipe  =  WanPipeline (
917+             transformer = transformer ,
918+             transformer_2 = transformer_2 ,
919+             text_encoder = text_encoder ,
920+             tokenizer = tokenizer ,
921+             vae = vae ,
922+             scheduler = scheduler ,
923+             boundary_ratio = 0.875 ,
924+         )
612925    elif  "I2V"  in  args .model_type  or  "FLF2V"  in  args .model_type :
613926        image_encoder  =  CLIPVisionModelWithProjection .from_pretrained (
614927            "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" , torch_dtype = torch .bfloat16 
0 commit comments