diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index a54fdb8d9c06..246c8a054890 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -34,6 +34,7 @@ from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import deepcopy import numpy as np import paddle import paddle.amp.auto_cast as autocast @@ -54,6 +55,9 @@ except: core = None from paddle.distributed import fleet +from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ( + DygraphShardingOptimizerV2, +) from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.hybrid_parallel_optimizer import ( HybridParallelOptimizer, ) @@ -102,6 +106,8 @@ except: pass +from paddle.distributed.flex_checkpoint.dcp.sharded_weight import ShardedWeight + from ..transformers.context_parallel_utils import split_inputs_sequence_dim_load_balance from ..transformers.model_utils import ( PretrainedModel, @@ -226,6 +232,11 @@ def in_auto_parallel_align_mode(): return False +MODEL_STATE_DIC = "model_state" +OPTIMIZER_STATE_DIC = "optimizer_state" +MASTER_WEIGHT_DIC = "master_weight" + + __all__ = ["Trainer"] @@ -842,6 +853,140 @@ def create_zcc_manager(self, unwrapped_model, resume_from_checkpoint=None): logger.info("Create zero cost checkpoint manager done.") + def _load_flex_checkpoint(self, resume_from_checkpoint): + model_sharded_state_dict = self.model.sharded_state_dict() + master_weights_path = os.path.join(resume_from_checkpoint, MASTER_WEIGHT_DIC) + opt_states_path = os.path.join(resume_from_checkpoint, OPTIMIZER_STATE_DIC) + model_states_path = os.path.join(resume_from_checkpoint, MODEL_STATE_DIC) + if not self.args.ignore_load_lr_and_optim: + state_dict_metadata = {} + metadata_paths = [ + os.path.join(model_states_path, "0.metadata"), + os.path.join(opt_states_path, "0.metadata"), + os.path.join(master_weights_path, "0.metadata"), + ] + + for metadata_file in metadata_paths: + if not os.path.exists(metadata_file): + raise FileNotFoundError(f"Metadata file not found: {metadata_file}") + metadata = paddle.load(metadata_file) + if hasattr(metadata, "state_dict_metadata"): + state_dict_metadata.update(metadata.state_dict_metadata) + else: + raise AttributeError( + f"Loaded metadata from {metadata_file} does not have 'state_dict_metadata' attribute" + ) + + init_optimizer(self.optimizer, model_sharded_state_dict, state_dict_metadata) + + optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict) + for k, v in optimizer_sharded_state_dict.items(): + v.local_tensor._clear_to_zero_allocation() + + if isinstance(self.optimizer.inner_opt, DygraphShardingOptimizerV2): + color_to_comm_buffer_list = self.optimizer._color_to_comm_buffer_list + for color, _comm_buffer_list in color_to_comm_buffer_list.items(): + for comm_buffer in _comm_buffer_list: + comm_buffer._clear_param_storage() + else: + state_dict = self.model.state_dict() + for k, v in state_dict.items(): + v._clear_to_zero_allocation() + + opt_states = {} + master_weights = {} + for k, v in optimizer_sharded_state_dict.items(): + if k.endswith(".w_0"): + master_weights[k] = v + else: + opt_states[k] = v + + for k, v in opt_states.items(): + new_v = ShardedWeight( + key=v.key, + local_tensor=paddle.zeros_like(v.local_tensor), + local_shape=deepcopy(v.local_shape), + global_shape=deepcopy(v.global_shape), + global_offset=deepcopy(v.global_offset), + is_flattened=v.is_flattened, + flattened_range=deepcopy(v.flattened_range), + ) + opt_states[k] = new_v + + dist.load_state_dict( + opt_states, + opt_states_path, + aoa_config=self.args.aoa_config, + ) + + optimizer_state_pin = {} + + for k, v in opt_states.items(): + tmp = v.local_tensor + optimizer_state_pin[k] = tmp.pin_memory() + tmp._clear_to_zero_allocation() + del tmp + + for k, v in master_weights.items(): + new_v = ShardedWeight( + key=v.key, + local_tensor=paddle.zeros_like(v.local_tensor), + local_shape=deepcopy(v.local_shape), + global_shape=deepcopy(v.global_shape), + global_offset=deepcopy(v.global_offset), + is_flattened=v.is_flattened, + flattened_range=deepcopy(v.flattened_range), + ) + master_weights[k] = new_v + + dist.load_state_dict( + master_weights, + master_weights_path, + aoa_config=self.args.aoa_config, + ) + + master_weights_pin = {} + + for k, v in master_weights.items(): + tmp = v.local_tensor + master_weights_pin[k] = tmp.pin_memory() + tmp._clear_to_zero_allocation() + del tmp + + optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict) + + optimizer_sharded_state_dict_pin = {**master_weights_pin, **optimizer_state_pin} + + for k, v in optimizer_sharded_state_dict.items(): + source_tensor = optimizer_sharded_state_dict_pin[k] + target_tensor = paddle.zeros_like(v.local_tensor) + if source_tensor.place != target_tensor.place: + source_tensor = source_tensor.to(target_tensor.place) + paddle.assign(source_tensor, target_tensor) + target_tensor_pin = target_tensor.cpu() + del target_tensor + target_tensor_pin._share_buffer_to(v.local_tensor) + del source_tensor + + if isinstance(self.optimizer.inner_opt, DygraphShardingOptimizerV2): + color_to_comm_buffer_list = self.optimizer._color_to_comm_buffer_list + for color, _comm_buffer_list in color_to_comm_buffer_list.items(): + for comm_buffer in _comm_buffer_list: + comm_buffer._reset_param_storage() + else: + state_dict = self.model.state_dict() + for k, v in state_dict.items(): + new_v = paddle.zeros_like(v) + new_v._share_buffer_to(v) + + self._load_scheduler(resume_from_checkpoint) + + dist.load_state_dict( + model_sharded_state_dict, + model_states_path, + aoa_config=self.args.aoa_config, + ) + def train( self, resume_from_checkpoint: Optional[Union[str, bool]] = None, @@ -975,28 +1120,8 @@ def train( self.model_wrapped = model if delay_optimizer_creation: self.create_optimizer_and_scheduler(num_training_steps=max_steps) - if resume_from_checkpoint is not None: - if not self.args.ignore_load_lr_and_optim: - model_sharded_state_dict = self.model.sharded_state_dict() - accessible_files = os.listdir(resume_from_checkpoint) - metadata_files = [file for file in accessible_files if file.endswith(".metadata")] - assert len(metadata_files) == 1, "Only support one metadata file now." - metadata = paddle.load(os.path.join(resume_from_checkpoint, metadata_files[0])) - state_dict_metadata = metadata.state_dict_metadata - init_optimizer(self.optimizer, model_sharded_state_dict, state_dict_metadata) - optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict) - sharded_state_dict = {**model_sharded_state_dict, **optimizer_sharded_state_dict} - dist.load_state_dict( - sharded_state_dict, resume_from_checkpoint, aoa_config=self.args.aoa_config - ) - self._load_scheduler(resume_from_checkpoint) - else: - model_sharded_state_dict = self.model.sharded_state_dict() - sharded_state_dict = model_sharded_state_dict - dist.load_state_dict( - sharded_state_dict, resume_from_checkpoint, aoa_config=self.args.aoa_config - ) + self._load_flex_checkpoint(resume_from_checkpoint) else: model = self._wrap_model(self.model_wrapped) # for the rest of this function `model` is the outside model, whether it was wrapped or not @@ -2794,7 +2919,12 @@ def _save_checkpoint(self, model, metrics=None): if self.args.save_checkpoint_format == "flex_checkpoint": model_sharded_state_dict = self.model.sharded_state_dict() - os.makedirs(output_dir, exist_ok=True) + model_state_dict_path = os.path.join(output_dir, MODEL_STATE_DIC) + os.makedirs(model_state_dict_path, exist_ok=True) + dist.save_state_dict( + model_sharded_state_dict, + model_state_dict_path, + ) # Determine the new best metric / best model checkpoint if metrics is not None and self.args.metric_for_best_model is not None: @@ -2858,10 +2988,26 @@ def _save_checkpoint(self, model, metrics=None): ) else: if self.args.save_checkpoint_format == "flex_checkpoint": + optimizer_state_dict_path = os.path.join(output_dir, OPTIMIZER_STATE_DIC) + optimizer_states = {} + master_weights = {} + + model_sharded_state_dict = self.model.sharded_state_dict() optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict) + for k, v in optimizer_sharded_state_dict.items(): + if k.endswith(".w_0"): + master_weights[k] = v + else: + optimizer_states[k] = v + dist.save_state_dict( - {**model_sharded_state_dict, **optimizer_sharded_state_dict}, - output_dir, + optimizer_states, + optimizer_state_dict_path, + ) + master_weights_path = os.path.join(output_dir, MASTER_WEIGHT_DIC) + dist.save_state_dict( + master_weights, + master_weights_path, ) if self.args.should_save: if self.tokenizer is not None and self.args.save_tokenizer: @@ -2919,10 +3065,35 @@ def _save_checkpoint(self, model, metrics=None): ) elif self.args.save_checkpoint_format == "flex_checkpoint": optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict) + model_sharded_state_dict = self.model.sharded_state_dict() + optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict) + model_state_dict_path = os.path.join(output_dir, MODEL_STATE_DIC) + os.makedirs(model_state_dict_path, exist_ok=True) dist.save_state_dict( - {**model_sharded_state_dict, **optimizer_sharded_state_dict}, - output_dir, + model_sharded_state_dict, + model_state_dict_path, ) + if not self.args.ignore_save_lr_and_optim: + optimizer_state_dict_path = os.path.join(output_dir, OPTIMIZER_STATE_DIC) + optimizer_states = {} + master_weights = {} + for k, v in optimizer_sharded_state_dict.items(): + if k.endswith(".w_0"): + master_weights[k] = v + else: + optimizer_states[k] = v + + dist.save_state_dict( + optimizer_states, + optimizer_state_dict_path, + ) + + master_weights_path = os.path.join(output_dir, MASTER_WEIGHT_DIC) + dist.save_state_dict( + master_weights, + master_weights_path, + ) + if self.args.should_save: if self.tokenizer is not None and self.args.save_tokenizer: self.tokenizer.save_pretrained(output_dir)