Skip to content

Some save and load problems when incorporated with BMCook #57

@isuco

Description

@isuco

Hi, I found there might be some problems between bmt.save() and bmt.load().
Following the examples of BMCook, I loaded a gpt2-base model and trained it for several epoches. Notice that all operations of BMCook had been disabled in --cook-config. After training I invoked the bmt.save() method to save the checkpoint. However, this checkpoint seems to be mismatched with an initialized model parameters:

Traceback (most recent call last):
  File "eval.py", line 207, in <module>
 main()
  File "eval.py", line 202, in main
    bmt.load(model, args.load_path)
  File "/opt/conda/lib/python3.8/site-packages/bmtrain/store.py", line 202, in load
ret = model.load_state_dict()
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1604, in load_state_dict^M
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format()
RuntimeError: Error(s) in loading state_dict for GPT2:
        While copying the parameter named "encoder.layers.0.self_att.self_attention.project_q.weight", whose dimensions in the model are torch.Size([768, 768]) and whose dimensions in the checkpoint are torch.Size([768, 768]), an exception occurred : ('The size of tensor a (768) must match the size of tensor b (589824) at non-singleton dimension 1',).
        While copying the parameter named "encoder.layers.0.self_att.self_attention.project_k.weight", whose dimensions in the model are torch.Size([768, 768]) and whose dimensions in the checkpoint are torch.Size([768, 768]), an exception occurred : ('The size of tensor a (768) must match the size of tensor b (589824) at non-singleton dimension 1',).
        While copying the parameter named "encoder.layers.0.self_att.self_attention.project_v.weight", whose dimensions in the model are torch.Size([768, 768]) and whose dimensions in the checkpoint are torch.Size([768, 768]), an exception occurred : ('The size of tensor a (768) must match the size of tensor b (589824) at non-singleton dimension 1',).
        While copying the parameter named "encoder.layers.0.self_att.self_attention.attention_out.weight", whose dimensions in the model are torch.Size([768, 768]) and whose dimensions in the checkpoint are torch.Size([768, 768]), an exception occurred : ('The size of tensor a (768) must match the size of tensor b (589824) at non-singleton dimension 1',).
        While copying the parameter named "encoder.layers.0.ffn.ffn.w_in.w.weight", whose dimensions in the model are torch.Size([3072, 768]) and whose dimensions in the checkpoint are torch.Size([3072, 768]), an exception occurred : ('The size of tensor a (768) must match the size of tensor b (2359296) at non-singleton dimension 1',).
...
        While copying the parameter named "encoder.layers.11.ffn.ffn.w_out.weight", whose dimensions in the model are torch.Size([768, 3072]) and whose dimensions in the checkpoint are torch.Size([768, 3072]), an exception occurred : ('The size of tensor a (3072) must match the size of tensor b (2359296) at non-singleton dimension 1',).
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 253443) of binary: /opt/conda/bin/python
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/opt/conda/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/opt/conda/lib/python3.8/site-packages/torch/distributed/launch.py", line 193, in <module>
    main()
  File "/opt/conda/lib/python3.8/site-packages/torch/distributed/launch.py", line 189, in main
    launch(args)
  File "/opt/conda/lib/python3.8/site-packages/torch/distributed/launch.py", line 174, in launch
    run(args)
  File "/opt/conda/lib/python3.8/site-packages/torch/distributed/run.py", line 752, in run
    elastic_launch()
  File "/opt/conda/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 131, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/opt/conda/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 245, in launch_agent
    raise ChildFailedError()
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
eval.py FAILED

It seems that the bmt.load just could not align the saved parameters and the flatten parameters of an initialized model. I'm not sure this is caused by BMCook or BMTrain. All hyper-parameters have been aligned, including preprocess of BMCook, this is the main part of my code, which is the same as the given example of BMCook except for the final bmt.load():

bmt.init_distributed()
args = parse_args()
save_dir = Path(args.save_dir)
ckpt_dir = save_dir / 'checkpoints'
os.makedirs(ckpt_dir, exist_ok=True)
json.dump(vars(args), open(save_dir / 'train_args.json', 'w'), indent=2)

model_config = config_map[args.model].from_pretrained(args.model)
model = model_map[args.model].from_pretrained(args.model, config=model_config)
# teacher model has the same config as the student model
teacher = model_map[args.model].from_pretrained(args.model, config=model_config)

def new_forward(model_self, enc_input, enc_length, dec_input, dec_length, return_logits=False):
    return model_self.forward_old(dec_input, dec_length, output_logits=return_logits)

model.forward_old = model.forward
model.forward = types.MethodType(new_forward, model)
teacher.forward_old = teacher.forward
teacher.forward = types.MethodType(new_forward, teacher)

bmt.synchronize()

# data
batch_size = 8
dec_len = 512

loss_func = torch.nn.CrossEntropyLoss(ignore_index=-100)
optimizer = bmt.optim.AdamOptimizer(model.parameters(), scale=2**20)
lr_scheduler = bmt.lr_scheduler.Noam(optimizer, start_lr=args.start_lr, warmup_iter=2000, end_iter=100000)

# bmcook config
from bmcook.utils.config import ConfigParser
config = ConfigParser(args.cook_config)

# remove checkpointing
for _, v in model.named_modules():

    if isinstance(v, bmt.TransformerBlockList):

        def new_func(list_self, hidden_states, *args):
            for i in range(len(list_self._modules)):
                hidden_states = list_self._modules[str(i)](hidden_states, *args)
            return hidden_states

        v.forward = types.MethodType(new_func, v)

        for k in v._modules.keys():
            state_dict = v._modules[k].state_dict()
            for kk, vv in v._modules[k]._module.named_modules():
                if kk+'.weight' in state_dict:
                    vv.weight.data = state_dict[kk+'.weight'].clone().cuda()
                if kk+'.bias' in state_dict:
                    vv.bias.data = state_dict[kk+'.bias'].clone().cuda()
            v._modules[k] = v._modules[k]._module

# for distillation
Trainer.forward = BMDistill.set_forward(model, teacher, Trainer.forward, config)

# for pruning
BMPrune.compute_mask(model, config)
BMPrune.set_optim_for_pruning(optimizer)

# for quantization
BMQuant.quantize(model, config)

# for moefication
Trainer.forward = BMMoE.get_hidden(model, config, Trainer.forward)

bmt.synchronize()

average_time = 0
average_time_shift = 0.9

dataset = Dataset(
    MMapIndexedDataset(args.data_path),
    dec_len
)

if config.get('MoEfication')['is_moefy']:
    os.makedirs(save_dir / 'hiddens', exist_ok=True)
    model.eval()

    for iteration, data in enumerate(Trainer.batch_iter(dataset, batch_size, bmt.rank(), bmt.world_size())):

        if iteration == 100:
            break

        dec_input = data["ctx"].int()
        dec_length = data["len_ctx"].int()
        dec_mask = torch.arange(dec_len)[None, :].repeat(batch_size, 1) < dec_length[:, None]
        targets = torch.where(dec_mask, data["target"].long(), torch.scalar_tensor(-100, dtype=torch.long))

        targets = targets.cuda()
        dec_input = dec_input.cuda()
        dec_length = dec_length.cuda()
        
        with torch.no_grad():
            outputs = Trainer.forward(model, None, None, dec_input, dec_length, targets, loss_func)
        
        torch.save(outputs[-1], save_dir / 'hiddens' / '{}_{}'.format(iteration, bmt.rank()))
           
        bmt.print_rank("Iteration:", iteration)
    exit()

do_distill = True
distill_config = config.get('distillation')
if distill_config['ce_scale'] + distill_config['mse_hidn_scale'] + distill_config['mse_att_scale'] == 0:
    do_distill = False

bmt.load(model, args.load_path)
# model.train()
teacher.eval()
model.eval()

Metadata

Metadata

Assignees

Labels

questionFurther information is requested

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions