|
| 1 | +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import sys |
| 16 | +import paddle |
| 17 | +import torch |
| 18 | +import numpy as np |
| 19 | + |
| 20 | +paddle.set_device("cpu") |
| 21 | + |
| 22 | +model = torch.load(sys.argv[1], map_location='cpu') |
| 23 | + |
| 24 | +print("The origin model keys:") |
| 25 | +for x in sorted(list(model.keys())): |
| 26 | + print(x) |
| 27 | + |
| 28 | +state = {} |
| 29 | +for sub_name, sub_param in model.items(): |
| 30 | + if sub_name.startswith("transformer"): |
| 31 | + sub_name = sub_name[12:-1] |
| 32 | + if sub_name.startswith("h."): |
| 33 | + final_name = sub_name.replace("h.", "gpt.decoder.layers.") |
| 34 | + else: |
| 35 | + final_name = sub_name |
| 36 | + state[final_name] = sub_param.numpy() |
| 37 | + |
| 38 | + |
| 39 | +def trans_name(key): |
| 40 | + k = key |
| 41 | + k = k.replace("mlp.c_fc", "linear1") |
| 42 | + k = k.replace("mlp.c_proj", "linear2") |
| 43 | + k = k.replace("attn.c_proj", "self_attn.out_proj") |
| 44 | + k = k.replace("ln_1", "norm1") |
| 45 | + k = k.replace("ln_2", "norm2") |
| 46 | + k = k.replace("ln_f", "gpt.decoder.norm") |
| 47 | + k = k.replace("wte", "gpt.embeddings.word_embeddings") |
| 48 | + k = k.replace("wpe", "gpt.embeddings.position_embeddings") |
| 49 | + return k |
| 50 | + |
| 51 | + |
| 52 | +new_state_dict = {} |
| 53 | +all_num = 0 |
| 54 | +for key in sorted(list(state.keys())): |
| 55 | + all_num += state[key].size |
| 56 | + new_key = trans_name(key) |
| 57 | + if "attn.c_attn" in key: |
| 58 | + shape = state[key].shape |
| 59 | + print(shape) |
| 60 | + if "weight" in key: |
| 61 | + q, k, v = np.split(state[key], 3, axis=1) |
| 62 | + else: |
| 63 | + print("BIAS SHAPE", state[key].shape, state[key].transpose().shape) |
| 64 | + q, k, v = np.split(state[key], 3, axis=-1) |
| 65 | + q = q.reshape((-1)) |
| 66 | + k = k.reshape((-1)) |
| 67 | + v = v.reshape((-1)) |
| 68 | + q_name = new_key.replace("attn.c_attn", "self_attn.q_proj") |
| 69 | + k_name = new_key.replace("attn.c_attn", "self_attn.k_proj") |
| 70 | + v_name = new_key.replace("attn.c_attn", "self_attn.v_proj") |
| 71 | + new_state_dict[q_name] = paddle.to_tensor(q, dtype="float32") |
| 72 | + new_state_dict[k_name] = paddle.to_tensor(k, dtype="float32") |
| 73 | + new_state_dict[v_name] = paddle.to_tensor(v, dtype="float32") |
| 74 | + continue |
| 75 | + new_state_dict[new_key] = paddle.to_tensor(state[key], dtype="float32") |
| 76 | +print("all shape numel:{}".format(all_num)) |
| 77 | +for key, value in new_state_dict.items(): |
| 78 | + print("key:{}, shape:{}, dtype:{}".format(key, value.shape, value.dtype)) |
| 79 | + |
| 80 | +orgin_path = sys.argv[1] |
| 81 | +if ".bin" in orgin_path: |
| 82 | + save_path = orgin_path.replace(".bin", ".pdparams") |
| 83 | +else: |
| 84 | + save_path = os.path.join(orgin_path, ".pdparams") |
| 85 | +paddle.save(new_state_dict, save_path) |
0 commit comments