@@ -1888,7 +1888,8 @@ def get_alpha_scales(down_weight, alpha_key):
18881888
18891889 # Cross-attention
18901890 for o , c in zip (["q" , "k" , "v" , "o" ], ["to_q" , "to_k" , "to_v" , "to_out.0" ]):
1891- has_alpha = f"blocks.{ i } .cross_attn.{ o } .alpha" in original_state_dict
1891+ alpha_key = f"blocks.{ i } .cross_attn.{ o } .alpha"
1892+ has_alpha = alpha_key in original_state_dict
18921893 original_key_A = f"blocks.{ i } .cross_attn.{ o } .{ lora_down_key } .weight"
18931894 converted_key_A = f"blocks.{ i } .attn2.{ c } .lora_A.weight"
18941895
@@ -1898,7 +1899,7 @@ def get_alpha_scales(down_weight, alpha_key):
18981899 if has_alpha :
18991900 down_weight = original_state_dict .pop (original_key_A )
19001901 up_weight = original_state_dict .pop (original_key_B )
1901- scale_down , scale_up = get_alpha_scales (down_weight , f"blocks. { i } .cross_attn. { o } .alpha" )
1902+ scale_down , scale_up = get_alpha_scales (down_weight , alpha_key )
19021903 converted_state_dict [converted_key_A ] = down_weight * scale_down
19031904 converted_state_dict [converted_key_B ] = up_weight * scale_up
19041905 else :
@@ -1915,7 +1916,8 @@ def get_alpha_scales(down_weight, alpha_key):
19151916
19161917 if is_i2v_lora :
19171918 for o , c in zip (["k_img" , "v_img" ], ["add_k_proj" , "add_v_proj" ]):
1918- has_alpha = f"blocks.{ i } .cross_attn.{ o } .alpha" in original_state_dict
1919+ alpha_key = f"blocks.{ i } .cross_attn.{ o } .alpha"
1920+ has_alpha = alpha_key in original_state_dict
19191921 original_key_A = f"blocks.{ i } .cross_attn.{ o } .{ lora_down_key } .weight"
19201922 converted_key_A = f"blocks.{ i } .attn2.{ c } .lora_A.weight"
19211923
@@ -1925,7 +1927,7 @@ def get_alpha_scales(down_weight, alpha_key):
19251927 if has_alpha :
19261928 down_weight = original_state_dict .pop (original_key_A )
19271929 up_weight = original_state_dict .pop (original_key_B )
1928- scale_down , scale_up = get_alpha_scales (down_weight , f"blocks. { i } .cross_attn. { o } .alpha" )
1930+ scale_down , scale_up = get_alpha_scales (down_weight , alpha_key )
19291931 converted_state_dict [converted_key_A ] = down_weight * scale_down
19301932 converted_state_dict [converted_key_B ] = up_weight * scale_up
19311933 else :
@@ -1942,7 +1944,8 @@ def get_alpha_scales(down_weight, alpha_key):
19421944
19431945 # FFN
19441946 for o , c in zip (["ffn.0" , "ffn.2" ], ["net.0.proj" , "net.2" ]):
1945- has_alpha = f"blocks.{ i } .{ o } .alpha" in original_state_dict
1947+ alpha_key = f"blocks.{ i } .{ o } .alpha"
1948+ has_alpha = alpha_key in original_state_dict
19461949 original_key_A = f"blocks.{ i } .{ o } .{ lora_down_key } .weight"
19471950 converted_key_A = f"blocks.{ i } .ffn.{ c } .lora_A.weight"
19481951
@@ -1952,7 +1955,7 @@ def get_alpha_scales(down_weight, alpha_key):
19521955 if has_alpha :
19531956 down_weight = original_state_dict .pop (original_key_A )
19541957 up_weight = original_state_dict .pop (original_key_B )
1955- scale_down , scale_up = get_alpha_scales (down_weight , f"blocks. { i } . { o } .alpha" )
1958+ scale_down , scale_up = get_alpha_scales (down_weight , alpha_key )
19561959 converted_state_dict [converted_key_A ] = down_weight * scale_down
19571960 converted_state_dict [converted_key_B ] = up_weight * scale_up
19581961 else :
0 commit comments