|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
| 14 | +import importlib |
| 15 | +import os |
| 16 | + |
14 | 17 | import paddle
|
15 |
| -from paddle.distributed import fleet |
16 | 18 |
|
17 |
| -from paddlenlp.transformers import AutoModelForCausalLM, AutoTokenizer |
| 19 | +from paddlenlp.transformers import AutoConfig |
| 20 | +from paddlenlp.transformers.auto.modeling import MAPPING_NAMES |
| 21 | +from paddlenlp.utils.log import logger |
18 | 22 |
|
19 | 23 |
|
20 | 24 | def parse_arguments():
|
21 | 25 | import argparse
|
22 | 26 |
|
23 | 27 | parser = argparse.ArgumentParser()
|
24 | 28 | parser.add_argument("--model_name_or_path", default=None, required=True, help="The directory of model.")
|
25 |
| - parser.add_argument("--merge_model_path", default=None, required=True, help="The directory of merged model.") |
26 | 29 | parser.add_argument("--device", type=str, default="gpu", help="Device")
|
27 |
| - parser.add_argument("--dtype", type=str, default=None, required=True, help="Model dtype") |
28 |
| - parser.add_argument("--with_tokenizer", type=bool, default=True, help="Save tokenizer at the same time") |
29 | 30 | return parser.parse_args()
|
30 | 31 |
|
31 | 32 |
|
32 |
| -def merge(): |
| 33 | +def load_tp_params(tp_degree, path): |
| 34 | + tp_state_dict_list = [] |
| 35 | + for tp in range(tp_degree): |
| 36 | + tp_state_dict = {} |
| 37 | + tmp = paddle.load(os.path.join(path, f"model_state.tp{tp:0>2d}.pdparams"), return_numpy=True) |
| 38 | + for k, v in tmp.items(): |
| 39 | + tp_state_dict[k] = v |
| 40 | + tp_state_dict_list.append(tp_state_dict) |
| 41 | + |
| 42 | + return tp_state_dict_list |
| 43 | + |
| 44 | + |
| 45 | +def merge_tensor_parallel(model_class, state_dict_list, config) -> None: |
| 46 | + """the entry of converting config and converting model file |
| 47 | +
|
| 48 | + Args: |
| 49 | + input_dir (str | None): the input dir which contains `pytorch_model.bin` and `config.json` file |
| 50 | + config (PretrainedConfig): the PretrainedConfig instance of model |
| 51 | + """ |
| 52 | + name_action_mappings = model_class._get_tensor_parallel_mappings(config, is_split=False) |
| 53 | + state_keys_map = model_class._resolve_prefix_keys(name_action_mappings.keys(), state_dict_list[0].keys()) |
| 54 | + |
| 55 | + for k, v in state_keys_map.items(): |
| 56 | + name_action_mappings[v] = name_action_mappings.pop(k) |
| 57 | + |
| 58 | + state_dict_to_save = {} |
| 59 | + for key in state_dict_list[0].keys(): |
| 60 | + tensor = state_dict_list[0][key] |
| 61 | + if key in name_action_mappings: |
| 62 | + ret = [x[key] for x in state_dict_list] |
| 63 | + action = name_action_mappings.pop(key) |
| 64 | + tensor = action(ret) |
| 65 | + |
| 66 | + state_dict_to_save[key] = tensor |
| 67 | + |
| 68 | + if len(name_action_mappings) > 0: |
| 69 | + for x in name_action_mappings.keys(): |
| 70 | + logger.warning(f"key <{x}> need to merge tensor parallel but we can't find in model state.") |
| 71 | + |
| 72 | + logger.info("Finally, we merging state dict to fellowing tensors.") |
| 73 | + for k, v in state_dict_to_save.items(): |
| 74 | + logger.info(f"{k}, {v.shape}, {v.dtype}") |
| 75 | + |
| 76 | + return state_dict_to_save |
| 77 | + |
| 78 | + |
| 79 | +def main(): |
33 | 80 | args = parse_arguments()
|
34 | 81 | paddle.set_device(args.device)
|
35 |
| - tensor_parallel_degree = paddle.distributed.get_world_size() |
36 |
| - tensor_parallel_rank = 0 |
37 |
| - if tensor_parallel_degree > 1: |
38 |
| - strategy = fleet.DistributedStrategy() |
39 |
| - strategy.hybrid_configs = { |
40 |
| - "dp_degree": 1, |
41 |
| - "mp_degree": tensor_parallel_degree, |
42 |
| - "pp_degree": 1, |
43 |
| - "sharding_degree": 1, |
44 |
| - } |
45 |
| - fleet.init(is_collective=True, strategy=strategy) |
46 |
| - hcg = fleet.get_hybrid_communicate_group() |
47 |
| - tensor_parallel_rank = hcg.get_model_parallel_rank() |
48 |
| - |
49 |
| - model = AutoModelForCausalLM.from_pretrained( |
50 |
| - args.model_name_or_path, |
51 |
| - dtype=args.dtype, |
52 |
| - tensor_parallel_degree=tensor_parallel_degree, |
53 |
| - tensor_parallel_rank=tensor_parallel_rank, |
54 |
| - ) |
55 |
| - tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) |
56 |
| - if tensor_parallel_rank == 0: |
57 |
| - model.save_pretrained(args.merge_model_path, merge_tensor_parallel=tensor_parallel_degree > 1) |
58 |
| - tokenizer.save_pretrained(args.merge_model_path) |
| 82 | + config = AutoConfig.from_pretrained(args.model_name_or_path) |
| 83 | + init_class = config["architectures"][0] |
| 84 | + import_class = importlib.import_module(f"paddlenlp.transformers.{MAPPING_NAMES[init_class[:-11]]}.modeling") |
| 85 | + model_class = getattr(import_class, init_class) |
| 86 | + |
| 87 | + if config.tensor_parallel_degree > 1: |
| 88 | + tp_state_dict_list = load_tp_params(config.tensor_parallel_degree, args.model_name_or_path) |
| 89 | + state_dict_to_save = merge_tensor_parallel( |
| 90 | + model_class=model_class, state_dict_list=tp_state_dict_list, config=config |
| 91 | + ) |
| 92 | + |
| 93 | + logger.info("Saving") |
| 94 | + paddle.save(state_dict_to_save, os.path.join(args.model_name_or_path, "model_state.pdparams")) |
| 95 | + else: |
| 96 | + logger.info("No need to merge since config.tensor_parallel_degree <= 1.") |
59 | 97 |
|
60 | 98 |
|
61 | 99 | if __name__ == "__main__":
|
62 |
| - merge() |
| 100 | + main() |
0 commit comments