|
121 | 121 | from ..utils.batch_sampler import DistributedBatchSampler as NlpDistributedBatchSampler |
122 | 122 | from ..utils.env import ( |
123 | 123 | DISLORA_WEIGHTS_NAME, |
| 124 | + EMA_STATE_DIC, |
124 | 125 | LOKR_WEIGHTS_NAME, |
125 | 126 | LORA_WEIGHTS_NAME, |
| 127 | + MASTER_WEIGHT_DIC, |
126 | 128 | MODEL_META_NAME, |
| 129 | + MODEL_STATE_DIC, |
| 130 | + OPTIMIZER_STATE_DIC, |
127 | 131 | PADDLE_MASTER_WEIGHTS_INDEX_NAME, |
128 | 132 | PADDLE_OPTIMIZER_NAME, |
129 | 133 | PADDLE_PEFT_WEIGHTS_INDEX_NAME, |
|
185 | 189 | from .unified_checkpoint import UnifiedCheckpointHandler |
186 | 190 | from .utils import reshard as reshard_util |
187 | 191 | from .utils.async_save import AsyncSaver |
| 192 | +from .utils.reshard import SHARDING_STRATEGY_V1, split_opt_state |
| 193 | +from .utils.sharding_io import GroupGetter, to_device |
188 | 194 |
|
189 | 195 | try: |
190 | 196 | from .utils.zero_cost_checkpoint import ( |
@@ -673,6 +679,248 @@ def _load_from_peft_checkpoint(self, resume_from_checkpoint=None): |
673 | 679 | elif resume_from_checkpoint is not None: |
674 | 680 | logger.info(f"not loading ckpt :{self.args.dataset_rank}") |
675 | 681 |
|
| 682 | + def _load_flex_checkpoint(self, resume_from_checkpoint): |
| 683 | + def get_metadata_file_name(path): |
| 684 | + files = os.listdir(path) |
| 685 | + metadata_files = [f for f in files if f.endswith(".metadata")] |
| 686 | + assert len(metadata_files) > 0, f"Found no metadata files in {path}" |
| 687 | + assert len(metadata_files) == 1, f"Found multiple metadata files in {path}" |
| 688 | + return metadata_files[0] |
| 689 | + |
| 690 | + model_sharded_state_dict = self.model.sharded_state_dict() |
| 691 | + hf_aoa_config = self.model._gen_aoa_config(self.model.config) |
| 692 | + master_weights_path = os.path.join(resume_from_checkpoint, MASTER_WEIGHT_DIC) |
| 693 | + opt_states_path = os.path.join(resume_from_checkpoint, OPTIMIZER_STATE_DIC) |
| 694 | + model_states_path = os.path.join(resume_from_checkpoint, MODEL_STATE_DIC) |
| 695 | + |
| 696 | + if self.args.load_from_hf: |
| 697 | + hcg = dist.fleet.get_hybrid_communicate_group() |
| 698 | + assert ( |
| 699 | + self.args.ignore_load_lr_and_optim |
| 700 | + ), "Loading from HuggingFace format is only allowed when learning rate and optimizer state are ignored." |
| 701 | + try: |
| 702 | + moe_sharding_group = hcg.get_moe_sharding_parallel_group() |
| 703 | + except Exception: |
| 704 | + moe_sharding_group = None |
| 705 | + |
| 706 | + if moe_sharding_group is None or moe_sharding_group.nranks <= 1: |
| 707 | + # when moe_sharding_group is None, we use the default process_group |
| 708 | + logger.info(f"Loading model weights from '{resume_from_checkpoint}' in safetensors format.") |
| 709 | + dist.load_state_dict( |
| 710 | + model_sharded_state_dict, |
| 711 | + resume_from_checkpoint, |
| 712 | + aoa_config=hf_aoa_config, |
| 713 | + offload=self.args.load_via_cpu, |
| 714 | + safetensors=True, |
| 715 | + process_group=None, |
| 716 | + comm_method=self.args.comm_method, |
| 717 | + ) |
| 718 | + else: |
| 719 | + try: |
| 720 | + pp_group = hcg.get_pipe_parallel_group() |
| 721 | + if pp_group is None or pp_group.nranks < 1: |
| 722 | + raise NotImplementedError("Only support when pp_group is not None.") |
| 723 | + except Exception: |
| 724 | + raise RuntimeError("Only support when pp_group is not None.") |
| 725 | + |
| 726 | + try: |
| 727 | + moe_group = hcg.get_expert_parallel_group() |
| 728 | + if moe_group is None or moe_group.nranks < 1: |
| 729 | + raise NotImplementedError("Only support when moe_group is not None.") |
| 730 | + except Exception: |
| 731 | + raise RuntimeError("Only support when moe_group is not None.") |
| 732 | + moe_sharding_rank = moe_sharding_group.rank |
| 733 | + cur_rank = dist.get_rank() |
| 734 | + if moe_sharding_rank == 0: |
| 735 | + moe_group_ranks = [] |
| 736 | + dist.all_gather_object(moe_group_ranks, cur_rank, group=moe_group) |
| 737 | + pp_group_ranks = [] |
| 738 | + dist.all_gather_object(pp_group_ranks, moe_group_ranks, group=pp_group) |
| 739 | + process_group_ranks = [rank for ranks in pp_group_ranks for rank in ranks] |
| 740 | + else: |
| 741 | + process_group_ranks = [0] * (pp_group.nranks * moe_group.nranks) |
| 742 | + src_rank = hcg.get_moe_sharding_parallel_group_src_rank() |
| 743 | + dist.broadcast_object_list(process_group_ranks, src=src_rank, group=moe_sharding_group) |
| 744 | + assert any(process_group_ranks), "process_group_ranks should not be all 0" |
| 745 | + logger.info(f"Creating a temporary process group with ranks: {process_group_ranks}") |
| 746 | + process_group = dist.new_group(process_group_ranks) |
| 747 | + |
| 748 | + if moe_sharding_rank == 0: |
| 749 | + logger.info(f"Loading model weights from '{resume_from_checkpoint}' in safetensors format.") |
| 750 | + # Only the first moe_sharding process is allowed to load the model weights. |
| 751 | + dist.load_state_dict( |
| 752 | + model_sharded_state_dict, |
| 753 | + resume_from_checkpoint, |
| 754 | + aoa_config=hf_aoa_config, |
| 755 | + offload=self.args.load_via_cpu, |
| 756 | + safetensors=True, |
| 757 | + process_group=process_group, |
| 758 | + comm_method=self.args.comm_method, |
| 759 | + ) |
| 760 | + |
| 761 | + dist.barrier() |
| 762 | + logger.info("Destroying the temporary process group.") |
| 763 | + dist.destroy_process_group(process_group) |
| 764 | + # The first moe_sharding group loads the model weights and then broadcasts them to all other moe_sharding groups. |
| 765 | + logger.info( |
| 766 | + "First shard (moe_sharding_group) has loaded safetensors weights, starting broadcast on moe_sharding_groups." |
| 767 | + ) |
| 768 | + for param_name, param in self.model.state_dict().items(): |
| 769 | + dist.broadcast(param, src=src_rank, group=moe_sharding_group) |
| 770 | + logger.info("Safetensors format weights have been loaded successfully.") |
| 771 | + return |
| 772 | + |
| 773 | + if not self.args.ignore_load_lr_and_optim: |
| 774 | + state_dict_metadata = {} |
| 775 | + metadata_paths = [ |
| 776 | + os.path.join(model_states_path, get_metadata_file_name(model_states_path)), |
| 777 | + os.path.join(opt_states_path, get_metadata_file_name(opt_states_path)), |
| 778 | + os.path.join(master_weights_path, get_metadata_file_name(master_weights_path)), |
| 779 | + ] |
| 780 | + |
| 781 | + for metadata_file in metadata_paths: |
| 782 | + if not os.path.exists(metadata_file): |
| 783 | + raise FileNotFoundError(f"Metadata file not found: {metadata_file}") |
| 784 | + metadata = paddle.load(metadata_file) |
| 785 | + state_dict_metadata.update(metadata.state_dict_metadata) |
| 786 | + |
| 787 | + init_optimizer(self.optimizer, model_sharded_state_dict, state_dict_metadata) |
| 788 | + |
| 789 | + optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict) |
| 790 | + |
| 791 | + opt_states = {} |
| 792 | + master_weights = {} |
| 793 | + for k, v in optimizer_sharded_state_dict.items(): |
| 794 | + if k.endswith(".w_0"): |
| 795 | + master_weights[k] = v |
| 796 | + else: |
| 797 | + opt_states[k] = v |
| 798 | + |
| 799 | + dist.load_state_dict( |
| 800 | + opt_states, |
| 801 | + opt_states_path, |
| 802 | + aoa_config=self.args.aoa_config, |
| 803 | + offload=self.args.load_via_cpu, |
| 804 | + comm_method=self.args.comm_method, |
| 805 | + ) |
| 806 | + |
| 807 | + if not self.args.sharded_model_from_ema: |
| 808 | + dist.load_state_dict( |
| 809 | + master_weights, |
| 810 | + master_weights_path, |
| 811 | + aoa_config=self.args.aoa_config, |
| 812 | + offload=self.args.load_via_cpu, |
| 813 | + ) |
| 814 | + |
| 815 | + self._load_scheduler(resume_from_checkpoint) |
| 816 | + |
| 817 | + should_load_stage1 = self.args.should_load_sharding_stage1_model |
| 818 | + if should_load_stage1 and self.args.sharded_model_from_ema: |
| 819 | + ema_states_path = os.path.join(resume_from_checkpoint, EMA_STATE_DIC, f"{dist.get_rank()}_0.distcp") |
| 820 | + ema_state_dict = paddle.load(ema_states_path) |
| 821 | + ema_master_weights = ema_state_dict.pop("master_weights", None) |
| 822 | + opt_master_weights = self.optimizer.state_dict()["master_weights"] |
| 823 | + for k, v in opt_master_weights.items(): |
| 824 | + assert ( |
| 825 | + k in ema_master_weights |
| 826 | + ), f"{k} not in ema_master_weights, emas_master_weight keys {ema_master_weights.keys()}" |
| 827 | + paddle.assign(ema_master_weights[k], opt_master_weights[k]) |
| 828 | + |
| 829 | + ema_state_dict = reshard_util.all_gather_state_dict(ema_state_dict, lambda x: True, self.sharding_group) |
| 830 | + self.model.set_state_dict(ema_state_dict) |
| 831 | + else: |
| 832 | + dist.load_state_dict( |
| 833 | + model_sharded_state_dict, |
| 834 | + model_states_path, |
| 835 | + aoa_config=self.args.aoa_config, |
| 836 | + offload=self.args.load_via_cpu, |
| 837 | + ) |
| 838 | + |
| 839 | + if self.args.bf16 and (not self.args.ignore_load_lr_and_optim) and should_load_stage1: |
| 840 | + opt_state_dict = self.optimizer.state_dict() |
| 841 | + |
| 842 | + def recover_params_from_master_weight(opt_state_dict, group): |
| 843 | + master_weights = opt_state_dict["master_weights"] |
| 844 | + tmp = OrderedDict() |
| 845 | + (master_weights, tmp) = (tmp, master_weights) |
| 846 | + # cast to before |
| 847 | + for (k, v) in tmp.items(): |
| 848 | + name = v.name |
| 849 | + master_weights[k] = paddle.cast(to_device(v), paddle.bfloat16).cpu() |
| 850 | + master_weights[k].name = name |
| 851 | + |
| 852 | + structure_name_map = {k: v.name for (k, v) in self.model.state_dict().items()} |
| 853 | + node_model_state = reshard_util.NodeModelState(group=group) |
| 854 | + node_model_state_tmp = reshard_util.NodeModelState(group=group) |
| 855 | + node_model_state_tmp.add_master_weights(master_weights) |
| 856 | + node_model_state_tmp.pack_keys(structure_name_map) |
| 857 | + node_model_state.merge_from(node_model_state_tmp, max(group.rank, 0)) |
| 858 | + del node_model_state_tmp |
| 859 | + sharding_strategy = reshard_util.get_sharding_strategy(self.optimizer) |
| 860 | + logger.debug(f"sharding_strategy: {sharding_strategy}") |
| 861 | + restore_func = ( |
| 862 | + reshard_util.sharding_v1.restore |
| 863 | + if sharding_strategy == SHARDING_STRATEGY_V1 |
| 864 | + else reshard_util.sharding_v2.restore |
| 865 | + ) |
| 866 | + node_model_state = restore_func(node_model_state, self.model, self.optimizer) |
| 867 | + node_model_state.unpack_keys() |
| 868 | + master_weights = node_model_state.master_weights |
| 869 | + |
| 870 | + master_weights = reshard_util.all_gather_state_dict(master_weights, lambda x: True, group) |
| 871 | + |
| 872 | + model_state_dict = self.model.state_dict() |
| 873 | + for key, param in model_state_dict.items(): |
| 874 | + if param.name in master_weights: |
| 875 | + logger.debug( |
| 876 | + f"key {key}, convert master weights {param.name} shape {master_weights[param.name].shape} to param {param.name} shape{param.shape}" |
| 877 | + ) |
| 878 | + assert ( |
| 879 | + param.shape == master_weights[param.name].shape |
| 880 | + ), f"got {param.shape} vs {master_weights[param.name].shape}" |
| 881 | + master_weight = paddle.reshape(master_weights[param.name], param.shape) |
| 882 | + paddle.assign(paddle.cast(to_device(master_weight), paddle.bfloat16), model_state_dict[key]) |
| 883 | + |
| 884 | + group_getter = GroupGetter(self.model) |
| 885 | + opt_state_dict = split_opt_state(opt_state_dict, group_getter) |
| 886 | + for gid in group_getter.get_group_ids(): |
| 887 | + sub_opt_state_dict = opt_state_dict[gid] |
| 888 | + group = group_getter.get_group_by_id(gid) |
| 889 | + if self.args.bf16: |
| 890 | + recover_params_from_master_weight(sub_opt_state_dict, group) |
| 891 | + |
| 892 | + def _save_flex_model_state(self, output_dir): |
| 893 | + model_sharded_state_dict = self.model.sharded_state_dict() |
| 894 | + model_state_dict_path = os.path.join(output_dir, MODEL_STATE_DIC) |
| 895 | + os.makedirs(model_state_dict_path, exist_ok=True) |
| 896 | + dist.save_state_dict( |
| 897 | + model_sharded_state_dict, |
| 898 | + model_state_dict_path, |
| 899 | + ) |
| 900 | + |
| 901 | + def _save_flex_optimizer_state(self, output_dir): |
| 902 | + optimizer_state_dict_path = os.path.join(output_dir, OPTIMIZER_STATE_DIC) |
| 903 | + optimizer_states = {} |
| 904 | + master_weights = {} |
| 905 | + model_sharded_state_dict = self.model.sharded_state_dict() |
| 906 | + optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict) |
| 907 | + for k, v in optimizer_sharded_state_dict.items(): |
| 908 | + if k.endswith(".w_0"): |
| 909 | + master_weights[k] = v |
| 910 | + else: |
| 911 | + optimizer_states[k] = v |
| 912 | + |
| 913 | + dist.save_state_dict( |
| 914 | + optimizer_states, |
| 915 | + optimizer_state_dict_path, |
| 916 | + ) |
| 917 | + |
| 918 | + master_weights_path = os.path.join(output_dir, MASTER_WEIGHT_DIC) |
| 919 | + dist.save_state_dict( |
| 920 | + master_weights, |
| 921 | + master_weights_path, |
| 922 | + ) |
| 923 | + |
676 | 924 | def _load_from_checkpoint(self, resume_from_checkpoint=None): |
677 | 925 | """load state_dict from_checkpoint, Only load model state dict. |
678 | 926 |
|
@@ -1048,27 +1296,7 @@ def train( |
1048 | 1296 | if delay_optimizer_creation: |
1049 | 1297 | self.create_optimizer_and_scheduler(num_training_steps=max_steps) |
1050 | 1298 |
|
1051 | | - if resume_from_checkpoint is not None: |
1052 | | - if not self.args.ignore_load_lr_and_optim: |
1053 | | - model_sharded_state_dict = self.model.sharded_state_dict() |
1054 | | - accessible_files = os.listdir(resume_from_checkpoint) |
1055 | | - metadata_files = [file for file in accessible_files if file.endswith(".metadata")] |
1056 | | - assert len(metadata_files) == 1, "Only support one metadata file now." |
1057 | | - metadata = paddle.load(os.path.join(resume_from_checkpoint, metadata_files[0])) |
1058 | | - state_dict_metadata = metadata.state_dict_metadata |
1059 | | - init_optimizer(self.optimizer, model_sharded_state_dict, state_dict_metadata) |
1060 | | - optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict) |
1061 | | - sharded_state_dict = {**model_sharded_state_dict, **optimizer_sharded_state_dict} |
1062 | | - dist.load_state_dict( |
1063 | | - sharded_state_dict, resume_from_checkpoint, aoa_config=self.args.aoa_config |
1064 | | - ) |
1065 | | - self._load_scheduler(resume_from_checkpoint) |
1066 | | - else: |
1067 | | - model_sharded_state_dict = self.model.sharded_state_dict() |
1068 | | - sharded_state_dict = model_sharded_state_dict |
1069 | | - dist.load_state_dict( |
1070 | | - sharded_state_dict, resume_from_checkpoint, aoa_config=self.args.aoa_config |
1071 | | - ) |
| 1299 | + self._load_flex_checkpoint(resume_from_checkpoint) |
1072 | 1300 | else: |
1073 | 1301 | model = self._wrap_model(self.model_wrapped) |
1074 | 1302 | # for the rest of this function `model` is the outside model, whether it was wrapped or not |
@@ -2867,8 +3095,7 @@ def _save_checkpoint(self, model, metrics=None): |
2867 | 3095 | self.save_model(output_dir) |
2868 | 3096 |
|
2869 | 3097 | if self.args.save_checkpoint_format == "flex_checkpoint": |
2870 | | - model_sharded_state_dict = self.model.sharded_state_dict() |
2871 | | - os.makedirs(output_dir, exist_ok=True) |
| 3098 | + self._save_flex_model_state(output_dir) |
2872 | 3099 |
|
2873 | 3100 | # Determine the new best metric / best model checkpoint |
2874 | 3101 | if metrics is not None and self.args.metric_for_best_model is not None: |
@@ -2932,11 +3159,7 @@ def _save_checkpoint(self, model, metrics=None): |
2932 | 3159 | ) |
2933 | 3160 | else: |
2934 | 3161 | if self.args.save_checkpoint_format == "flex_checkpoint": |
2935 | | - optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict) |
2936 | | - dist.save_state_dict( |
2937 | | - {**model_sharded_state_dict, **optimizer_sharded_state_dict}, |
2938 | | - output_dir, |
2939 | | - ) |
| 3162 | + self._save_flex_optimizer_state(output_dir) |
2940 | 3163 | if self.args.should_save: |
2941 | 3164 | if self.tokenizer is not None and self.args.save_tokenizer: |
2942 | 3165 | self.tokenizer.save_pretrained(output_dir) |
@@ -2992,11 +3215,7 @@ def _save_checkpoint(self, model, metrics=None): |
2992 | 3215 | signal_dir, |
2993 | 3216 | ) |
2994 | 3217 | elif self.args.save_checkpoint_format == "flex_checkpoint": |
2995 | | - optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict) |
2996 | | - dist.save_state_dict( |
2997 | | - {**model_sharded_state_dict, **optimizer_sharded_state_dict}, |
2998 | | - output_dir, |
2999 | | - ) |
| 3218 | + self._save_flex_optimizer_state(output_dir) |
3000 | 3219 | if self.args.should_save: |
3001 | 3220 | if self.tokenizer is not None and self.args.save_tokenizer: |
3002 | 3221 | self.tokenizer.save_pretrained(output_dir) |
@@ -3039,10 +3258,7 @@ def _save_checkpoint(self, model, metrics=None): |
3039 | 3258 | self._offload_optimizer() |
3040 | 3259 | else: |
3041 | 3260 | if self.args.save_checkpoint_format == "flex_checkpoint": |
3042 | | - dist.save_state_dict( |
3043 | | - model_sharded_state_dict, |
3044 | | - output_dir, |
3045 | | - ) |
| 3261 | + self._save_flex_model_state(output_dir) |
3046 | 3262 | if self.args.should_save: |
3047 | 3263 | if self.tokenizer is not None and self.args.save_tokenizer: |
3048 | 3264 | self.tokenizer.save_pretrained(output_dir) |
|
0 commit comments