@@ -4833,10 +4833,10 @@ def _maybe_expand_t2v_lora_for_i2v(
48334833 continue
48344834
48354835 state_dict [f"transformer.blocks.{ i } .attn2.{ c } .lora_A.weight" ] = torch .zeros_like (
4836- state_dict [f"transformer.blocks. { i } .attn2.to_k.lora_A.weight" ], device = target_device
4836+ state_dict [ref_key_lora_A ], device = target_device # Using original ref_key_lora_A
48374837 )
48384838 state_dict [f"transformer.blocks.{ i } .attn2.{ c } .lora_B.weight" ] = torch .zeros_like (
4839- state_dict [f"transformer.blocks. { i } .attn2.to_k.lora_B.weight" ], device = target_device
4839+ state_dict [ref_key_lora_B ], device = target_device # Using original ref_key_lora_B
48404840 )
48414841
48424842 # If the original LoRA had biases (indicated by has_bias)
@@ -4850,6 +4850,47 @@ def _maybe_expand_t2v_lora_for_i2v(
48504850 device = target_device ,
48514851 )
48524852
4853+ if hasattr (transformer , 'vace_blocks' ):
4854+ inferred_rank_for_vace = None
4855+ lora_weights_dtype_for_vace = next (iter (transformer .parameters ())).dtype # Fallback dtype
4856+
4857+ for k_lora_any , v_lora_tensor_any in state_dict .items ():
4858+ if k_lora_any .endswith (".lora_A.weight" ):
4859+ inferred_rank_for_vace = v_lora_tensor_any .shape [0 ]
4860+ lora_weights_dtype_for_vace = v_lora_tensor_any .dtype
4861+ break # Found one, good enough for rank and dtype
4862+
4863+ if inferred_rank_for_vace is not None :
4864+ current_lora_has_bias = any (".lora_B.bias" in k for k in state_dict .keys ())
4865+
4866+ for i , vace_block_module_in_model in enumerate (transformer .vace_blocks ):
4867+ if hasattr (vace_block_module_in_model , 'proj_out' ):
4868+
4869+ proj_out_linear_layer_in_model = vace_block_module_in_model .proj_out
4870+
4871+ vace_lora_A_key = f"vace_blocks.{ i } .proj_out.lora_A.weight"
4872+ vace_lora_B_key = f"vace_blocks.{ i } .proj_out.lora_B.weight"
4873+
4874+ if vace_lora_A_key not in state_dict :
4875+ state_dict [vace_lora_A_key ] = torch .zeros (
4876+ (inferred_rank_for_vace , proj_out_linear_layer_in_model .in_features ),
4877+ device = target_device , dtype = lora_weights_dtype_for_vace
4878+ )
4879+
4880+ if vace_lora_B_key not in state_dict :
4881+ state_dict [vace_lora_B_key ] = torch .zeros (
4882+ (proj_out_linear_layer_in_model .out_features , inferred_rank_for_vace ),
4883+ device = target_device , dtype = lora_weights_dtype_for_vace
4884+ )
4885+
4886+ if current_lora_has_bias and proj_out_linear_layer_in_model .bias is not None :
4887+ vace_lora_B_bias_key = f"vace_blocks.{ i } .proj_out.lora_B.bias"
4888+ if vace_lora_B_bias_key not in state_dict :
4889+ state_dict [vace_lora_B_bias_key ] = torch .zeros_like (
4890+ proj_out_linear_layer_in_model .bias ,
4891+ device = target_device
4892+ )
4893+
48534894 return state_dict
48544895
48554896 def load_lora_weights (
0 commit comments