Skip to content
10 changes: 7 additions & 3 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,14 +222,17 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
if not args.deepspeed:
model = unwrap_model(model)

print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
print_rank_0('saving checkpoint at iteration {} to {}'.format(
iteration, args.save))

# Collect rng state across data parallel ranks.
rng_state = get_rng_state()

# Checkpoint name.
checkpoint_name = get_checkpoint_name(args.save, iteration)
if iteration == 'release':
checkpoint_name = get_checkpoint_name(args.save, iteration, release=True)
else:
checkpoint_name = get_checkpoint_name(args.save, iteration)

# Save distributed optimizer's custom parameter state.
if args.use_distributed_optimizer:
Expand Down Expand Up @@ -300,7 +303,7 @@ def state_dict_for_save_checkpoint_deepspeed(destination=None, prefix='', keep_v
if torch.distributed.is_initialized():
torch.distributed.barrier()

print_rank_0(' successfully saved checkpoint at iteration {:7d} to {}' \
print_rank_0(' successfully saved checkpoint at iteration {} to {}' \
.format(iteration, args.save))

# And update the latest iteration
Expand Down Expand Up @@ -509,6 +512,7 @@ def _set_arg(arg_name, old_arg_name=None, force=False):
_set_arg('apply_layernorm_1p', force=True)
_set_arg('tokenizer_type')
_set_arg('padded_vocab_size')
_set_arg('normalization', force=True)
if checkpoint_version < 3.0:
_set_arg('tensor_model_parallel_size',
'model_parallel_size')
Expand Down
8 changes: 5 additions & 3 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,9 +584,11 @@ def __init__(self, config, layer_number,
local_attn = FlashSelfAttention(causal=True, attention_dropout=config.attention_dropout)
else:
local_attn = CoreAttention(self.layer_number, config, self.attn_mask_type)

self.enable_ds_sequence_parallel = parallel_state.get_sequence_parallel_world_size() > 1 \
or args.force_ds_sequence_parallel
if hasattr(args, 'ckpt_transfer') and args.ckpt_transfer:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not notice --ckpt_transfer in the argument parsing code. How is this attribute added to args?

Copy link
Author

@uygnef uygnef Oct 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the ckpt splitting program loads the model, it actually doesn't initialize the parallel_state, so running parallel_state.get_sequence_parallel_world_size() will cause an error.

  File "/mnt/megatron-deepspeed/megatron/core/parallel_state.py", line 362, in get_sequence_parallel_group
    assert _SEQUENCE_PARALLEL_GROUP is not None, \
AssertionError: sequence parallel group is not initialized

Therefore, I used ckpt_transfer to skip getting get_sequence_parallel_world_size.
I also think this modification is not good, do you have any suggestions?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not notice --ckpt_transfer in the argument parsing code. How is this attribute added to args?

I understand that you are likely busy with many responsibilities, but I would greatly appreciate your feedback on this PR when you get a chance.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not notice --ckpt_transfer in the argument parsing code. How is this attribute added to args?

I understand that you are likely busy with many responsibilities, but I would greatly appreciate your feedback on this PR when you get a chance.

Hi, @uygnef , thank you for great your work! I am trying to use this script for convert HF LLAMA to Megatron-Deepspeed format and I met the same error AssertionError: sequence parallel group is not initialized. Do you solve this issue?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not notice --ckpt_transfer in the argument parsing code. How is this attribute added to args?

I understand that you are likely busy with many responsibilities, but I would greatly appreciate your feedback on this PR when you get a chance.

Hi, @uygnef, I changed ckpt_transfer parameter so it works. But it seems the output format is Megatron-LM format not Megatron-DeepSpeed format?

Copy link

@inkcherry inkcherry Nov 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @uygnef, thank you so much for this pr! Would it be possible for you to provide an example of a launch script(pretrain or finetune) for it?

self.enable_ds_sequence_parallel = False
else:
self.enable_ds_sequence_parallel = parallel_state.get_sequence_parallel_world_size() > 1 \
or args.force_ds_sequence_parallel
if self.enable_ds_sequence_parallel:
assert dist_attn_supported, 'Distributed attention is not supported in this DeepSpeed version'
assert args.num_attention_heads % parallel_state.get_sequence_parallel_world_size() == 0
Expand Down
11 changes: 10 additions & 1 deletion pretrain_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,23 @@
from torch import nn
import torch.nn.functional as F

def model_provider(pre_process=True, post_process=True):

def model_provider(pre_process=True, post_process=True, ckpt_transfer_model=False):
"""Build the model."""

print_rank_0('building GPT model ...')
see_memory_usage(f"Before Building Model", force=True)

args = get_args()
config = core_transformer_config_from_args(args)

if ckpt_transfer_model:
return GPTModel(config=config,
num_tokentypes=0,
parallel_output=True,
pre_process=pre_process,
post_process=post_process)

with deepspeed.zero.Init(sequence_data_parallel_group=mpu.get_sequence_data_parallel_group(),
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there must be some better solution to init model without init distibute group. please help me ..

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The distributed initialization only occurs for args.zero_stage==3. Have you tried with different stage value on command line?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The distributed initialization only occurs for args.zero_stage==3. Have you tried with different stage value on command line?

The problem is mpu.get_sequence_data_parallel_group(). How can I solve this problem?

  File "/mnt/megatron-deepspeed/pretrain_gpt.py", line 48, in model_provider
    with deepspeed.zero.Init(sequence_data_parallel_group=mpu.get_sequence_data_parallel_group(),
  File "/mnt/megatron-deepspeed/megatron/core/parallel_state.py", line 369, in get_sequence_data_parallel_group
    assert _SEQUENCE_DATA_PARALLEL_GROUP is not None, \
AssertionError: sequence data parallel group is not initialized

remote_device=None if args.remote_device == 'none' else args.remote_device,
config_dict_or_path=args.deepspeed_config,
Expand Down
18 changes: 13 additions & 5 deletions tools/checkpoint_loader_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def _load_checkpoint(queue, args):

margs = parse_args()
margs, checkpoint_args = load_args_from_checkpoint(margs)
if args.tokenizer_model:
margs.tokenizer_model = args.tokenizer_model
margs.ckpt_transfer = True

# Arguments do sanity checks on the world size, but we don't care,
# so trick it into thinking we are plenty of processes
Expand Down Expand Up @@ -124,14 +127,15 @@ def get_models(count, dtype):
post_process = mpu.is_pipeline_last_stage()
this_model = model_provider(
pre_process=pre_process,
post_process=post_process
post_process=post_process,
ckpt_transfer_model=True
).to(dtype)
model_.append(this_model)
else:
pre_process = mpu.is_pipeline_first_stage()
post_process = mpu.is_pipeline_last_stage()
model_rank = 0
model_ = [model_provider(pre_process, post_process).to(dtype)]
model_ = [model_provider(pre_process, post_process, ckpt_transfer_model=True).to(dtype)]
margs.consumed_train_samples = 0
margs.consumed_valid_samples = 0
load_checkpoint(model_, None, None)
Expand Down Expand Up @@ -236,9 +240,11 @@ def queue_put(name, msg):
# Get non-parallel tensors from tp_rank 0
layer = models[0].language_model.encoder.layers[layer_num]
message["input layernorm weight"] = layer.input_layernorm.weight.data
message["input layernorm bias"] = layer.input_layernorm.bias.data
message["post layernorm weight"] = layer.post_attention_layernorm.weight.data
message["post layernorm bias"] = layer.post_attention_layernorm.bias.data
if margs.normalization != 'rmsnorm':
message["input layernorm bias"] = layer.input_layernorm.bias.data
message["post layernorm bias"] = layer.post_attention_layernorm.bias.data

if md.linear_bias:
message["dense bias"] = layer.self_attention.dense.bias.data
message["mlp l1 bias"] = layer.mlp.dense_4h_to_h.bias.data
Expand Down Expand Up @@ -291,8 +297,9 @@ def queue_put(name, msg):
# Send final layernorm from tp_rank 0
message = {
"weight": models[0].language_model.encoder.final_layernorm.weight.data,
"bias": models[0].language_model.encoder.final_layernorm.bias.data
}
if margs.normalization != 'rmsnorm':
message["bias"] = models[0].language_model.encoder.final_layernorm.bias.data
queue_put("final layernorm", message)

if md.output_layer:
Expand Down Expand Up @@ -334,3 +341,4 @@ def load_checkpoint(queue, args):
except:
queue.put("exit")
raise

29 changes: 19 additions & 10 deletions tools/checkpoint_saver_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def check_message(msg):

# margs = megatron args
margs = get_args()
margs.ckpt_transfer = True

if hasattr(md, 'consumed_train_samples'):
margs.consumed_train_samples = md.consumed_train_samples
Expand All @@ -187,7 +188,7 @@ def check_message(msg):
raise Exception(f'unrecognized model type: {args.model_type}')

def get_models(count, dtype, pre_process, post_process):
models = [model_provider(pre_process, post_process).to(dtype) for _ in range(count)]
models = [model_provider(pre_process, post_process, ckpt_transfer_model=True).to(dtype) for _ in range(count)]
return models

# fake initializing distributed
Expand Down Expand Up @@ -262,9 +263,11 @@ def get_models(count, dtype, pre_process, post_process):

# duplicated tensors
input_layernorm_weight = msg.pop("input layernorm weight")
input_layernorm_bias = msg.pop("input layernorm bias")
post_layernorm_weight = msg.pop("post layernorm weight")
post_layernorm_bias = msg.pop("post layernorm bias")
if margs.normalization != 'rmsnorm':
post_layernorm_bias = msg.pop("post layernorm bias")
input_layernorm_bias = msg.pop("input layernorm bias")

if md.linear_bias:
dense_bias = msg.pop("dense bias")
mlp_l1_bias = msg.pop("mlp l1 bias")
Expand Down Expand Up @@ -295,11 +298,12 @@ def get_models(count, dtype, pre_process, post_process):
for tp_rank in range(args.target_tensor_parallel_size):
l = models[tp_rank].language_model.encoder.layers[layer]
l.input_layernorm.weight.data.copy_(input_layernorm_weight)
l.input_layernorm.bias.data.copy_(input_layernorm_bias)
if margs.normalization != 'rmsnorm':
l.input_layernorm.bias.data.copy_(input_layernorm_bias)
l.post_attention_layernorm.bias.data.copy_(post_layernorm_bias)
l.self_attention.query_key_value.weight.data.copy_(qkv_weight[tp_rank])
l.self_attention.dense.weight.data.copy_(dense_weight[tp_rank])
l.post_attention_layernorm.weight.data.copy_(post_layernorm_weight)
l.post_attention_layernorm.bias.data.copy_(post_layernorm_bias)
l.mlp.dense_h_to_4h.weight.data.copy_(mlp_l0_weight[tp_rank])
l.mlp.dense_4h_to_h.weight.data.copy_(mlp_l1_weight[tp_rank])
if md.linear_bias:
Expand All @@ -315,15 +319,18 @@ def get_models(count, dtype, pre_process, post_process):
if post_process:
msg = queue_get("final layernorm")
final_layernorm_weight = msg.pop("weight")
final_layernorm_bias = msg.pop("bias")
if margs.normalization != 'rmsnorm':
final_layernorm_bias = msg.pop("bias")
for tp_rank in range(args.target_tensor_parallel_size):
models[tp_rank].language_model.encoder.final_layernorm.weight.data.copy_(final_layernorm_weight)
models[tp_rank].language_model.encoder.final_layernorm.bias.data.copy_(final_layernorm_bias)
if margs.normalization != 'rmsnorm':
models[tp_rank].language_model.encoder.final_layernorm.bias.data.copy_(final_layernorm_bias)
if pp_rank != 0 and not md.output_layer:
# Copy word embeddings to final pipeline rank
models[tp_rank].word_embeddings.weight.data.copy_(out_word_embed[tp_rank])
del final_layernorm_weight
del final_layernorm_bias
if margs.normalization != 'rmsnorm':
del final_layernorm_bias
check_message(msg)

if md.output_layer:
Expand Down Expand Up @@ -361,12 +368,14 @@ def get_models(count, dtype, pre_process, post_process):
lm_head_dense_weight = msg.pop("dense weight")
lm_head_dense_bias = msg.pop("dense bias")
lm_head_layernorm_weight = msg.pop("layernorm weight")
lm_head_layernorm_bias = msg.pop("layernorm bias")
if margs.normalization != 'rmsnorm':
lm_head_layernorm_bias = msg.pop("layernorm bias")
for tp_rank in range(args.target_tensor_parallel_size):
models[tp_rank].lm_head.dense.weight.data.copy_(lm_head_dense_weight)
models[tp_rank].lm_head.dense.bias.data.copy_(lm_head_dense_bias)
models[tp_rank].lm_head.layernorm.weight.data.copy_(lm_head_layernorm_weight)
models[tp_rank].lm_head.layernorm.bias.data.copy_(lm_head_layernorm_bias)
if margs.normalization != 'rmsnorm':
models[tp_rank].lm_head.layernorm.bias.data.copy_(lm_head_layernorm_bias)
check_message(msg)
msg = queue_get()

Expand Down
7 changes: 6 additions & 1 deletion tools/checkpoint_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,16 +124,21 @@ def main():
parser.add_argument('--no-checking', action='store_false',
help='Do not perform checking on the name and ordering of weights',
dest='checking')
parser.add_argument('--tokenizer-model', type=str, default=None,
help='tokenizer-model, should be on python path')


known_args, _ = parser.parse_known_args()

loader = load_plugin('loader', known_args.loader)
saver = load_plugin('saver', known_args.saver)

loader.add_arguments(parser)
saver.add_arguments(parser)

args = parser.parse_args()

if args.tokenizer_model is None:
args.tokenizer_model = args.load_dir+"/tokenizer.model"
queue = mp.Queue(maxsize=args.max_queue_size)

print("Starting saver...")
Expand Down