@@ -4859,11 +4859,57 @@ def _maybe_expand_t2v_lora_for_vace(
48594859 transformer : torch .nn .Module ,
48604860 state_dict ,
48614861 ):
4862+ target_device = transformer .device
4863+ if hasattr (transformer , 'vace_blocks' ):
4864+ inferred_rank_for_vace = None
4865+ lora_weights_dtype_for_vace = next (iter (transformer .parameters ())).dtype # Fallback dtype
4866+
4867+ for k_lora_any , v_lora_tensor_any in state_dict .items ():
4868+ if k_lora_any .endswith (".lora_A.weight" ):
4869+ inferred_rank_for_vace = v_lora_tensor_any .shape [0 ]
4870+ lora_weights_dtype_for_vace = v_lora_tensor_any .dtype
4871+ break # Found one, good enough for rank and dtype
4872+
4873+ if inferred_rank_for_vace is not None :
4874+ # Determine if the LoRA format (as potentially modified by I2V expansion) includes bias
4875+ # This re-checks 'has_bias' based on the *current* state_dict.
4876+ current_lora_has_bias = any (".lora_B.bias" in k for k in state_dict .keys ())
4877+
4878+ for i , vace_block_module_in_model in enumerate (transformer .vace_blocks ):
4879+ # Specifically target proj_out as per the error message
4880+ if hasattr (vace_block_module_in_model , 'proj_out' ) and \
4881+ isinstance (vace_block_module_in_model .proj_out , nn .Linear ):
4882+
4883+ proj_out_linear_layer_in_model = vace_block_module_in_model .proj_out
4884+
4885+ vace_lora_A_key = f"vace_blocks.{ i } .proj_out.lora_A.weight"
4886+ vace_lora_B_key = f"vace_blocks.{ i } .proj_out.lora_B.weight"
4887+
4888+ if vace_lora_A_key not in state_dict :
4889+ state_dict [vace_lora_A_key ] = torch .zeros (
4890+ (inferred_rank_for_vace , proj_out_linear_layer_in_model .in_features ),
4891+ device = target_device , dtype = lora_weights_dtype_for_vace
4892+ )
4893+
4894+ if vace_lora_B_key not in state_dict :
4895+ state_dict [vace_lora_B_key ] = torch .zeros (
4896+ (proj_out_linear_layer_in_model .out_features , inferred_rank_for_vace ),
4897+ device = target_device , dtype = lora_weights_dtype_for_vace
4898+ )
4899+
4900+ # Use 'current_lora_has_bias' to decide on padding bias for VACE blocks
4901+ if current_lora_has_bias and proj_out_linear_layer_in_model .bias is not None :
4902+ vace_lora_B_bias_key = f"vace_blocks.{ i } .proj_out.lora_B.bias"
4903+ if vace_lora_B_bias_key not in state_dict :
4904+ state_dict [vace_lora_B_bias_key ] = torch .zeros_like (
4905+ proj_out_linear_layer_in_model .bias , # Shape from model's bias
4906+ device = target_device # Dtype from model's bias implicitly by zeros_like
4907+ )
4908+
4909+ print ("AFTER 2:" , list (state_dict .keys ()))
4910+ return state_dict
48624911
4863- if not hasattr (transformer , 'vace_blocks' ):
4864- return state_dict
48654912
4866- target_device = transformer .device
48674913
48684914 return state_dict
48694915
@@ -4915,6 +4961,10 @@ def load_lora_weights(
49154961 transformer = getattr (self , self .transformer_name ) if not hasattr (self , "transformer" ) else self .transformer ,
49164962 state_dict = state_dict ,
49174963 )
4964+ state_dict = self ._maybe_expand_t2v_lora_for_vace (
4965+ transformer = getattr (self , self .transformer_name ) if not hasattr (self , "transformer" ) else self .transformer ,
4966+ state_dict = state_dict ,
4967+ )
49184968 is_correct_format = all ("lora" in key for key in state_dict .keys ())
49194969 if not is_correct_format :
49204970 raise ValueError ("Invalid LoRA checkpoint." )
0 commit comments