Skip to content

Commit 1a7d54b

Browse files
authored
Merge pull request #40 from bigcode-project/mqa-checkpoint-utils
support mqa in checkpoint-merging tools
2 parents 22b8611 + 57f21b7 commit 1a7d54b

File tree

5 files changed

+58
-17
lines changed

5 files changed

+58
-17
lines changed

megatron/arguments.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ def _add_network_size_args(parser):
430430
'attention. This is set to '
431431
' args.hidden_size // args.num_attention_heads '
432432
'if not provided.')
433-
group.add_argument('--attention-head-type', type=str, default='multihead',
433+
group.add_argument('--attention-head-type', type=str, default=None,
434434
choices=['multihead', 'multiquery'],
435435
help='Type of attention heads. `multihead` is the standard multi-head attention.'
436436
'`multiquery` shares the values and keys across attention heads')

megatron/checkpointing.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def ensure_directory_exists(filename):
9292

9393

9494
def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, release=False,
95-
pipeline_parallel=None, tensor_rank=None, pipeline_rank=None):
95+
pipeline_parallel=None, tensor_rank=None, pipeline_rank=None, only_model=False):
9696
"""Determine the directory name for this rank's checkpoint."""
9797
if release:
9898
directory = 'release'
@@ -119,7 +119,7 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer,
119119

120120
if use_distributed_optimizer:
121121
model_name = os.path.join(common_path, "model_rng.pt")
122-
optim_name = os.path.join(
122+
optim_name = None if only_model else os.path.join(
123123
common_path + "_%03d" % mpu.get_data_parallel_rank(),
124124
"optim.pt")
125125
else:
@@ -139,14 +139,14 @@ def find_checkpoint_rank_0(checkpoints_path, iteration, use_distributed_optimize
139139
# Look for checkpoint with no pipelining
140140
filenames = get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, release,
141141
pipeline_parallel=False,
142-
tensor_rank=0, pipeline_rank=0)
142+
tensor_rank=0, pipeline_rank=0, only_model=True)
143143
if os.path.isfile(filenames[0]):
144144
return filenames
145145

146146
# Look for checkpoint with pipelining
147147
filenames = get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, release,
148148
pipeline_parallel=True,
149-
tensor_rank=0, pipeline_rank=0)
149+
tensor_rank=0, pipeline_rank=0, only_model=True)
150150
if os.path.isfile(filenames[0]):
151151
return filenames
152152

@@ -379,10 +379,11 @@ def fix_query_key_value_ordering(model, checkpoint_version):
379379
print_rank_0(" succesfully fixed query-key-values ordering for"
380380
" checkpoint version {}".format(checkpoint_version))
381381

382-
def _load_base_checkpoint(load_dir, use_distributed_optimizer, rank0=False, iteration=None, release=None):
382+
def _load_base_checkpoint(load_dir, use_distributed_optimizer, rank0=False, iteration=None, release=None, no_load_optim=False):
383383
""" Load the base state_dict from the given directory
384384
385385
If rank0 is true, just loads rank 0 checkpoint, ignoring arguments.
386+
If rank0 is true or no_load_optim is true, we do not care about the optimizer, only the model checkpoint.
386387
"""
387388

388389
# Read the tracker file and set the iteration.
@@ -408,7 +409,7 @@ def _load_base_checkpoint(load_dir, use_distributed_optimizer, rank0=False, iter
408409
release)
409410
else:
410411
checkpoint_names = get_checkpoint_names(load_dir, iteration, use_distributed_optimizer,
411-
release)
412+
release, only_model=no_load_optim)
412413
if release:
413414
print_rank_0(f' loading release checkpoint from {load_dir}')
414415
else:
@@ -419,7 +420,9 @@ def _load_base_checkpoint(load_dir, use_distributed_optimizer, rank0=False, iter
419420
# Load the checkpoint.
420421
try:
421422
model_state_dict = torch.load(model_checkpoint_name, map_location='cpu')
422-
if use_distributed_optimizer:
423+
if rank0 or no_load_optim:
424+
optim_state_dict = None
425+
elif use_distributed_optimizer:
423426
optim_state_dict = torch.load(optim_checkpoint_name, map_location='cpu')
424427
else:
425428
optim_state_dict = model_state_dict
@@ -572,7 +575,8 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
572575
use_distributed_optimizer=args.use_distributed_optimizer,
573576
rank0=False,
574577
iteration=iteration,
575-
release=release)
578+
release=release,
579+
no_load_optim=args.no_load_optim)
576580

577581
if model_state_dict is None:
578582
return 0

tools/checkpoint_loader_megatron.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ def _load_checkpoint(queue, args):
5050
'--no-initialization',
5151
'--load', args.load_dir
5252
]
53+
if args.use_distributed_optimizer:
54+
sys.argv.append("--use-distributed-optimizer")
55+
5356

5457
margs = parse_args()
5558
margs = load_args_from_checkpoint(margs)
@@ -78,6 +81,7 @@ def check_for_arg(arg_name):
7881
check_for_arg('iteration')
7982
check_for_arg('bert_binary_head')
8083
check_for_arg('params_dtype')
84+
check_for_arg('attention_head_type')
8185

8286
# Determine how to make our models
8387
if args.model_type == 'GPT':
@@ -147,6 +151,7 @@ def get_models(count, dtype, pre_process, post_process):
147151
# metadata
148152
md = types.SimpleNamespace()
149153
md.model_type = args.model_type
154+
md.attention_head_type = margs.attention_head_type
150155
md.num_layers = margs.num_layers
151156
md.hidden_size = margs.hidden_size
152157
md.seq_length = margs.seq_length
@@ -202,26 +207,40 @@ def queue_put(name, msg):
202207
message["post layernorm weight"] = layer.post_attention_layernorm.weight.data
203208
message["post layernorm bias"] = layer.post_attention_layernorm.bias.data
204209
message["mlp l1 bias"] = layer.mlp.dense_4h_to_h.bias.data
210+
if margs.attention_head_type == "multiquery":
211+
# MQA: kv is shared across tp-ranks
212+
message["kv weight"] = layer.self_attention.key_value.weight.data
213+
message["kv bias"] = layer.self_attention.key_value.bias.data
205214

206215
# Grab all parallel tensors for this layer
207216
qkv_weight = []
208217
qkv_bias = []
218+
q_weight = []
219+
q_bias = []
209220
dense_weight = []
210221
mlp_l0_weight = []
211222
mlp_l0_bias = []
212223
mlp_l1_weight = []
213224
for tp_rank, model in enumerate(models):
214225
layer = model.language_model.encoder.layers[layer_num]
215-
qkv_weight.append(layer.self_attention.query_key_value.weight.data)
216-
qkv_bias.append(layer.self_attention.query_key_value.bias.data)
226+
if margs.attention_head_type == "multihead":
227+
qkv_weight.append(layer.self_attention.query_key_value.weight.data)
228+
qkv_bias.append(layer.self_attention.query_key_value.bias.data)
229+
elif margs.attention_head_type == "multiquery":
230+
q_weight.append(layer.self_attention.query.weight.data)
231+
q_bias.append(layer.self_attention.query.bias.data)
217232
dense_weight.append(layer.self_attention.dense.weight.data)
218233
mlp_l0_weight.append(layer.mlp.dense_h_to_4h.weight.data)
219234
mlp_l0_bias.append(layer.mlp.dense_h_to_4h.bias.data)
220235
mlp_l1_weight.append(layer.mlp.dense_4h_to_h.weight.data)
221236

222237
# concat them
223-
message["qkv weight"] = torch.cat(qkv_weight, dim=0)
224-
message["qkv bias"] = torch.cat(qkv_bias, dim=0)
238+
if margs.attention_head_type == "multihead":
239+
message["qkv weight"] = torch.cat(qkv_weight, dim=0)
240+
message["qkv bias"] = torch.cat(qkv_bias, dim=0)
241+
elif margs.attention_head_type == "multiquery":
242+
message["q weight"] = torch.cat(q_weight, dim=0)
243+
message["q bias"] = torch.cat(q_bias, dim=0)
225244
message["dense weight"] = torch.cat(dense_weight, dim=1)
226245
message["mlp l0 weight"] = torch.cat(mlp_l0_weight, dim=0)
227246
message["mlp l0 bias"] = torch.cat(mlp_l0_bias, dim=0)

tools/checkpoint_saver_megatron.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def check_message(msg):
9595
'--seq-length', str(md.seq_length),
9696
'--num-attention-heads', str(md.num_attention_heads),
9797
'--max-position-embeddings', str(md.max_position_embeddings),
98+
'--attention-head-type', str(md.attention_head_type),
9899
'--tokenizer-type', str(md.tokenizer_type),
99100
'--tensor-model-parallel-size', str(args.target_tensor_parallel_size),
100101
'--pipeline-model-parallel-size', str(args.target_pipeline_parallel_size),
@@ -225,10 +226,17 @@ def get_models(count, dtype, pre_process, post_process):
225226
post_layernorm_weight = msg.pop("post layernorm weight")
226227
post_layernorm_bias = msg.pop("post layernorm bias")
227228
mlp_l1_bias = msg.pop("mlp l1 bias")
229+
if margs.attention_head_type == "multiquery":
230+
kv_weight = msg.pop("kv weight")
231+
kv_bias = msg.pop("kv bias")
228232

229233
# Split up the parallel tensors
230-
qkv_weight = torch.chunk(msg.pop("qkv weight"), args.target_tensor_parallel_size, dim=0)
231-
qkv_bias = torch.chunk(msg.pop("qkv bias"), args.target_tensor_parallel_size, dim=0)
234+
if margs.attention_head_type == "multihead":
235+
qkv_weight = torch.chunk(msg.pop("qkv weight"), args.target_tensor_parallel_size, dim=0)
236+
qkv_bias = torch.chunk(msg.pop("qkv bias"), args.target_tensor_parallel_size, dim=0)
237+
elif margs.attention_head_type == "multiquery":
238+
q_weight = torch.chunk(msg.pop("q weight"), args.target_tensor_parallel_size, dim=0)
239+
q_bias = torch.chunk(msg.pop("q bias"), args.target_tensor_parallel_size, dim=0)
232240
dense_weight = torch.chunk(msg.pop("dense weight"), args.target_tensor_parallel_size, dim=1)
233241
mlp_l0_weight = torch.chunk(msg.pop("mlp l0 weight"), args.target_tensor_parallel_size, dim=0)
234242
mlp_l0_bias = torch.chunk(msg.pop("mlp l0 bias"), args.target_tensor_parallel_size, dim=0)
@@ -239,8 +247,15 @@ def get_models(count, dtype, pre_process, post_process):
239247
l = models[tp_rank].language_model.encoder.layers[layer]
240248
l.input_layernorm.weight.data.copy_(input_layernorm_weight)
241249
l.input_layernorm.bias.data.copy_(input_layernorm_bias)
242-
l.self_attention.query_key_value.weight.data.copy_(qkv_weight[tp_rank])
243-
l.self_attention.query_key_value.bias.data.copy_(qkv_bias[tp_rank])
250+
if margs.attention_head_type == "multihead":
251+
l.self_attention.query_key_value.weight.data.copy_(qkv_weight[tp_rank])
252+
l.self_attention.query_key_value.bias.data.copy_(qkv_bias[tp_rank])
253+
elif margs.attention_head_type == "multiquery":
254+
# MQA: key-value are shared across tp-ranks
255+
l.self_attention.key_value.weight.data.copy_(kv_weight)
256+
l.self_attention.key_value.bias.data.copy_(kv_bias)
257+
l.self_attention.query.weight.data.copy_(q_weight[tp_rank])
258+
l.self_attention.query.bias.data.copy_(q_bias[tp_rank])
244259
l.self_attention.dense.weight.data.copy_(dense_weight[tp_rank])
245260
l.self_attention.dense.bias.data.copy_(dense_bias)
246261
l.post_attention_layernorm.weight.data.copy_(post_layernorm_weight)

tools/checkpoint_util.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,9 @@ def main():
124124
parser.add_argument('--no-checking', action='store_false',
125125
help='Do not perform checking on the name and ordering of weights',
126126
dest='checking')
127+
128+
parser.add_argument('--use-distributed-optimizer', action='store_true',
129+
help='Loaded checkpoint uses distributed optimizer.')
127130

128131
known_args, _ = parser.parse_known_args()
129132
loader = load_plugin('loader', known_args.loader)

0 commit comments

Comments
 (0)