Skip to content

Commit 0f23a72

Browse files
tjruwasestas00
andauthored
Reshape deepspeed checkpoint (#239)
* Reshape deepspeed checkpoint * add checkpoint tests * Validate input folder * Tests for tp/pp reshape * remove debug folders * fix test_checkpoint_reshaping_empty_dir * Fix unit tests * Remove deepspeed checkpoint utils * Use DS 3D reshaping utils * convert to bf16 * wip universal chkpt * rename * rename * wip on fragments dealing * cleanup * Loading universal checkpoint with reshaping * all gpu1<->2 reshapes work * param attrs * make the tests adaptable to the number of available gpus * WIP * WIP * WIP * WIP * Debug functions * args should be required, don't create another latest file * Parallelize shard extraction * close+join pool; add tqdm; comment out noise * rename * parameterize * Parallel slice merging * Cleanup * allow inspection on a machine w/o gpus * test against the right DS branch * DS size was merged Co-authored-by: Stas Bekman <[email protected]>
1 parent 7b5f175 commit 0f23a72

15 files changed

+1349
-293
lines changed

megatron/arguments.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,8 @@ def _add_learning_rate_args(parser):
641641
'(learning rate, warmup iterations, minimum learning '
642642
'rate, maximum number of iterations, and decay style '
643643
'from checkpoint and ignore input arguments.')
644+
group.add_argument('--universal-checkpoint', action='store_true',
645+
help='Loading a universal format checkpoint.')
644646

645647
return parser
646648

megatron/checkpointing.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
mpu,
2828
print_rank_0,
2929
update_num_microbatches,
30-
utils)
30+
utils,
31+
get_tokenizer)
3132
from megatron.enums import PositionEmbeddingType
3233

3334
_CHECKPOINT_VERSION = None
@@ -131,6 +132,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
131132
state_dict['checkpoint_version'] = 3.0
132133
state_dict['iteration'] = iteration
133134
state_dict['tokens'] = args.consumed_train_tokens
135+
state_dict['checkpoint_info'] = _checkpoint_info()
134136

135137
# DeepSpeed saves the model/optimizer/scheduler
136138
if not args.deepspeed:
@@ -361,7 +363,8 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
361363
assert args.consumed_valid_samples == 0
362364
if 'args' in state_dict:
363365
checkpoint_args = state_dict['args']
364-
check_checkpoint_args(checkpoint_args)
366+
if not args.universal_checkpoint:
367+
check_checkpoint_args(checkpoint_args)
365368
args.consumed_train_samples = getattr(checkpoint_args,
366369
'consumed_train_samples', 0)
367370
update_num_microbatches(consumed_samples=args.consumed_train_samples)
@@ -468,3 +471,13 @@ def load_biencoder_checkpoint(model, only_query_model=False,
468471
print(' successfully loaded {}'.format(checkpoint_name))
469472

470473
return model
474+
475+
476+
def _checkpoint_info():
477+
args = get_args()
478+
tokenizer = get_tokenizer()
479+
480+
return {
481+
"padded_vocab_size": args.padded_vocab_size,
482+
"original_vocab_size": tokenizer.vocab_size,
483+
}

megatron/training.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,32 @@ def get_learning_rate_scheduler(optimizer):
367367
return lr_scheduler
368368

369369

370+
def sync_hp_to_lp(optimizer):
371+
372+
optimizer.update_lp_params()
373+
374+
# for n,p in model.named_parameters():
375+
# print(n)
376+
377+
# if p._hp_mapping is not None:
378+
# #print(f'rank {rank} fixing hp for input_layernorm')
379+
# #p._hp_mapping.update_hp()
380+
381+
# hp = p._hp_mapping.hp_fragment
382+
383+
384+
385+
# torch.distributed.all_reduce(hp, op=torch.distributed.ReduceOp.AVG, group=mpu.get_tensor_model_parallel_group())
386+
387+
# # 3. optim states
388+
# for key in ['exp_avg', 'exp_avg_sq']:
389+
# optim_state_fragment = p._hp_mapping.get_optim_state_fragment(key)
390+
# #print(f'rank {rank} before reduce optim state fragment {key} = {optim_state_fragment}')
391+
# torch.distributed.all_reduce(optim_state_fragment, op=torch.distributed.ReduceOp.AVG, group=mpu.get_tensor_model_parallel_group())
392+
# #print(f'rank {rank} after reduce optim state fragment {key} = {optim_state_fragment}')
393+
394+
395+
370396
def setup_model_and_optimizer(model_provider_func):
371397
"""Setup model and optimizer."""
372398
args = get_args()
@@ -386,12 +412,21 @@ def setup_model_and_optimizer(model_provider_func):
386412

387413
if args.deepspeed:
388414
print_rank_0("DeepSpeed is enabled.")
389-
pp = mpu.get_pipeline_model_parallel_world_size()
415+
#pp = mpu.get_pipeline_model_parallel_world_size()
416+
417+
import json
418+
import io
419+
with io.open(args.deepspeed_config, "r", encoding="utf-8") as f:
420+
config = json.load(f)
421+
if args.universal_checkpoint:
422+
config["checkpoint"] = {"load_universal": True}
423+
390424
model, optimizer, _, lr_scheduler = deepspeed.initialize(
391425
model=model[0],
392426
optimizer=optimizer,
427+
lr_scheduler=lr_scheduler,
428+
config=config,
393429
args=args,
394-
lr_scheduler=lr_scheduler
395430
)
396431

397432
assert model.fp16_enabled() == args.fp16, "megatron fp16 config does not match deepspeed"
@@ -416,8 +451,37 @@ def setup_model_and_optimizer(model_provider_func):
416451
torch.distributed.barrier()
417452
timers('load-checkpoint').stop()
418453
timers.log(['load-checkpoint'])
454+
455+
456+
# hp -> lp
457+
if args.deepspeed and args.universal_checkpoint:
458+
sync_hp_to_lp(optimizer)
459+
460+
419461
else:
420462
args.iteration = 0
463+
464+
from .utils import dump_weights
465+
dump_weights(f'{args.universal_checkpoint=}', args.iteration, model, optimizer)
466+
467+
# tp_rank = mpu.get_tensor_model_parallel_rank()
468+
# pp_rank = mpu.get_pipeline_model_parallel_rank()
469+
# dp_rank = mpu.get_data_parallel_rank()
470+
# for n,p in model[0].named_parameters():
471+
# if 'word_embeddings.weight' not in n:
472+
# continue
473+
# if tp_rank == 0 and pp_rank == 0:
474+
# print(f"{tp_rank=}{pp_rank=}{dp_rank=} bf16 {n=} {p[:10]=}")
475+
# if p._hp_mapping is not None:
476+
# hp = p._hp_mapping.hp_fragment
477+
# print(f'{tp_rank=}{pp_rank=}{dp_rank=} fp32 {n=} {hp[:10]=}')
478+
479+
# if tp_rank == 0 and pp_rank == mpu.get_pipeline_model_parallel_world_size() - 1:
480+
# print(f"{tp_rank=}{pp_rank=}{dp_rank=} bf16 {n=} {p[:10]=}")
481+
# if p._hp_mapping is not None:
482+
# hp = p._hp_mapping.hp_fragment
483+
# print(f'{tp_rank=}{pp_rank=}{dp_rank=} fp32 {n=} {hp[:10]=}')
484+
421485

422486
# We only support local DDP with multiple micro-batches.
423487
if len(model) > 1 or mpu.get_pipeline_model_parallel_world_size() > 1:

megatron/utils.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,3 +461,79 @@ def found_kill_switch():
461461
return True
462462
else:
463463
return False
464+
465+
def get_fingerprint_header():
466+
return f"{'min':^13} {'max':^13} {'mean':^13} {'l2 norm':^12} metadata"
467+
468+
def get_fingerprint(p):
469+
return f"{p.min():13.6e} {p.max():13.6e} {p.mean():13.6e} {p.norm():12.6e}"
470+
471+
472+
def dump_weights(preamble, iteration, model, optimizer, tensor=None):
473+
tp_rank = mpu.get_tensor_model_parallel_rank()
474+
pp_rank = mpu.get_pipeline_model_parallel_rank()
475+
dp_rank = mpu.get_data_parallel_rank()
476+
dp_size = mpu.get_data_parallel_world_size()
477+
fn = f"debug-bf16-{iteration}-pp{pp_rank}-tp{tp_rank}-dp{dp_rank}-{preamble}.txt"
478+
479+
# only care for first and last pp stages and dp0 tp0
480+
#if not (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()):
481+
# return
482+
483+
#if not (tp_rank == 0 and dp_rank == 0):
484+
# return
485+
486+
if tensor is not None:
487+
orig_tensor = tensor
488+
if hasattr(tensor, "_hp_param"):
489+
numel = tensor._hp_param.numel() # // dp_size
490+
tensor = tensor.flatten().narrow(0, 0, numel)
491+
492+
#print(fn)
493+
with open(fn, "w") as fh:
494+
fh.write(f"{get_fingerprint_header()}\n")
495+
496+
if tensor is not None:
497+
fh.write(f"{get_fingerprint(tensor)} tensor {tensor.shape}\n")
498+
else:
499+
for n, p in model[0].named_parameters():
500+
fh.write(f"{get_fingerprint(p)} {n} {p.shape}\n")
501+
502+
503+
return
504+
505+
506+
# until we figure out how to dump the actual fp32 values don't do this
507+
fn = f"debug-fp32-{iteration}-pp{pp_rank}-tp{tp_rank}-dp{dp_rank}-{preamble}.txt"
508+
with open(fn, "w") as fh:
509+
fh.write(f"{get_fingerprint_header()}\n")
510+
if tensor is not None:
511+
tensor = orig_tensor
512+
if hasattr(tensor, "_hp_param"):
513+
fh.write(f"{get_fingerprint(tensor._hp_param)} tensor {tensor._hp_param.shape}\n")
514+
#fh.write(f"{get_fingerprint(tensor._hp_grad)} tensor grad\n")
515+
else:
516+
fh.write(f"{get_fingerprint(tensor)} tensor {tensor.shape}\n")
517+
#fh.write(f"{get_fingerprint(tensor.grad)} tensor grad\n")
518+
519+
else:
520+
if hasattr(model[0].module.tied_modules, "embed"):
521+
p = model[0].module.tied_modules.embed.word_embeddings.weight._hp_param
522+
fh.write(f"{get_fingerprint(p)} module.tied_modules.embed.word_embeddings.weight._hp_param {p.shape}\n")
523+
524+
# for i, param_group in enumerate(optimizer.param_groups):
525+
# fh.write(f"{get_fingerprint(optimizer.fp32_groups_flat_partition[i])} group={i}\n")
526+
#fh.write(f"{i}={optimizer.fp32_groups_flat_partition[i]}\n")
527+
# if mpu.is_pipeline_first_stage():
528+
# x = optimizer.fp32_groups_flat_partition[0]
529+
# fh.write(f"fp32={x[:402432]}\n")
530+
# if mpu.is_pipeline_last_stage()):
531+
# x = optimizer.fp32_groups_flat_partition[1]
532+
# fh.write(f"fp32={x[-402432:]}\n")
533+
534+
# import os
535+
# import socket
536+
# hostname = socket.gethostname()
537+
# pid = os.getpid()
538+
# global_rank = torch.distributed.get_rank()
539+
#fn = f"debug-{iteration}-pp{pp_rank}-tp{tp_rank}-dp{dp_rank}-global{global_rank}-{preamble}-{pid}.txt"

run_bf16.sh

Lines changed: 40 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,48 +12,58 @@ DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'`
1212
#DATASET_3="<PATH TO THE THIRD DATASET>"
1313
#DATASET="0.2 ${DATASET_1} 0.3 ${DATASET_2} 0.5 ${DATASET_3}"
1414

15-
BASE_DATA_PATH=/data/Megatron-LM/data
15+
#BASE_DATA_PATH=tests/data/gpt2
16+
#DATASET=${BASE_DATA_PATH}/meg-gpt2-openwebtext_text_document
17+
#VOCAB_PATH=${BASE_DATA_PATH}/gpt2-tiny-vocab.json
18+
#MERGE_PATH=${BASE_DATA_PATH}/gpt2-tiny-merges.txt
19+
20+
BASE_DATA_PATH=/vc_data/Megatron-LM/data
1621
DATASET=${BASE_DATA_PATH}/indexed_datasets/megatron
1722
VOCAB_PATH=${BASE_DATA_PATH}/gpt2-vocab.json
1823
MERGE_PATH=${BASE_DATA_PATH}/gpt2-merges.txt
1924

2025

2126
script_path=$(realpath $0)
2227
script_dir=$(dirname $script_path)
23-
#CONFIG_JSON="$script_dir/ds_config.json"
24-
CONFIG_JSON="/tmp/ds_config.json"
28+
CONFIG_JSON="$script_dir/ds_config.json"
29+
#CONFIG_JSON="/tmp/ds_config.json"
2530

2631
USE_DEEPSPEED=1
2732
ZERO_STAGE=0
2833

29-
30-
# Debug
3134
#TP=4
3235
#PP=4
33-
#LAYERS=8
34-
#HIDDEN=512
35-
#SEQ=1024
36-
#GLOBAL_BATCH=128
37-
#WORKER_STR="-i worker-0"
3836

39-
40-
TP=1
41-
PP=1
42-
DP=2
37+
# Debug
38+
DEBUG_MODE=0
39+
if [[ $DEBUG_MODE == 1 ]]; then
40+
LAYERS=4
41+
HIDDEN=512
42+
SEQ=512
43+
EXIT_INTERVAL=3
44+
else
45+
HIDDEN=1024
46+
LAYERS=24
47+
SEQ=1024
48+
EXIT_INTERVAL=10
49+
fi
50+
51+
TP=2
52+
PP=2
53+
DP=4
4354
WORLD_SIZE=$((TP*PP*DP))
44-
HIDDEN=1024
45-
LAYERS=24
46-
SEQ=1024
47-
GLOBAL_BATCH=1
48-
WORKER_STR=""
55+
GLOBAL_BATCH=4
4956

5057
MICRO_BATCH=1
58+
TRAIN_ITERS=100000
59+
CHECKPOINT_PATH=checkpoints/gpt2/tp${TP}_pp${PP}_dp${DP}
60+
LOAD_CHECKPOINT_PATH=checkpoints/gpt2/tp${TP}_pp${PP}_dp${DP}
5161

5262
LR=6.0e-4
5363
MIN_LR=6.0e-5
5464
DTYPE="bf16"
55-
EXP_DIR=${HOME}/experiments/results/bf16
56-
LOG_DIR="${EXP_DIR}/tensorboard/tp${TP}_pp${PP}_dp${DP}_hd${HIDDEN}_nl${LAYERS}_gbsz${GLOBAL_BATCH}_mbsz${MICRO_BATCH}_z${ZERO_STAGE}_LR_${LR}_${MIN_LR}_${DTYPE}_fix3"
65+
EXP_DIR=${HOME}/experiments/results/ckpt_reshape
66+
LOG_DIR="${EXP_DIR}/tensorboard/tp${TP}_pp${PP}_dp${DP}_hd${HIDDEN}_nl${LAYERS}_gbsz${GLOBAL_BATCH}_mbsz${MICRO_BATCH}_z${ZERO_STAGE}_LR_${LR}_${MIN_LR}_${DTYPE}_cont"
5767
mkdir -p $LOG_DIR
5868

5969
while [[ $# -gt 0 ]]
@@ -89,7 +99,7 @@ options=" \
8999
--max-position-embeddings $SEQ \
90100
--micro-batch-size $MICRO_BATCH \
91101
--global-batch-size $GLOBAL_BATCH \
92-
--train-iters 1000 \
102+
--train-iters $TRAIN_ITERS \
93103
--lr $LR \
94104
--min-lr $MIN_LR \
95105
--lr-decay-style cosine \
@@ -99,7 +109,7 @@ options=" \
99109
--data-path ${DATASET} \
100110
--vocab-file ${VOCAB_PATH} \
101111
--merge-file ${MERGE_PATH} \
102-
--save-interval 10000 \
112+
--save-interval 1000 \
103113
--split 98,2,0 \
104114
--clip-grad 1.0 \
105115
--weight-decay 0.1 \
@@ -108,7 +118,12 @@ options=" \
108118
--init-method-std 0.006 \
109119
--${DTYPE} \
110120
--checkpoint-activations \
111-
--exit-interval 10000 \
121+
--exit-interval ${EXIT_INTERVAL} \
122+
--save ${CHECKPOINT_PATH} \
123+
--load ${LOAD_CHECKPOINT_PATH} \
124+
--position-embedding-type alibi \
125+
--override-lr-scheduler \
126+
--embed-layernorm \
112127
--tensorboard-dir $LOG_DIR
113128
"
114129

@@ -151,7 +166,7 @@ cat <<EOT > $CONFIG_JSON
151166
}
152167
EOT
153168

154-
WORKER_STR="--num_nodes 1 --num_gpus $WORLD_SIZE"
169+
#WORKER_STR="--num_nodes 1 --num_gpus $WORLD_SIZE"
155170
#WORKER_STR="-i worker-0:0,1,2,3"
156171
#run_cmd="deepspeed -i worker-0:0,1,2,3 ${DIR}/pretrain_gpt.py $@ ${options}"
157172
#run_cmd="deepspeed -i worker-0 ${DIR}/pretrain_gpt.py $@ ${options}"

0 commit comments

Comments
 (0)