@@ -4255,22 +4255,23 @@ def _maybe_expand_t2v_lora_for_i2v(
42554255 transformer : torch .nn .Module ,
42564256 state_dict ,
42574257 ):
4258- num_blocks = len ({k .split ("blocks." )[1 ].split ("." )[0 ] for k in state_dict })
4259- is_i2v_lora = any ("k_img" in k for k in state_dict ) and any ("v_img" in k for k in state_dict )
4260- if not is_i2v_lora :
4261- return state_dict
4262-
4263- if transformer .config .image_dim is None :
4264- return state_dict
4265-
4266- for i in range (num_blocks ):
4267- for o , c in zip (["k_img" , "v_img" ], ["add_k_proj" , "add_v_proj" ]):
4268- state_dict [f"blocks.{ i } .attn2.{ c } .lora_A.weight" ] = torch .zeros_like (
4269- state_dict [f"blocks.{ i } .attn2.{ o .replace ('_img' , '' )} .lora_A.weight" ]
4270- )
4271- state_dict [f"blocks.{ i } .attn2.{ c } .lora_B.weight" ] = torch .zeros_like (
4272- state_dict [f"blocks.{ i } .attn2.{ o .replace ('_img' , '' )} .lora_B.weight" ]
4273- )
4258+ if any (k .startswith ("blocks." ) for k in state_dict ):
4259+ num_blocks = len ({k .split ("blocks." )[1 ].split ("." )[0 ] for k in state_dict })
4260+ is_i2v_lora = any ("k_img" in k for k in state_dict ) and any ("v_img" in k for k in state_dict )
4261+ if not is_i2v_lora :
4262+ return state_dict
4263+
4264+ if transformer .config .image_dim is None :
4265+ return state_dict
4266+
4267+ for i in range (num_blocks ):
4268+ for o , c in zip (["k_img" , "v_img" ], ["add_k_proj" , "add_v_proj" ]):
4269+ state_dict [f"blocks.{ i } .attn2.{ c } .lora_A.weight" ] = torch .zeros_like (
4270+ state_dict [f"blocks.{ i } .attn2.{ o .replace ('_img' , '' )} .lora_A.weight" ]
4271+ )
4272+ state_dict [f"blocks.{ i } .attn2.{ c } .lora_B.weight" ] = torch .zeros_like (
4273+ state_dict [f"blocks.{ i } .attn2.{ o .replace ('_img' , '' )} .lora_B.weight" ]
4274+ )
42744275
42754276 return state_dict
42764277
0 commit comments