|
| 1 | +import paddle |
| 2 | +import torch |
| 3 | + |
| 4 | + |
| 5 | +def save_checkpoint(checkpoint_path, model): |
| 6 | + """Save model and optimizer to checkpoint""" |
| 7 | + |
| 8 | + paddle.save( |
| 9 | + {"model_state": model}, |
| 10 | + checkpoint_path, |
| 11 | + ) |
| 12 | + |
| 13 | + |
| 14 | +def torch2paddle(): |
| 15 | + torch_path = "/home/aistudio/data_efficient_nopt/data/possion_64_inference/finetune_b01_m0_n8192.tar" |
| 16 | + paddle_path = "./data/pd_finetune_b01_m0_n8192.tar" |
| 17 | + |
| 18 | + torch_state_dict = torch.load(torch_path)["model_state"] |
| 19 | + # model.set_state_dict(checkpoint["model_state"]) |
| 20 | + fc_names = ["classifier", "fc"] |
| 21 | + paddle_state_dict = {} |
| 22 | + import pdb |
| 23 | + |
| 24 | + pdb.set_trace() |
| 25 | + for k in torch_state_dict: |
| 26 | + if "num_batches_tracked" in k: # 飞桨中无此参数,无需保存 |
| 27 | + continue |
| 28 | + v = torch_state_dict[k].detach().cpu().numpy() |
| 29 | + flag = [i in k for i in fc_names] |
| 30 | + if any(flag) and "weight" in k: |
| 31 | + new_shape = [1, 0] + list(range(2, v.ndim)) |
| 32 | + print( |
| 33 | + f"name: {k}, ori shape: {v.shape}, new shape: {v.transpose(new_shape).shape}" |
| 34 | + ) |
| 35 | + v = v.transpose(new_shape) # 转置 Linear 层的 weight 参数 |
| 36 | + # 将 torch.nn.BatchNorm2d 的参数名称改成 paddle.nn.BatchNorm2D 对应的参数名称 |
| 37 | + k = k.replace("running_var", "_variance") |
| 38 | + k = k.replace("running_mean", "_mean") |
| 39 | + k = k.replace("module.", "") |
| 40 | + # 添加到飞桨权重字典中 |
| 41 | + # k = k |
| 42 | + print(f"k: {k}") |
| 43 | + paddle_state_dict[k] = v |
| 44 | + print(f"paddle_state_dict: {paddle_state_dict.keys()}") |
| 45 | + save_checkpoint(paddle_path, paddle_state_dict) |
| 46 | + |
| 47 | + |
| 48 | +if __name__ == "__main__": |
| 49 | + torch2paddle() |
0 commit comments