@@ -30,6 +30,14 @@ def remap_txt_attn_qkv_(key, state_dict):
3030 state_dict [key .replace ("txt_attn_qkv" , "attn.add_v_proj" )] = to_v
3131
3232
33+ def remap_self_attn_qkv_ (key , state_dict ):
34+ weight = state_dict .pop (key )
35+ to_q , to_k , to_v = weight .chunk (3 , dim = 0 )
36+ state_dict [key .replace ("self_attn_qkv" , "attn.to_q" )] = to_q
37+ state_dict [key .replace ("self_attn_qkv" , "attn.to_k" )] = to_k
38+ state_dict [key .replace ("self_attn_qkv" , "attn.to_v" )] = to_v
39+
40+
3341def remap_single_transformer_blocks_ (key , state_dict ):
3442 hidden_size = 3072
3543
@@ -69,6 +77,7 @@ def remap_single_transformer_blocks_(key, state_dict):
6977 # "vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
7078 # "vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
7179 "double_blocks" : "transformer_blocks" ,
80+ "individual_token_refiner.blocks" : "token_refiner.refiner_blocks" ,
7281 "img_attn_q_norm" : "attn.norm_q" ,
7382 "img_attn_k_norm" : "attn.norm_k" ,
7483 "img_attn_proj" : "attn.to_out.0" ,
@@ -83,6 +92,7 @@ def remap_single_transformer_blocks_(key, state_dict):
8392 "txt_norm1" : "norm1.norm" ,
8493 "txt_norm2" : "norm2_context" ,
8594 "txt_mlp" : "ff_context" ,
95+ "self_attn_proj" : "attn.to_out.0" ,
8696 "modulation.linear" : "norm.linear" ,
8797 "pre_norm" : "norm.norm" ,
8898 "final_layer.norm_final" : "norm_out.norm" ,
@@ -95,6 +105,7 @@ def remap_single_transformer_blocks_(key, state_dict):
95105 "final_layer.adaLN_modulation.1" : remap_norm_scale_shift_ ,
96106 "img_attn_qkv" : remap_img_attn_qkv_ ,
97107 "txt_attn_qkv" : remap_txt_attn_qkv_ ,
108+ "self_attn_qkv" : remap_self_attn_qkv_ ,
98109 "single_blocks" : remap_single_transformer_blocks_ ,
99110}
100111
0 commit comments