Skip to content

Commit e9191fb

Browse files
author
fengyu05
committed
add llama transfer script
1 parent b91bfd4 commit e9191fb

File tree

10 files changed

+470
-24
lines changed

10 files changed

+470
-24
lines changed

megatron/checkpointing.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -222,14 +222,17 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
222222
if not args.deepspeed:
223223
model = unwrap_model(model)
224224

225-
print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
225+
print_rank_0('saving checkpoint at iteration {} to {}'.format(
226226
iteration, args.save))
227227

228228
# Collect rng state across data parallel ranks.
229229
rng_state = get_rng_state()
230230

231231
# Checkpoint name.
232-
checkpoint_name = get_checkpoint_name(args.save, iteration)
232+
if iteration == 'release':
233+
checkpoint_name = get_checkpoint_name(args.save, iteration, release=True)
234+
else:
235+
checkpoint_name = get_checkpoint_name(args.save, iteration)
233236

234237
# Save distributed optimizer's custom parameter state.
235238
if args.use_distributed_optimizer:
@@ -300,7 +303,7 @@ def state_dict_for_save_checkpoint_deepspeed(destination=None, prefix='', keep_v
300303
if torch.distributed.is_initialized():
301304
torch.distributed.barrier()
302305

303-
print_rank_0(' successfully saved checkpoint at iteration {:7d} to {}' \
306+
print_rank_0(' successfully saved checkpoint at iteration {} to {}' \
304307
.format(iteration, args.save))
305308

306309
# And update the latest iteration
@@ -509,6 +512,7 @@ def _set_arg(arg_name, old_arg_name=None, force=False):
509512
_set_arg('apply_layernorm_1p', force=True)
510513
_set_arg('tokenizer_type')
511514
_set_arg('padded_vocab_size')
515+
_set_arg('normalization', force=True)
512516
if checkpoint_version < 3.0:
513517
_set_arg('tensor_model_parallel_size',
514518
'model_parallel_size')

megatron/model/transformer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -584,9 +584,11 @@ def __init__(self, config, layer_number,
584584
local_attn = FlashSelfAttention(causal=True, attention_dropout=config.attention_dropout)
585585
else:
586586
local_attn = CoreAttention(self.layer_number, config, self.attn_mask_type)
587-
588-
self.enable_ds_sequence_parallel = parallel_state.get_sequence_parallel_world_size() > 1 \
589-
or args.force_ds_sequence_parallel
587+
if hasattr(args, 'ckpt_transfer') and args.ckpt_transfer:
588+
self.enable_ds_sequence_parallel = False
589+
else:
590+
self.enable_ds_sequence_parallel = parallel_state.get_sequence_parallel_world_size() > 1 \
591+
or args.force_ds_sequence_parallel
590592
if self.enable_ds_sequence_parallel:
591593
assert dist_attn_supported, 'Distributed attention is not supported in this DeepSpeed version'
592594
assert args.num_attention_heads % parallel_state.get_sequence_parallel_world_size() == 0

pretrain_gpt.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,23 @@
2828
from torch import nn
2929
import torch.nn.functional as F
3030

31-
def model_provider(pre_process=True, post_process=True):
31+
32+
def model_provider(pre_process=True, post_process=True, ckpt_transfer_model=False):
3233
"""Build the model."""
3334

3435
print_rank_0('building GPT model ...')
3536
see_memory_usage(f"Before Building Model", force=True)
3637

3738
args = get_args()
3839
config = core_transformer_config_from_args(args)
40+
41+
if ckpt_transfer_model:
42+
return GPTModel(config=config,
43+
num_tokentypes=0,
44+
parallel_output=True,
45+
pre_process=pre_process,
46+
post_process=post_process)
47+
3948
with deepspeed.zero.Init(sequence_data_parallel_group=mpu.get_sequence_data_parallel_group(),
4049
remote_device=None if args.remote_device == 'none' else args.remote_device,
4150
config_dict_or_path=args.deepspeed_config,

tools/checkpoint_loader_megatron.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ def _load_checkpoint(queue, args):
5656

5757
margs = parse_args()
5858
margs, checkpoint_args = load_args_from_checkpoint(margs)
59+
if args.tokenizer_model:
60+
margs.tokenizer_model = args.tokenizer_model
61+
margs.ckpt_transfer = True
5962

6063
# Arguments do sanity checks on the world size, but we don't care,
6164
# so trick it into thinking we are plenty of processes
@@ -124,14 +127,15 @@ def get_models(count, dtype):
124127
post_process = mpu.is_pipeline_last_stage()
125128
this_model = model_provider(
126129
pre_process=pre_process,
127-
post_process=post_process
130+
post_process=post_process,
131+
ckpt_transfer_model=True
128132
).to(dtype)
129133
model_.append(this_model)
130134
else:
131135
pre_process = mpu.is_pipeline_first_stage()
132136
post_process = mpu.is_pipeline_last_stage()
133137
model_rank = 0
134-
model_ = [model_provider(pre_process, post_process).to(dtype)]
138+
model_ = [model_provider(pre_process, post_process, ckpt_transfer_model=True).to(dtype)]
135139
margs.consumed_train_samples = 0
136140
margs.consumed_valid_samples = 0
137141
load_checkpoint(model_, None, None)
@@ -236,9 +240,11 @@ def queue_put(name, msg):
236240
# Get non-parallel tensors from tp_rank 0
237241
layer = models[0].language_model.encoder.layers[layer_num]
238242
message["input layernorm weight"] = layer.input_layernorm.weight.data
239-
message["input layernorm bias"] = layer.input_layernorm.bias.data
240243
message["post layernorm weight"] = layer.post_attention_layernorm.weight.data
241-
message["post layernorm bias"] = layer.post_attention_layernorm.bias.data
244+
if margs.normalization != 'rmsnorm':
245+
message["input layernorm bias"] = layer.input_layernorm.bias.data
246+
message["post layernorm bias"] = layer.post_attention_layernorm.bias.data
247+
242248
if md.linear_bias:
243249
message["dense bias"] = layer.self_attention.dense.bias.data
244250
message["mlp l1 bias"] = layer.mlp.dense_4h_to_h.bias.data
@@ -291,8 +297,9 @@ def queue_put(name, msg):
291297
# Send final layernorm from tp_rank 0
292298
message = {
293299
"weight": models[0].language_model.encoder.final_layernorm.weight.data,
294-
"bias": models[0].language_model.encoder.final_layernorm.bias.data
295300
}
301+
if margs.normalization != 'rmsnorm':
302+
message["bias"] = models[0].language_model.encoder.final_layernorm.bias.data
296303
queue_put("final layernorm", message)
297304

298305
if md.output_layer:
@@ -334,3 +341,4 @@ def load_checkpoint(queue, args):
334341
except:
335342
queue.put("exit")
336343
raise
344+

tools/checkpoint_saver_megatron.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -162,12 +162,15 @@ def check_message(msg):
162162
setattr(margs, arg, value)
163163

164164
validate_args(margs)
165-
165+
margs.ckpt_transfer = True
166+
if args.tokenizer_model:
167+
margs.tokenizer_model = args.tokenizer_model
166168
set_global_variables(margs)
167169

168170
# margs = megatron args
169171
margs = get_args()
170172

173+
print("args.tokenizer_model", args.tokenizer_model)
171174
if hasattr(md, 'consumed_train_samples'):
172175
margs.consumed_train_samples = md.consumed_train_samples
173176
margs.consumed_valid_samples = md.consumed_valid_samples
@@ -187,7 +190,7 @@ def check_message(msg):
187190
raise Exception(f'unrecognized model type: {args.model_type}')
188191

189192
def get_models(count, dtype, pre_process, post_process):
190-
models = [model_provider(pre_process, post_process).to(dtype) for _ in range(count)]
193+
models = [model_provider(pre_process, post_process, ckpt_transfer_model=True).to(dtype) for _ in range(count)]
191194
return models
192195

193196
# fake initializing distributed
@@ -262,9 +265,11 @@ def get_models(count, dtype, pre_process, post_process):
262265

263266
# duplicated tensors
264267
input_layernorm_weight = msg.pop("input layernorm weight")
265-
input_layernorm_bias = msg.pop("input layernorm bias")
266268
post_layernorm_weight = msg.pop("post layernorm weight")
267-
post_layernorm_bias = msg.pop("post layernorm bias")
269+
if margs.normalization != 'rmsnorm':
270+
post_layernorm_bias = msg.pop("post layernorm bias")
271+
input_layernorm_bias = msg.pop("input layernorm bias")
272+
268273
if md.linear_bias:
269274
dense_bias = msg.pop("dense bias")
270275
mlp_l1_bias = msg.pop("mlp l1 bias")
@@ -295,11 +300,12 @@ def get_models(count, dtype, pre_process, post_process):
295300
for tp_rank in range(args.target_tensor_parallel_size):
296301
l = models[tp_rank].language_model.encoder.layers[layer]
297302
l.input_layernorm.weight.data.copy_(input_layernorm_weight)
298-
l.input_layernorm.bias.data.copy_(input_layernorm_bias)
303+
if margs.normalization != 'rmsnorm':
304+
l.input_layernorm.bias.data.copy_(input_layernorm_bias)
305+
l.post_attention_layernorm.bias.data.copy_(post_layernorm_bias)
299306
l.self_attention.query_key_value.weight.data.copy_(qkv_weight[tp_rank])
300307
l.self_attention.dense.weight.data.copy_(dense_weight[tp_rank])
301308
l.post_attention_layernorm.weight.data.copy_(post_layernorm_weight)
302-
l.post_attention_layernorm.bias.data.copy_(post_layernorm_bias)
303309
l.mlp.dense_h_to_4h.weight.data.copy_(mlp_l0_weight[tp_rank])
304310
l.mlp.dense_4h_to_h.weight.data.copy_(mlp_l1_weight[tp_rank])
305311
if md.linear_bias:
@@ -315,15 +321,18 @@ def get_models(count, dtype, pre_process, post_process):
315321
if post_process:
316322
msg = queue_get("final layernorm")
317323
final_layernorm_weight = msg.pop("weight")
318-
final_layernorm_bias = msg.pop("bias")
324+
if margs.normalization != 'rmsnorm':
325+
final_layernorm_bias = msg.pop("bias")
319326
for tp_rank in range(args.target_tensor_parallel_size):
320327
models[tp_rank].language_model.encoder.final_layernorm.weight.data.copy_(final_layernorm_weight)
321-
models[tp_rank].language_model.encoder.final_layernorm.bias.data.copy_(final_layernorm_bias)
328+
if margs.normalization != 'rmsnorm':
329+
models[tp_rank].language_model.encoder.final_layernorm.bias.data.copy_(final_layernorm_bias)
322330
if pp_rank != 0 and not md.output_layer:
323331
# Copy word embeddings to final pipeline rank
324332
models[tp_rank].word_embeddings.weight.data.copy_(out_word_embed[tp_rank])
325333
del final_layernorm_weight
326-
del final_layernorm_bias
334+
if margs.normalization != 'rmsnorm':
335+
del final_layernorm_bias
327336
check_message(msg)
328337

329338
if md.output_layer:
@@ -361,12 +370,14 @@ def get_models(count, dtype, pre_process, post_process):
361370
lm_head_dense_weight = msg.pop("dense weight")
362371
lm_head_dense_bias = msg.pop("dense bias")
363372
lm_head_layernorm_weight = msg.pop("layernorm weight")
364-
lm_head_layernorm_bias = msg.pop("layernorm bias")
373+
if margs.normalization != 'rmsnorm':
374+
lm_head_layernorm_bias = msg.pop("layernorm bias")
365375
for tp_rank in range(args.target_tensor_parallel_size):
366376
models[tp_rank].lm_head.dense.weight.data.copy_(lm_head_dense_weight)
367377
models[tp_rank].lm_head.dense.bias.data.copy_(lm_head_dense_bias)
368378
models[tp_rank].lm_head.layernorm.weight.data.copy_(lm_head_layernorm_weight)
369-
models[tp_rank].lm_head.layernorm.bias.data.copy_(lm_head_layernorm_bias)
379+
if margs.normalization != 'rmsnorm':
380+
models[tp_rank].lm_head.layernorm.bias.data.copy_(lm_head_layernorm_bias)
370381
check_message(msg)
371382
msg = queue_get()
372383

tools/checkpoint_util.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,16 +124,21 @@ def main():
124124
parser.add_argument('--no-checking', action='store_false',
125125
help='Do not perform checking on the name and ordering of weights',
126126
dest='checking')
127+
parser.add_argument('--tokenizer-model', type=str, default=None,
128+
help='tokenizer-model, should be on python path')
129+
127130

128131
known_args, _ = parser.parse_known_args()
132+
129133
loader = load_plugin('loader', known_args.loader)
130134
saver = load_plugin('saver', known_args.saver)
131135

132136
loader.add_arguments(parser)
133137
saver.add_arguments(parser)
134138

135139
args = parser.parse_args()
136-
140+
if args.tokenizer_model is None:
141+
args.tokenizer_model = args.load_dir+"/tokenizer.model"
137142
queue = mp.Queue(maxsize=args.max_queue_size)
138143

139144
print("Starting saver...")
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Introduction
2+
This folder is a collection of scripts for converting hf checkpoints to megatron-DeepSpeed checkpoints.
3+
4+
# Usage
5+
## huggingface to megatron
6+
```bash
7+
python tools/convert_checkpoint/weights2megatron/weights2megatron_llama.py llama2 --size=13 --out=${DEST_DIR} --cache-dir=${HF_CKPT_DIR} --tokenizer-size=32000
8+
```
9+
10+
## split ckpt by TP and PP size
11+
```bash
12+
python3 tools/checkpoint_util.py \
13+
--target-tensor-parallel-size 4 \
14+
--target-pipeline-parallel-size 2 \
15+
--load-dir ${LOAD_DIR} \
16+
--save-dir ${SAVE_DIR} \
17+
--model-type GPT \
18+
--true-vocab-size 32000
19+
```
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import os
2+
import re
3+
from pathlib import Path
4+
from typing import Optional
5+
from collections import OrderedDict
6+
7+
import torch
8+
from tqdm.auto import tqdm
9+
from transformers import LlamaForCausalLM, AutoTokenizer
10+
11+
12+
scale2emb = {
13+
'7B': 4096,
14+
'13B': 5120,
15+
'30B': 6656,
16+
'65B': 8192,
17+
'70B': 8192,
18+
}
19+
20+
21+
key_to_dim = {
22+
"w1": 0,
23+
"w2": -1,
24+
"w3": 0,
25+
"wo": -1,
26+
"wq": 0,
27+
"wk": 0,
28+
"wv": 0,
29+
"output": 0,
30+
"tok_embeddings": -1,
31+
"ffn_norm": None,
32+
"attention_norm": None,
33+
"norm": None,
34+
"rope": None,
35+
}
36+
37+
38+
def init_merged_ckpt(pth_00, num_pth=8, emb_dim=8192):
39+
merged_ckpt = OrderedDict()
40+
for parameter_name, parameter in pth_00.items():
41+
short_name = parameter_name.split(".")[-2]
42+
if key_to_dim[short_name] is None:
43+
merged_ckpt[parameter_name] = parameter
44+
del parameter
45+
elif key_to_dim[short_name] == 0:
46+
size = parameter.shape[0]
47+
merged_param_shape = [ parameter.shape[0] * num_pth, parameter.shape[1] ]
48+
merged_ckpt[parameter_name] = torch.zeros(merged_param_shape)
49+
merged_ckpt[parameter_name][0 : size, :] = parameter
50+
del parameter
51+
elif key_to_dim[short_name] == -1:
52+
size = parameter.shape[-1]
53+
merged_param_shape = [ parameter.shape[0], parameter.shape[1] * num_pth]
54+
merged_ckpt[parameter_name] = torch.zeros(merged_param_shape)
55+
merged_ckpt[parameter_name][:, 0 : size] = parameter
56+
del parameter
57+
return merged_ckpt
58+
59+
60+
def merge_meta_llama(size: int, root_dir: Path):
61+
paths = sorted(path for path in root_dir.iterdir()
62+
if re.match(r"^consolidated\.[0-9]+\.pth$", path.name))
63+
if len(paths) == 1: # no sharded checkpoints, return everything
64+
return torch.load(paths[0], map_location=torch.device("cpu"))
65+
66+
num_pth = len(paths)
67+
for i, ckpt_path in enumerate(tqdm(paths, desc="Merging llama")):
68+
llama_config = torch.load(ckpt_path, map_location=torch.device('cpu'))
69+
if i == 0:
70+
merged_ckpt = init_merged_ckpt(llama_config, num_pth=num_pth,
71+
emb_dim=scale2emb[f"{size}B"])
72+
else:
73+
for parameter_name, parameter in llama_config.items():
74+
short_name = parameter_name.split(".")[-2]
75+
if key_to_dim[short_name] == 0:
76+
size = parameter.shape[0]
77+
merged_param_shape = [ parameter.shape[0] * num_pth, parameter.shape[1] ]
78+
merged_ckpt[parameter_name][size * i : size * (i + 1), :] = parameter
79+
del parameter
80+
if key_to_dim[short_name] == -1:
81+
size = parameter.shape[-1]
82+
merged_param_shape = [ parameter.shape[0], parameter.shape[1] * num_pth]
83+
merged_ckpt[parameter_name][:, size * i : size * (i + 1)] = parameter
84+
del parameter
85+
del llama_config
86+
return merged_ckpt
87+
88+
89+
def merge_hf_llama(size: int, version: int, cache_dir: Optional[Path] = None, model_path=None, tokenizer_len=32000):
90+
assert version == 2, "Only llama v2 available using huggingface"
91+
print(cache_dir)
92+
model = LlamaForCausalLM.from_pretrained(cache_dir, cache_dir=cache_dir, local_files_only=True, use_safetensors=False)
93+
# resize token embeddings size according saved tokenizer for model extend token size.
94+
# model.resize_token_embeddings(tokenizer_len)
95+
weights = model.state_dict()
96+
weights["tok_embeddings.weight"] = weights.pop("model.embed_tokens.weight")
97+
weights["norm.weight"] = weights.pop("model.norm.weight")
98+
weights["output.weight"] = weights.pop("lm_head.weight")
99+
for key in list(weights.keys()):
100+
if rmatch := re.match(r"^model\.(layers\.[0-9]+\.)(.+)(\.weight)$", key):
101+
new_key = {
102+
"self_attn.q_proj": "attention.wq",
103+
"self_attn.k_proj": "attention.wk",
104+
"self_attn.v_proj": "attention.wv",
105+
"self_attn.o_proj": "attention.wo",
106+
"mlp.gate_proj": "feed_forward.w1",
107+
"mlp.down_proj": "feed_forward.w2",
108+
"mlp.up_proj": "feed_forward.w3",
109+
"input_layernorm": "attention_norm",
110+
"post_attention_layernorm": "ffn_norm"
111+
}[rmatch.group(2)]
112+
weights[rmatch.group(1) + new_key + rmatch.group(3)] = weights.pop(key)
113+
return weights
114+
115+
116+
def merge_llama(size: int, version: int, root_dir: Optional[Path] = None, tokenizer_len: Optional[int] = 32000):
117+
if root_dir is not None and (root_dir/"consolidated.00.pth").exists():
118+
return merge_meta_llama(size, root_dir), "meta"
119+
print(f"Weights at {root_dir} do not look like a meta checkpoint, assuming "
120+
"huggingface cache_dir instead")
121+
return merge_hf_llama(size, version, root_dir, tokenizer_len), "hf"

0 commit comments

Comments
 (0)