@@ -95,6 +95,7 @@ def check_message(msg):
9595 '--seq-length' , str (md .seq_length ),
9696 '--num-attention-heads' , str (md .num_attention_heads ),
9797 '--max-position-embeddings' , str (md .max_position_embeddings ),
98+ '--attention-head-type' , str (md .attention_head_type ),
9899 '--tokenizer-type' , str (md .tokenizer_type ),
99100 '--tensor-model-parallel-size' , str (args .target_tensor_parallel_size ),
100101 '--pipeline-model-parallel-size' , str (args .target_pipeline_parallel_size ),
@@ -225,10 +226,17 @@ def get_models(count, dtype, pre_process, post_process):
225226 post_layernorm_weight = msg .pop ("post layernorm weight" )
226227 post_layernorm_bias = msg .pop ("post layernorm bias" )
227228 mlp_l1_bias = msg .pop ("mlp l1 bias" )
229+ if margs .attention_head_type == "multiquery" :
230+ kv_weight = msg .pop ("kv weight" )
231+ kv_bias = msg .pop ("kv bias" )
228232
229233 # Split up the parallel tensors
230- qkv_weight = torch .chunk (msg .pop ("qkv weight" ), args .target_tensor_parallel_size , dim = 0 )
231- qkv_bias = torch .chunk (msg .pop ("qkv bias" ), args .target_tensor_parallel_size , dim = 0 )
234+ if margs .attention_head_type == "multihead" :
235+ qkv_weight = torch .chunk (msg .pop ("qkv weight" ), args .target_tensor_parallel_size , dim = 0 )
236+ qkv_bias = torch .chunk (msg .pop ("qkv bias" ), args .target_tensor_parallel_size , dim = 0 )
237+ elif margs .attention_head_type == "multiquery" :
238+ q_weight = torch .chunk (msg .pop ("q weight" ), args .target_tensor_parallel_size , dim = 0 )
239+ q_bias = torch .chunk (msg .pop ("q bias" ), args .target_tensor_parallel_size , dim = 0 )
232240 dense_weight = torch .chunk (msg .pop ("dense weight" ), args .target_tensor_parallel_size , dim = 1 )
233241 mlp_l0_weight = torch .chunk (msg .pop ("mlp l0 weight" ), args .target_tensor_parallel_size , dim = 0 )
234242 mlp_l0_bias = torch .chunk (msg .pop ("mlp l0 bias" ), args .target_tensor_parallel_size , dim = 0 )
@@ -239,8 +247,15 @@ def get_models(count, dtype, pre_process, post_process):
239247 l = models [tp_rank ].language_model .encoder .layers [layer ]
240248 l .input_layernorm .weight .data .copy_ (input_layernorm_weight )
241249 l .input_layernorm .bias .data .copy_ (input_layernorm_bias )
242- l .self_attention .query_key_value .weight .data .copy_ (qkv_weight [tp_rank ])
243- l .self_attention .query_key_value .bias .data .copy_ (qkv_bias [tp_rank ])
250+ if margs .attention_head_type == "multihead" :
251+ l .self_attention .query_key_value .weight .data .copy_ (qkv_weight [tp_rank ])
252+ l .self_attention .query_key_value .bias .data .copy_ (qkv_bias [tp_rank ])
253+ elif margs .attention_head_type == "multiquery" :
254+ # MQA: key-value are shared across tp-ranks
255+ l .self_attention .key_value .weight .data .copy_ (kv_weight )
256+ l .self_attention .key_value .bias .data .copy_ (kv_bias )
257+ l .self_attention .query .weight .data .copy_ (q_weight [tp_rank ])
258+ l .self_attention .query .bias .data .copy_ (q_bias [tp_rank ])
244259 l .self_attention .dense .weight .data .copy_ (dense_weight [tp_rank ])
245260 l .self_attention .dense .bias .data .copy_ (dense_bias )
246261 l .post_attention_layernorm .weight .data .copy_ (post_layernorm_weight )
0 commit comments