Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,19 @@
get_tokenizer)
from megatron.enums import PositionEmbeddingType


from deepspeed.checkpoint import (
ORIGINAL_VOCAB_SIZE,
PADDED_VOCAB_SIZE,
UNIVERSAL_CHECKPOINT_INFO,
UNIVERSAL_CHECKPOINT_VERSION_KEY,
UNIVERSAL_CHECKPOINT_VERSION_VALUE,
VOCABULARY_PARAMETERS_PATTERN,
PIPELINE_REPLICATED_PARAMETERS_PATTERN,
PARAMETERS_TO_AVERAGE_PATTERN,
PARAMETERS_WITH_ROW_PARALLELISM_PATTERN,
)

_CHECKPOINT_VERSION = None

def set_checkpoint_version(value):
Expand Down Expand Up @@ -133,6 +146,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
state_dict['iteration'] = iteration
state_dict['tokens'] = args.consumed_train_tokens
state_dict['checkpoint_info'] = _checkpoint_info()
state_dict[UNIVERSAL_CHECKPOINT_INFO] = _universal_checkpoint_info()

# DeepSpeed saves the model/optimizer/scheduler
if not args.deepspeed:
Expand Down Expand Up @@ -480,4 +494,42 @@ def _checkpoint_info():
return {
"padded_vocab_size": args.padded_vocab_size,
"original_vocab_size": tokenizer.vocab_size,
}
}

def _universal_checkpoint_info():
args = get_args()
tokenizer = get_tokenizer()

info = dict()
info[UNIVERSAL_CHECKPOINT_VERSION_KEY] = UNIVERSAL_CHECKPOINT_VERSION_VALUE
info[ORIGINAL_VOCAB_SIZE] = tokenizer.vocab_size
info[PADDED_VOCAB_SIZE] = args.padded_vocab_size

# Vocabulary parameters (embeddings) that require special handling due to padding.
info[VOCABULARY_PARAMETERS_PATTERN] = ["word_embeddings.weight"]

# Replicated (shared) parameters on the pipeline dimension
info[PIPELINE_REPLICATED_PARAMETERS_PATTERN] = ["word_embeddings.weight"]

# Parameter slices that should be averaged not concatenated.
info[PARAMETERS_TO_AVERAGE_PATTERN] = [
r"tied_modules.embed.word_embeddings.norm.weight",
r"tied_modules.embed.word_embeddings.norm.bias",
r"\d+.input_layernorm.weight",
r"\d+.input_layernorm.bias",
r"\d+.post_attention_layernorm.weight",
r"\d+.post_attention_layernorm.bias",
r"\d+.self_attention.dense.bias",
r"\d+.mlp.dense_4h_to_h.bias",
r"\d+.weight",
r"\d+.bias",
]

# Parameter that are sliced on the row dimension
info[PARAMETERS_WITH_ROW_PARALLELISM_PATTERN] = [
"dense_4h_to_h.weight",
"self_attention.dense.weight",
]

return info

118 changes: 36 additions & 82 deletions tools/convert_checkpoint/ds_to_universal.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,23 @@
if root_repo_path not in sys.path:
sys.path.insert(0, root_repo_path)


from deepspeed.checkpoint import DeepSpeedCheckpoint

MODEL_KEY = 'model'
ARGS_KEY = 'args'
LANGUGAGE_MODEL_KEY = 'language_model'
EMBEDDING_KEY = 'embedding'
ENCODER_KEY = 'encoder'
WORD_EMBEDDINGS_FOR_HEAD_KEY = 'word_embeddings_for_head'
WORD_EMBEDDINGS_KEY = 'word_embeddings'
FINAL_LAYER_NORM_KEY = 'final_layernorm'
CHECKPOINT_VERSION_KEY = 'checkpoint_version'
CHECKPOINT_VERSION_VALUE = 3.0
ITERATION_KEY = 'iteration'

from deepspeed.checkpoint import (
OPTIMIZER_STATE_DICT,
BASE_OPTIMIZER_STATE,
SINGLE_PARTITION_OF_FP32_GROUPS,
PARAM_SLICE_MAPPINGS,
PARAM_SHAPES,
PARAM,
CAT_DIM,
VOCAB_DIVISIBILITY_PADDING_TENSOR,
ORIGINAL_VOCAB_SIZE,
UNIVERSAL_CHECKPOINT_INFO,
VOCABULARY_PARAMETERS_PATTERN,
PIPELINE_REPLICATED_PARAMETERS_PATTERN,
PARAMETERS_TO_AVERAGE_PATTERN,
PARAMETERS_WITH_ROW_PARALLELISM_PATTERN,
)

def parse_arguments():
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -72,16 +74,6 @@ def parse_arguments():
return args


def _convert_ds_transformer_state(sd_list):
new_sd = OrderedDict()
for i, sd in enumerate(sd_list):
for key, value in sd.items():
new_key = f'layers.{i}.{key}'
new_sd[new_key] = value

return new_sd


def _create_checkpoint_paths(base_folder, iteration, tp_degree, pp_degree):
path_list = []
iter_folder = f'iter_{iteration:07d}'
Expand All @@ -96,17 +88,6 @@ def _create_checkpoint_paths(base_folder, iteration, tp_degree, pp_degree):
return path_list


def _create_megatron_dict():
language_model_dict = {EMBEDDING_KEY: {}, ENCODER_KEY: {}}
megatron_dict = {
MODEL_KEY: {
LANGUGAGE_MODEL_KEY: language_model_dict
},
CHECKPOINT_VERSION_KEY: CHECKPOINT_VERSION_VALUE
}
return megatron_dict


def _save_checkpoint(file_path, chkpt_sd):
dir, _ = os.path.split(file_path)
os.makedirs(dir, exist_ok=True)
Expand All @@ -123,13 +104,14 @@ def extract_zero_shards(dir, slice_shapes, ds_checkpoint, indices_3D):

#pprint(f"Processing {dp_index=} {pp_index=}, {tp_index=}")

optim_sd = sd["optimizer_state_dict"]
param_slice_mappings = optim_sd["param_slice_mappings"]

optim_sd = sd[OPTIMIZER_STATE_DICT]
param_slice_mappings = optim_sd[PARAM_SLICE_MAPPINGS]
universal_checkpoint_info = ds_checkpoint.get_checkpoint_info(UNIVERSAL_CHECKPOINT_INFO)
pipeline_replicated_params = universal_checkpoint_info.get(PIPELINE_REPLICATED_PARAMETERS_PATTERN, [])
# dict
state_groups = optim_sd["base_optimizer_state"]["state"]
state_groups = optim_sd[BASE_OPTIMIZER_STATE]["state"]
# list
fp32_groups = optim_sd["single_partition_of_fp32_groups"]
fp32_groups = optim_sd[SINGLE_PARTITION_OF_FP32_GROUPS]
param_groups_cnt = len(state_groups)

for param_group_id in range(param_groups_cnt):
Expand All @@ -141,7 +123,7 @@ def extract_zero_shards(dir, slice_shapes, ds_checkpoint, indices_3D):
)

for name,fragment_mapping in param_slice_mappings[param_group_id].items():
if "word_embeddings.weight" in name and pp_index > 0:
if pp_index > 0 and any(re.match(pattern, name) for pattern in pipeline_replicated_params):
# Skip tied weights that are replicated in first and last pp stages
continue

Expand Down Expand Up @@ -176,41 +158,13 @@ def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape):
for tp_index in range(tp_degree):
prefix_path = os.path.join(param_base_path, str(tp_index), f"{state}")
paths = sorted(list(glob.glob(f"{prefix_path}.0*")))
#print(paths)
shards = [torch.load(p) for p in paths]
slice = torch.cat(shards, dim=0).reshape(slice_shape)
slices.append(slice)

return slices


ORIGINAL_VOCAB_SIZE = 'original_vocab_size'
def _strip_vocab_padding(ds_checkpoint, padded_vocab_tensor):
checkpoint_info = ds_checkpoint.get_checkpoint_info()
padding_tensor = padded_vocab_tensor.narrow(0, checkpoint_info[ORIGINAL_VOCAB_SIZE], padded_vocab_tensor.shape[0]-checkpoint_info[ORIGINAL_VOCAB_SIZE])
#print(f'{padded_vocab_tensor[checkpoint_info[ORIGINAL_VOCAB_SIZE]-3:,:]=}')
return padded_vocab_tensor.narrow(0, 0, checkpoint_info[ORIGINAL_VOCAB_SIZE])


WEIGHTS_TO_AVERAGE_PATTERNS = [
r"tied_modules.embed.word_embeddings.norm.weight",
r"tied_modules.embed.word_embeddings.norm.bias",
r"\d+.input_layernorm.weight",
r"\d+.input_layernorm.bias",
r"\d+.post_attention_layernorm.weight",
r"\d+.post_attention_layernorm.bias",
r"\d+.self_attention.dense.bias",
r"\d+.mlp.dense_4h_to_h.bias",
r"\d+.weight",
r"\d+.bias",
]

WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN = [
"dense_4h_to_h.weight",
"self_attention.dense.weight",
]


def _get_vocab_divisibility_padding_tensor(ds_checkpoint, padded_vocab_tensor):
checkpoint_info = ds_checkpoint.get_checkpoint_info()
if padded_vocab_tensor.shape[0] > checkpoint_info[ORIGINAL_VOCAB_SIZE]:
Expand All @@ -223,37 +177,38 @@ def merge_tp_slices(ds_checkpoint, dir, slice_dir, tp_degree, name_and_shape):
slice_base_path = os.path.join(slice_dir, name)
param_base_path = os.path.join(dir, name)

universal_checkpoint_info = ds_checkpoint.get_checkpoint_info(UNIVERSAL_CHECKPOINT_INFO)
parameters_to_average = universal_checkpoint_info.get(PARAMETERS_TO_AVERAGE_PATTERN, [])
parameters_with_row_parallelism = universal_checkpoint_info.get(PARAMETERS_WITH_ROW_PARALLELISM_PATTERN, [])
vocabulary_parameters = universal_checkpoint_info.get(VOCABULARY_PARAMETERS_PATTERN, [])
for state in ("fp32", "exp_avg", "exp_avg_sq"):
slices = _merge_zero_shards(slice_base_path, state, tp_degree, shape)
final_path = os.path.join(param_base_path, f"{state}.pt")

#print(f"Expected shape: {shape}")
#print(f"Fragment sizes:", list(frag.shape for frag in slices))
ckpt_dict = {}
if any(re.match(pattern, name) for pattern in WEIGHTS_TO_AVERAGE_PATTERNS):
if any(re.match(pattern, name) for pattern in parameters_to_average):
param = sum(slices) / len(slices)
else:
cat_dim = 1 if any(text in name for text in WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN) else 0
cat_dim = 1 if any(re.match(pattern, name) for pattern in parameters_with_row_parallelism) else 0
#print(f"CAT DIM: {cat_dim}")
param = torch.cat(slices, dim=cat_dim)
ckpt_dict['cat_dim'] = cat_dim
ckpt_dict[CAT_DIM] = cat_dim

if "word_embeddings.weight" in name:
if any(re.match(pattern, name) for pattern in vocabulary_parameters):
#print(f"Before {param.shape=}")
# strip padding
#param = _strip_vocab_padding(ds_checkpoint, param)
ckpt_dict['vocab_divisibility_padding_tensor'] = _get_vocab_divisibility_padding_tensor(ds_checkpoint, param)
ckpt_dict[VOCAB_DIVISIBILITY_PADDING_TENSOR] = _get_vocab_divisibility_padding_tensor(ds_checkpoint, param)
#print(f"After {param.shape=}")

#print(f"Final shape: {param.shape}")
ckpt_dict['param'] = param
ckpt_dict[PARAM] = param
_save_checkpoint(final_path, ckpt_dict)






def _get_chunks(l, n):
for i in range(0, len(l), n):
yield l[i:i + n]
Expand All @@ -268,9 +223,9 @@ def _do_parallel_work(do_work, work_chunks, num_workers):

def _extract_zero_shard_files(args, ds_checkpoint, slice_shapes, temp_dir):
_3d_range_list = list(itertools.product(range(ds_checkpoint.pp_degree), range(ds_checkpoint.tp_degree), range(ds_checkpoint.dp_degree)))
#pprint(_3d_range_list)
#pprint(f'{_3d_range_list=}')
work_chunks = list(_get_chunks(_3d_range_list, args.num_extract_workers))
#pprint(work_chunks)
#pprint(f'{work_chunks=}')

do_work = partial(extract_zero_shards, temp_dir, slice_shapes, ds_checkpoint)
_do_parallel_work(do_work, work_chunks, args.num_extract_workers)
Expand All @@ -295,7 +250,6 @@ def main():
)

ds_checkpoint = DeepSpeedCheckpoint(args.input_folder)#, 1, 2) # args.target_tp, args.target_pp)

iteration = ds_checkpoint.get_iteration()
#_create_latest_file(args.output_folder, iteration)
checkpoint_paths = _create_checkpoint_paths(args.output_folder, iteration,
Expand All @@ -305,7 +259,7 @@ def main():
slice_shapes = []
for mp_rank_file in ds_checkpoint.mp_rank_files:
mp_sd = torch.load(mp_rank_file, map_location=torch.device('cpu'))
slice_shapes += mp_sd["param_shapes"]
slice_shapes += mp_sd[PARAM_SHAPES]

# fix back to normal flat dict, merge duplicates for tp>1
slice_shapes = dict((k,v) for d in slice_shapes for k,v in d.items() )
Expand Down