@@ -2082,6 +2082,7 @@ def load_lora_weights(
20822082 state_dict , network_alphas , metadata = self .lora_state_dict (
20832083 pretrained_model_name_or_path_or_dict , return_alphas = True , ** kwargs
20842084 )
2085+ print (f"{ metadata = } " )
20852086
20862087 has_lora_keys = any ("lora" in key for key in state_dict .keys ())
20872088
@@ -2203,6 +2204,7 @@ def load_lora_into_transformer(
22032204 state_dict ,
22042205 network_alphas = network_alphas ,
22052206 adapter_name = adapter_name ,
2207+ metadata = metadata ,
22062208 _pipeline = _pipeline ,
22072209 low_cpu_mem_usage = low_cpu_mem_usage ,
22082210 hotswap = hotswap ,
@@ -5137,7 +5139,8 @@ def load_lora_weights(
51375139 pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict .copy ()
51385140
51395141 # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
5140- state_dict = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
5142+ kwargs ["return_lora_metadata" ] = True
5143+ state_dict , metadata = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
51415144 # convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers
51425145 state_dict = self ._maybe_expand_t2v_lora_for_i2v (
51435146 transformer = getattr (self , self .transformer_name ) if not hasattr (self , "transformer" ) else self .transformer ,
@@ -5151,6 +5154,7 @@ def load_lora_weights(
51515154 state_dict ,
51525155 transformer = getattr (self , self .transformer_name ) if not hasattr (self , "transformer" ) else self .transformer ,
51535156 adapter_name = adapter_name ,
5157+ metadata = metadata ,
51545158 _pipeline = self ,
51555159 low_cpu_mem_usage = low_cpu_mem_usage ,
51565160 hotswap = hotswap ,
0 commit comments