@@ -1896,18 +1896,16 @@ def get_alpha_scales(down_weight, alpha_key):
18961896 original_key_B = f"blocks.{ i } .cross_attn.{ o } .{ lora_up_key } .weight"
18971897 converted_key_B = f"blocks.{ i } .attn2.{ c } .lora_B.weight"
18981898
1899- if has_alpha :
1899+ if original_key_A in original_state_dict :
19001900 down_weight = original_state_dict .pop (original_key_A )
1901+ converted_state_dict [converted_key_A ] = down_weight
1902+ if original_key_B in original_state_dict :
19011903 up_weight = original_state_dict .pop (original_key_B )
1904+ converted_state_dict [converted_key_B ] = up_weight
1905+ if has_alpha :
19021906 scale_down , scale_up = get_alpha_scales (down_weight , alpha_key )
1903- converted_state_dict [converted_key_A ] = down_weight * scale_down
1904- converted_state_dict [converted_key_B ] = up_weight * scale_up
1905- else :
1906- if original_key_A in original_state_dict :
1907- converted_state_dict [converted_key_A ] = original_state_dict .pop (original_key_A )
1908-
1909- if original_key_B in original_state_dict :
1910- converted_state_dict [converted_key_B ] = original_state_dict .pop (original_key_B )
1907+ converted_state_dict [converted_key_A ] *= scale_down
1908+ converted_state_dict [converted_key_B ] *= scale_up
19111909
19121910 original_key = f"blocks.{ i } .cross_attn.{ o } .diff_b"
19131911 converted_key = f"blocks.{ i } .attn2.{ c } .lora_B.bias"
@@ -1924,18 +1922,16 @@ def get_alpha_scales(down_weight, alpha_key):
19241922 original_key_B = f"blocks.{ i } .cross_attn.{ o } .{ lora_up_key } .weight"
19251923 converted_key_B = f"blocks.{ i } .attn2.{ c } .lora_B.weight"
19261924
1927- if has_alpha :
1925+ if original_key_A in original_state_dict :
19281926 down_weight = original_state_dict .pop (original_key_A )
1927+ converted_state_dict [converted_key_A ] = down_weight
1928+ if original_key_B in original_state_dict :
19291929 up_weight = original_state_dict .pop (original_key_B )
1930+ converted_state_dict [converted_key_B ] = up_weight
1931+ if has_alpha :
19301932 scale_down , scale_up = get_alpha_scales (down_weight , alpha_key )
1931- converted_state_dict [converted_key_A ] = down_weight * scale_down
1932- converted_state_dict [converted_key_B ] = up_weight * scale_up
1933- else :
1934- if original_key_A in original_state_dict :
1935- converted_state_dict [converted_key_A ] = original_state_dict .pop (original_key_A )
1936-
1937- if original_key_B in original_state_dict :
1938- converted_state_dict [converted_key_B ] = original_state_dict .pop (original_key_B )
1933+ converted_state_dict [converted_key_A ] *= scale_down
1934+ converted_state_dict [converted_key_B ] *= scale_up
19391935
19401936 original_key = f"blocks.{ i } .cross_attn.{ o } .diff_b"
19411937 converted_key = f"blocks.{ i } .attn2.{ c } .lora_B.bias"
@@ -1952,18 +1948,16 @@ def get_alpha_scales(down_weight, alpha_key):
19521948 original_key_B = f"blocks.{ i } .{ o } .{ lora_up_key } .weight"
19531949 converted_key_B = f"blocks.{ i } .ffn.{ c } .lora_B.weight"
19541950
1955- if has_alpha :
1951+ if original_key_A in original_state_dict :
19561952 down_weight = original_state_dict .pop (original_key_A )
1953+ converted_state_dict [converted_key_A ] = down_weight
1954+ if original_key_B in original_state_dict :
19571955 up_weight = original_state_dict .pop (original_key_B )
1956+ converted_state_dict [converted_key_B ] = up_weight
1957+ if has_alpha :
19581958 scale_down , scale_up = get_alpha_scales (down_weight , alpha_key )
1959- converted_state_dict [converted_key_A ] = down_weight * scale_down
1960- converted_state_dict [converted_key_B ] = up_weight * scale_up
1961- else :
1962- if original_key_A in original_state_dict :
1963- converted_state_dict [converted_key_A ] = original_state_dict .pop (original_key_A )
1964-
1965- if original_key_B in original_state_dict :
1966- converted_state_dict [converted_key_B ] = original_state_dict .pop (original_key_B )
1959+ converted_state_dict [converted_key_A ] *= scale_down
1960+ converted_state_dict [converted_key_B ] *= scale_up
19671961
19681962 original_key = f"blocks.{ i } .{ o } .diff_b"
19691963 converted_key = f"blocks.{ i } .ffn.{ c } .lora_B.bias"
0 commit comments