@@ -4810,11 +4810,11 @@ def _maybe_expand_t2v_lora_for_i2v(
48104810 transformer : torch .nn .Module ,
48114811 state_dict ,
48124812 ):
4813- print ("wtf 0" , hasattr (transformer , 'vace_blocks' ))
4814- # if transformer.config.image_dim is None:
4815- # return state_dict
4813+ if transformer .config .image_dim is None :
4814+ return state_dict
48164815
48174816 target_device = transformer .device
4817+
48184818 if any (k .startswith ("transformer.blocks." ) for k in state_dict ):
48194819 num_blocks = len ({k .split ("blocks." )[1 ].split ("." )[0 ] for k in state_dict if "blocks." in k })
48204820 is_i2v_lora = any ("add_k_proj" in k for k in state_dict ) and any ("add_v_proj" in k for k in state_dict )
@@ -4833,10 +4833,10 @@ def _maybe_expand_t2v_lora_for_i2v(
48334833 continue
48344834
48354835 state_dict [f"transformer.blocks.{ i } .attn2.{ c } .lora_A.weight" ] = torch .zeros_like (
4836- state_dict [ref_key_lora_A ], device = target_device # Using original ref_key_lora_A
4836+ state_dict [f"transformer.blocks. { i } .attn2.to_k.lora_A.weight" ], device = target_device
48374837 )
48384838 state_dict [f"transformer.blocks.{ i } .attn2.{ c } .lora_B.weight" ] = torch .zeros_like (
4839- state_dict [ref_key_lora_B ], device = target_device # Using original ref_key_lora_B
4839+ state_dict [f"transformer.blocks. { i } .attn2.to_k.lora_B.weight" ], device = target_device
48404840 )
48414841
48424842 # If the original LoRA had biases (indicated by has_bias)
@@ -4849,52 +4849,7 @@ def _maybe_expand_t2v_lora_for_i2v(
48494849 ref_lora_B_bias_tensor ,
48504850 device = target_device ,
48514851 )
4852-
4853-
4854- if hasattr (transformer , 'vace_blocks' ):
4855- print (f"{ i } , WTF 0" )
4856- inferred_rank_for_vace = None
4857- lora_weights_dtype_for_vace = next (iter (transformer .parameters ())).dtype # Fallback dtype
4858-
4859- for k_lora_any , v_lora_tensor_any in state_dict .items ():
4860- if k_lora_any .endswith (".lora_A.weight" ):
4861- inferred_rank_for_vace = v_lora_tensor_any .shape [0 ]
4862- lora_weights_dtype_for_vace = v_lora_tensor_any .dtype
4863- break # Found one, good enough for rank and dtype
4864-
4865- if inferred_rank_for_vace is not None :
4866- current_lora_has_bias = any (".lora_B.bias" in k for k in state_dict .keys ())
4867-
4868- for i , vace_block_module_in_model in enumerate (transformer .vace_blocks ):
4869- if hasattr (vace_block_module_in_model , 'proj_out' ):
4870-
4871- proj_out_linear_layer_in_model = vace_block_module_in_model .proj_out
4872-
4873- vace_lora_A_key = f"vace_blocks.{ i } .proj_out.lora_A.weight"
4874- vace_lora_B_key = f"vace_blocks.{ i } .proj_out.lora_B.weight"
4875-
4876- if vace_lora_A_key not in state_dict :
4877- print (f"{ i } , WTF 1" )
4878- state_dict [vace_lora_A_key ] = torch .zeros (
4879- (inferred_rank_for_vace , proj_out_linear_layer_in_model .in_features ),
4880- device = target_device , dtype = lora_weights_dtype_for_vace
4881- )
4882-
4883- if vace_lora_B_key not in state_dict :
4884- print (f"{ i } , WTF 2" )
4885- state_dict [vace_lora_B_key ] = torch .zeros (
4886- (proj_out_linear_layer_in_model .out_features , inferred_rank_for_vace ),
4887- device = target_device , dtype = lora_weights_dtype_for_vace
4888- )
4889-
4890- if current_lora_has_bias and proj_out_linear_layer_in_model .bias is not None :
4891- print (f"{ i } , WTF 3" )
4892- vace_lora_B_bias_key = f"vace_blocks.{ i } .proj_out.lora_B.bias"
4893- if vace_lora_B_bias_key not in state_dict :
4894- state_dict [vace_lora_B_bias_key ] = torch .zeros_like (
4895- proj_out_linear_layer_in_model .bias ,
4896- device = target_device
4897- )
4852+ print (state_dict .keys )
48984853
48994854 return state_dict
49004855
@@ -4942,7 +4897,6 @@ def load_lora_weights(
49424897 # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
49434898 state_dict = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
49444899 # convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers
4945- print ("_maybe_expand_t2v_lora_for_i2v?????????????????" )
49464900 state_dict = self ._maybe_expand_t2v_lora_for_i2v (
49474901 transformer = getattr (self , self .transformer_name ) if not hasattr (self , "transformer" ) else self .transformer ,
49484902 state_dict = state_dict ,
0 commit comments