Skip to content

Commit bd12802

Browse files
committed
support mqa in checkpoint-merging tools
1 parent 659295a commit bd12802

File tree

3 files changed

+40
-9
lines changed

3 files changed

+40
-9
lines changed

megatron/arguments.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ def _add_network_size_args(parser):
430430
'attention. This is set to '
431431
' args.hidden_size // args.num_attention_heads '
432432
'if not provided.')
433-
group.add_argument('--attention-head-type', type=str, default='multihead',
433+
group.add_argument('--attention-head-type', type=str, default=None,
434434
choices=['multihead', 'multiquery'],
435435
help='Type of attention heads. `multihead` is the standard multi-head attention.'
436436
'`multiquery` shares the values and keys across attention heads')

tools/checkpoint_loader_megatron.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def check_for_arg(arg_name):
7878
check_for_arg('iteration')
7979
check_for_arg('bert_binary_head')
8080
check_for_arg('params_dtype')
81+
check_for_arg('attention_head_type')
8182

8283
# Determine how to make our models
8384
if args.model_type == 'GPT':
@@ -147,6 +148,7 @@ def get_models(count, dtype, pre_process, post_process):
147148
# metadata
148149
md = types.SimpleNamespace()
149150
md.model_type = args.model_type
151+
md.attention_head_type = margs.attention_head_type
150152
md.num_layers = margs.num_layers
151153
md.hidden_size = margs.hidden_size
152154
md.seq_length = margs.seq_length
@@ -202,26 +204,40 @@ def queue_put(name, msg):
202204
message["post layernorm weight"] = layer.post_attention_layernorm.weight.data
203205
message["post layernorm bias"] = layer.post_attention_layernorm.bias.data
204206
message["mlp l1 bias"] = layer.mlp.dense_4h_to_h.bias.data
207+
if margs.attention_head_type == "multiquery":
208+
# MQA: kv is shared across tp-ranks
209+
message["kv weight"] = layer.self_attention.key_value.weight.data
210+
message["kv bias"] = layer.self_attention.key_value.bias.data
205211

206212
# Grab all parallel tensors for this layer
207213
qkv_weight = []
208214
qkv_bias = []
215+
q_weight = []
216+
q_bias = []
209217
dense_weight = []
210218
mlp_l0_weight = []
211219
mlp_l0_bias = []
212220
mlp_l1_weight = []
213221
for tp_rank, model in enumerate(models):
214222
layer = model.language_model.encoder.layers[layer_num]
215-
qkv_weight.append(layer.self_attention.query_key_value.weight.data)
216-
qkv_bias.append(layer.self_attention.query_key_value.bias.data)
223+
if margs.attention_head_type == "multihead":
224+
qkv_weight.append(layer.self_attention.query_key_value.weight.data)
225+
qkv_bias.append(layer.self_attention.query_key_value.bias.data)
226+
elif margs.attention_head_type == "multiquery":
227+
q_weight.append(layer.self_attention.query.weight.data)
228+
q_bias.append(layer.self_attention.query.bias.data)
217229
dense_weight.append(layer.self_attention.dense.weight.data)
218230
mlp_l0_weight.append(layer.mlp.dense_h_to_4h.weight.data)
219231
mlp_l0_bias.append(layer.mlp.dense_h_to_4h.bias.data)
220232
mlp_l1_weight.append(layer.mlp.dense_4h_to_h.weight.data)
221233

222234
# concat them
223-
message["qkv weight"] = torch.cat(qkv_weight, dim=0)
224-
message["qkv bias"] = torch.cat(qkv_bias, dim=0)
235+
if margs.attention_head_type == "multihead":
236+
message["qkv weight"] = torch.cat(qkv_weight, dim=0)
237+
message["qkv bias"] = torch.cat(qkv_bias, dim=0)
238+
elif margs.attention_head_type == "multiquery":
239+
message["q weight"] = torch.cat(q_weight, dim=0)
240+
message["q bias"] = torch.cat(q_bias, dim=0)
225241
message["dense weight"] = torch.cat(dense_weight, dim=1)
226242
message["mlp l0 weight"] = torch.cat(mlp_l0_weight, dim=0)
227243
message["mlp l0 bias"] = torch.cat(mlp_l0_bias, dim=0)

tools/checkpoint_saver_megatron.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def check_message(msg):
9595
'--seq-length', str(md.seq_length),
9696
'--num-attention-heads', str(md.num_attention_heads),
9797
'--max-position-embeddings', str(md.max_position_embeddings),
98+
'--attention-head-type', str(md.attention_head_type),
9899
'--tokenizer-type', str(md.tokenizer_type),
99100
'--tensor-model-parallel-size', str(args.target_tensor_parallel_size),
100101
'--pipeline-model-parallel-size', str(args.target_pipeline_parallel_size),
@@ -225,10 +226,17 @@ def get_models(count, dtype, pre_process, post_process):
225226
post_layernorm_weight = msg.pop("post layernorm weight")
226227
post_layernorm_bias = msg.pop("post layernorm bias")
227228
mlp_l1_bias = msg.pop("mlp l1 bias")
229+
if margs.attention_head_type == "multiquery":
230+
kv_weight = msg.pop("kv weight")
231+
kv_bias = msg.pop("kv bias")
228232

229233
# Split up the parallel tensors
230-
qkv_weight = torch.chunk(msg.pop("qkv weight"), args.target_tensor_parallel_size, dim=0)
231-
qkv_bias = torch.chunk(msg.pop("qkv bias"), args.target_tensor_parallel_size, dim=0)
234+
if margs.attention_head_type == "multihead":
235+
qkv_weight = torch.chunk(msg.pop("qkv weight"), args.target_tensor_parallel_size, dim=0)
236+
qkv_bias = torch.chunk(msg.pop("qkv bias"), args.target_tensor_parallel_size, dim=0)
237+
elif margs.attention_head_type == "multiquery":
238+
q_weight = torch.chunk(msg.pop("q weight"), args.target_tensor_parallel_size, dim=0)
239+
q_bias = torch.chunk(msg.pop("q bias"), args.target_tensor_parallel_size, dim=0)
232240
dense_weight = torch.chunk(msg.pop("dense weight"), args.target_tensor_parallel_size, dim=1)
233241
mlp_l0_weight = torch.chunk(msg.pop("mlp l0 weight"), args.target_tensor_parallel_size, dim=0)
234242
mlp_l0_bias = torch.chunk(msg.pop("mlp l0 bias"), args.target_tensor_parallel_size, dim=0)
@@ -239,8 +247,15 @@ def get_models(count, dtype, pre_process, post_process):
239247
l = models[tp_rank].language_model.encoder.layers[layer]
240248
l.input_layernorm.weight.data.copy_(input_layernorm_weight)
241249
l.input_layernorm.bias.data.copy_(input_layernorm_bias)
242-
l.self_attention.query_key_value.weight.data.copy_(qkv_weight[tp_rank])
243-
l.self_attention.query_key_value.bias.data.copy_(qkv_bias[tp_rank])
250+
if margs.attention_head_type == "multihead":
251+
l.self_attention.query_key_value.weight.data.copy_(qkv_weight[tp_rank])
252+
l.self_attention.query_key_value.bias.data.copy_(qkv_bias[tp_rank])
253+
elif margs.attention_head_type == "multiquery":
254+
# MQA: key-value are shared across tp-ranks
255+
l.self_attention.key_value.weight.data.copy_(kv_weight)
256+
l.self_attention.key_value.bias.data.copy_(kv_bias)
257+
l.self_attention.query.weight.data.copy_(q_weight[tp_rank])
258+
l.self_attention.query.bias.data.copy_(q_bias[tp_rank])
244259
l.self_attention.dense.weight.data.copy_(dense_weight[tp_rank])
245260
l.self_attention.dense.bias.data.copy_(dense_bias)
246261
l.post_attention_layernorm.weight.data.copy_(post_layernorm_weight)

0 commit comments

Comments
 (0)