@@ -43,6 +43,24 @@ def _from_diffsynth(self, state_dict):
4343 dit_dict [key ] = lora_args
4444 return {"dit" : dit_dict }
4545
46+ def _from_diffusers (self , state_dict ):
47+ dit_dict = {}
48+ for key , param in state_dict .items ():
49+ if ".lora_down.weight" not in key :
50+ continue
51+
52+ lora_args = {}
53+ lora_args ["up" ] = state_dict [key .replace (".lora_down.weight" , ".lora_up.weight" )]
54+ lora_args ["down" ] = param
55+ lora_args ["rank" ] = lora_args ["up" ].shape [1 ]
56+ if key .replace (".lora_down.weight" , ".alpha" ) in state_dict :
57+ lora_args ["alpha" ] = state_dict [key .replace (".lora_down.weight" , ".alpha" )]
58+ else :
59+ lora_args ["alpha" ] = lora_args ["rank" ]
60+ key = key .replace ("diffusion_model." , "" ).replace (".lora_down.weight" , "" )
61+ dit_dict [key ] = lora_args
62+ return {"dit" : dit_dict }
63+
4664 def _from_civitai (self , state_dict ):
4765 dit_dict = {}
4866 for key , param in state_dict .items ():
@@ -86,6 +104,9 @@ def convert(self, state_dict):
86104 if "lora_unet_blocks_0_cross_attn_k.lora_down.weight" in state_dict :
87105 state_dict = self ._from_fun (state_dict )
88106 logger .info ("use fun format state dict" )
107+ elif "diffusion_model.blocks.0.cross_attn.k.lora_down.weight" in state_dict :
108+ state_dict = self ._from_diffusers (state_dict )
109+ logger .info ("use diffusers format state dict" )
89110 elif "diffusion_model.blocks.0.cross_attn.k.lora_A.weight" in state_dict :
90111 state_dict = self ._from_civitai (state_dict )
91112 logger .info ("use civitai format state dict" )
@@ -480,8 +501,8 @@ def from_pretrained(cls, model_path_or_config: WanPipelineConfig) -> "WanVideoPi
480501
481502 dit_state_dict , dit2_state_dict = None , None
482503 if isinstance (config .model_path , list ):
483- high_noise_model_ckpt = [path for path in config .model_path if "high_noise_model " in path ]
484- low_noise_model_ckpt = [path for path in config .model_path if "low_noise_model " in path ]
504+ high_noise_model_ckpt = [path for path in config .model_path if "high_noise " in path ]
505+ low_noise_model_ckpt = [path for path in config .model_path if "low_noise " in path ]
485506 if high_noise_model_ckpt and low_noise_model_ckpt :
486507 logger .info (f"loading high noise model state dict from { high_noise_model_ckpt } ..." )
487508 dit_state_dict = cls .load_model_checkpoint (
@@ -681,8 +702,9 @@ def has_any_key(*xs):
681702 config .attn_params = VideoSparseAttentionParams (sparsity = 0.9 )
682703
683704 def update_weights (self , state_dicts : WanStateDicts ) -> None :
684- is_dual_model_state_dict = (isinstance (state_dicts .model , dict ) and
685- ("high_noise_model" in state_dicts .model or "low_noise_model" in state_dicts .model ))
705+ is_dual_model_state_dict = isinstance (state_dicts .model , dict ) and (
706+ "high_noise_model" in state_dicts .model or "low_noise_model" in state_dicts .model
707+ )
686708 is_dual_model_pipeline = self .dit2 is not None
687709
688710 if is_dual_model_state_dict != is_dual_model_pipeline :
@@ -694,15 +716,21 @@ def update_weights(self, state_dicts: WanStateDicts) -> None:
694716
695717 if is_dual_model_state_dict :
696718 if "high_noise_model" in state_dicts .model :
697- self .update_component (self .dit , state_dicts .model ["high_noise_model" ], self .config .device , self .config .model_dtype )
719+ self .update_component (
720+ self .dit , state_dicts .model ["high_noise_model" ], self .config .device , self .config .model_dtype
721+ )
698722 if "low_noise_model" in state_dicts .model :
699- self .update_component (self .dit2 , state_dicts .model ["low_noise_model" ], self .config .device , self .config .model_dtype )
723+ self .update_component (
724+ self .dit2 , state_dicts .model ["low_noise_model" ], self .config .device , self .config .model_dtype
725+ )
700726 else :
701727 self .update_component (self .dit , state_dicts .model , self .config .device , self .config .model_dtype )
702728
703729 self .update_component (self .text_encoder , state_dicts .t5 , self .config .device , self .config .t5_dtype )
704730 self .update_component (self .vae , state_dicts .vae , self .config .device , self .config .vae_dtype )
705- self .update_component (self .image_encoder , state_dicts .image_encoder , self .config .device , self .config .image_encoder_dtype )
731+ self .update_component (
732+ self .image_encoder , state_dicts .image_encoder , self .config .device , self .config .image_encoder_dtype
733+ )
706734
707735 def compile (self ):
708736 self .dit .compile_repeated_blocks ()
0 commit comments