Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions llm/run_pretrain_llm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#!/bin/bash

# 设置环境变量
export PYTHONPATH=../:$PYTHONPATH
export FLAGS_call_stack_level=3
export NVIDIA_TF32_OVERRIDE=0
export FLAGS_cudnn_deterministic=True
export FLAGS_embedding_deterministic=1

# 设置输出目录
task_name="llama_pretrain"
case_out_dir="output/${task_name}"
case_log_dir="output/${task_name}_log"

# 清理旧的输出目录
rm -rf $case_out_dir
rm -rf $case_log_dir

# 启动训练
python -u -m paddle.distributed.launch \
--gpus "0,1,2,3" \
--log_dir "$case_log_dir" \
run_pretrain.py \
--model_name_or_path "meta-llama/Llama-2-7b" \
--tokenizer_name_or_path "meta-llama/Llama-2-7b" \
--input_dir "./data" \
--split "949,50,1" \
--num_hidden_layers 4 \
--output_dir "$case_out_dir" \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 8 \
--per_device_eval_batch_size 8 \
--tensor_parallel_degree 4 \
--pipeline_parallel_degree 1 \
--tensor_parallel_config "enable_delay_scale_loss enable_mp_async_allreduce enable_mp_skip_c_identity" \
--pipeline_parallel_config "enable_delay_scale_loss enable_release_grads disable_partial_send_recv enable_overlap_p2p_comm" \
--virtual_pp_degree 1 \
--sequence_parallel 0 \
--use_flash_attention 0 \
--use_fused_rms_norm 0 \
--enable_linear_fused_grad_add 0 \
--learning_rate 3e-05 \
--logging_steps 1 \
--max_steps 10 \
--save_steps 11 \
--eval_steps 1000 \
--weight_decay 0.01 \
--fp16 1 \
--fp16_opt_level "O2" \
--amp_master_grad 1 \
--max_grad_norm 1.0 \
--dataloader_num_workers 1 \
--continue_training 0 \
--do_train true \
--do_eval false \
--do_predict false \
--disable_tqdm true \
--skip_profile_timer true \
--recompute 0 \
--save_total_limit 2 \
--device "gpu" \
--save_sharded_model 0 \
--unified_checkpoint 0 \
--using_flex_checkpoint 1 \
--fuse_attention_qkv true \
--fuse_attention_ffn true \
# --resume_from_checkpoint "./output/llama_pretrain/checkpoint-1"
128 changes: 87 additions & 41 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@
get_last_checkpoint,
get_scheduler,
has_length,
init_optimizer,
set_seed,
should_skip_data,
speed_metrics,
Expand Down Expand Up @@ -197,7 +198,6 @@
if is_datasets_available():
import datasets


try:
from paddle.distributed.fleet.utils import mix_precision_utils
except:
Expand Down Expand Up @@ -914,7 +914,7 @@ def train(
self._memory_tracker.start()

if not self.args.enable_auto_parallel:
if not self.args.should_load_sharding_stage1_model:
if not self.args.should_load_sharding_stage1_model and not self.args.using_flex_checkpoint:
self._load_from_checkpoint(resume_from_checkpoint)

if self.args.should_load_sharding_stage1_model:
Expand All @@ -934,14 +934,32 @@ def train(
if delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
self._load_optimizer_and_scheduler(resume_from_checkpoint)
else:
elif not self.args.using_flex_checkpoint:
model = self._wrap_model(self.model_wrapped)
# for the rest of this function `model` is the outside model, whether it was wrapped or not
if model is not self.model:
self.model_wrapped = model
if delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
self._load_optimizer_and_scheduler(resume_from_checkpoint)
else:
assert self.args.using_flex_checkpoint, "default using flex_checkpoint!"

model = self._wrap_model(self.model_wrapped)
if model is not self.model:
self.model_wrapped = model

if delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)

if resume_from_checkpoint is not None:
model_sharded_state_dict = self.model.sharded_state_dict()
self.optimizer.sharded_state_dict(model_sharded_state_dict)
init_optimizer(self.optimizer)
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
sharded_state_dict = {**model_sharded_state_dict, **optimizer_sharded_state_dict}
dist.load_state_dict(sharded_state_dict, resume_from_checkpoint, aoa_config=self.args.aoa_config)
self._load_scheduler(resume_from_checkpoint)
else:
model = self.model_wrapped
if delay_optimizer_creation:
Expand Down Expand Up @@ -1342,6 +1360,7 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
logger.warning(
f"optimizer not run, scale_before: {scale_before_value[0]}, scale_after: {scale_after_value[0]}"
)

elif isinstance(self.optimizer, HybridParallelOptimizer):
self.optimizer._step(parameters_list)
else:
Expand Down Expand Up @@ -1597,8 +1616,8 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,

logs["learning_rate"] = float("{0:.3e}".format(self._get_learning_rate()))
logs["global_step"] = int(self.state.global_step)
if in_auto_parallel_align_mode():
logs["loss_md5"] = avg_loss._md5sum()
# if in_auto_parallel_align_mode():
logs["loss_md5"] = avg_loss._md5sum()

divisor = 2**30
# TODO(@gexiao): replace these codes with unified APIs in Paddle
Expand Down Expand Up @@ -1968,7 +1987,6 @@ def apply_decay_param_fun(x):
grad_clip=nn.ClipGradByGlobalNorm(self.args.max_grad_norm) if self.args.max_grad_norm > 0 else None,
**optimizer_kwargs,
)

return self.optimizer

def _apply_to_optimizer(self, action):
Expand Down Expand Up @@ -2234,7 +2252,6 @@ def _wrap_model(self, model, training=True):
mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype)
assert self.optimizer is not None, "optimizer is empty!"
self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer)

# Pipeline mode
if in_pipeline_parallel_mode:
if self.args.amp_master_grad:
Expand Down Expand Up @@ -2284,15 +2301,13 @@ def get_expected_keys(inputs, keys):
if self.args.amp_master_grad:
self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer)
self.optimizer = fleet.distributed_optimizer(self.optimizer)

if (
hasattr(self.args, "enable_sharding_comm_overlap")
and self.args.enable_sharding_comm_overlap
and self.args.unified_checkpoint
and "split_param" in split_parallel_config(self.args.sharding_parallel_config)
):
model.register_sharding_comm_overlap_hook(self.optimizer)

# No pipeline mode, sharding only
if not in_pipeline_parallel_mode and in_sharding_parallel_mode:
# Sharded DDP!
Expand All @@ -2306,7 +2321,6 @@ def get_expected_keys(inputs, keys):
model = paddle.distributed.fleet.meta_parallel.TensorParallel(
model, hcg, strategy=fleet.fleet._user_defined_strategy
)

if ShardingOption.SHARD_OP in self.args.sharding:
if self.args.amp_master_grad:
mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype) # return value has no use
Expand Down Expand Up @@ -2348,6 +2362,7 @@ def get_expected_keys(inputs, keys):
offload=cpu_offload,
**extra_kwargs,
)

if ShardingOption.SHARD_GRAD_OP in self.args.sharding and self.args.amp_master_grad:
assert hasattr(optimizer, "use_main_grad"), (
"Current installed paddle doesn't support sharding stage 2 with main grad, "
Expand All @@ -2373,7 +2388,6 @@ def get_expected_keys(inputs, keys):
if self.args.amp_master_grad:
self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer)
self.optimizer = fleet.distributed_optimizer(self.optimizer)

# stage1 has v1 and v2 version
if in_sharding_parallel_mode and ShardingOption.SHARD_OP in self.args.sharding:
if "split_param" in self.args.sharding_parallel_config:
Expand All @@ -2388,7 +2402,6 @@ def get_expected_keys(inputs, keys):
and "enable_stage1_broadcast_overlap" in self.args.sharding_parallel_config
):
self.optimizer._set_broadcast_overlap(True, model)

return model

def _prepare_input(self, data: Union[paddle.Tensor, Any]) -> Union[paddle.Tensor, Any]:
Expand Down Expand Up @@ -2700,6 +2713,10 @@ def _save_checkpoint(self, model, metrics=None):
else:
self.save_model(output_dir)

model_sharded_state_dict = self.model.sharded_state_dict()
if self.args.using_flex_checkpoint:
os.makedirs(output_dir, exist_ok=True)

# Determine the new best metric / best model checkpoint
if metrics is not None and self.args.metric_for_best_model is not None:
metric_to_check = self.args.metric_for_best_model
Expand Down Expand Up @@ -2763,23 +2780,32 @@ def _save_checkpoint(self, model, metrics=None):
signal_dir,
)
else:
if self.dp_group.rank > 0: # this should only work for MoE saving
self._save_ckpt_func(
self._filter_moe_no_sync_optimizer_params(),
os.path.join(output_dir, optimizer_name),
saved_signal_path,
)

else:
state_dict = self.optimizer.state_dict()
save_path = os.path.join(output_dir, optimizer_name)
if self.args.use_async_save:
assert not strtobool(os.getenv("FLAG_LLM_PDC", "False")), "Dont support FLAG_LLM_PDC"
self._async_optimizer_saver.run(
state_dict, save_path, saved_signal_path=saved_signal_path
if not self.args.using_flex_checkpoint:
if self.dp_group.rank > 0: # this should only work for MoE saving
self._save_ckpt_func(
self._filter_moe_no_sync_optimizer_params(),
os.path.join(output_dir, optimizer_name),
saved_signal_path,
)

else:
self._save_ckpt_func(state_dict, save_path, saved_signal_path)
state_dict = self.optimizer.state_dict()
save_path = os.path.join(output_dir, optimizer_name)
if self.args.use_async_save:
assert not strtobool(
os.getenv("FLAG_LLM_PDC", "False")
), "Dont support FLAG_LLM_PDC"
self._async_optimizer_saver.run(
state_dict, save_path, saved_signal_path=saved_signal_path
)
else:
self._save_ckpt_func(state_dict, save_path, saved_signal_path)
else:
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
dist.save_state_dict(
{**model_sharded_state_dict, **optimizer_sharded_state_dict},
output_dir,
)
else:
if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config:
global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1
Expand All @@ -2790,7 +2816,12 @@ def _save_checkpoint(self, model, metrics=None):
or "remove_master_weight" not in self.args.unified_checkpoint_config
):
paddle.save(global_rank, os.path.join(signal_dir, f".master_weight.done.{global_rank}"))
if self.args.should_save or self.args.use_expert_parallel:

if (
self.args.should_save
or self.args.use_expert_parallel
or (self.args.data_parallel_degree > 1 and not self.args.use_hybrid_parallel)
):
if not self.args.use_hybrid_parallel:
logger.info("Saving optimizer files.")
if self.args.unified_checkpoint:
Expand All @@ -2800,7 +2831,7 @@ def _save_checkpoint(self, model, metrics=None):
output_dir,
signal_dir,
)
else:
elif not self.args.using_flex_checkpoint:
if self.args.data_parallel_rank > 0 and self.args.use_expert_parallel:
self._save_ckpt_func(
self._filter_moe_no_sync_optimizer_params(),
Expand All @@ -2814,6 +2845,13 @@ def _save_checkpoint(self, model, metrics=None):
saved_signal_path,
)

else:
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
dist.save_state_dict(
{**model_sharded_state_dict, **optimizer_sharded_state_dict},
output_dir,
)

# FIXME: maybe only save one copy
paddle.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))

Expand Down Expand Up @@ -3077,6 +3115,24 @@ def _save(
with open(path, "w") as f:
json.dump(model_meta, f)

def _load_scheduler(self, checkpoint):
if checkpoint is None:
self.runtime_timer.stop()
return

if not self.args.ignore_load_lr_and_optim:
if distributed_isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
self.lr_scheduler.set_state_dict(
paddle.load(distributed_file(os.path.join(checkpoint, SCHEDULER_NAME)))
)
else:
raise ValueError(f"scheduler-file not found, scheduler:{os.path.join(checkpoint, SCHEDULER_NAME)}")

if self.do_grad_scaling and distributed_isfile(os.path.join(checkpoint, SCALER_NAME)):
self.scaler.load_state_dict(
paddle.load(distributed_file(os.path.join(checkpoint, SCALER_NAME)), return_numpy=True)
)

def _load_optimizer_and_scheduler(self, checkpoint):
"""If optimizer and scheduler states exist, load them."""
self.runtime_timer.start("checkpoint loading time")
Expand Down Expand Up @@ -3118,6 +3174,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
and "split_param" in split_parallel_config(self.args.sharding_parallel_config)
):
model = self.model_wrapped

opt_state_dict = self.unified_checkpoint_handler.load_unified_optimizer(
model=model,
optimizer=self.optimizer,
Expand Down Expand Up @@ -3149,18 +3206,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
optimizer_name = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix)
raise ValueError(f"optimizer-state-dict not found, opt: {os.path.join(checkpoint, optimizer_name)}.")

if not self.args.ignore_load_lr_and_optim:
if distributed_isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
self.lr_scheduler.set_state_dict(
paddle.load(distributed_file(os.path.join(checkpoint, SCHEDULER_NAME)))
)
else:
raise ValueError(f"scheduler-file not found, scheduler:{os.path.join(checkpoint, SCHEDULER_NAME)}")

if self.do_grad_scaling and distributed_isfile(os.path.join(checkpoint, SCALER_NAME)):
self.scaler.load_state_dict(
paddle.load(distributed_file(os.path.join(checkpoint, SCALER_NAME)), return_numpy=True)
)
self._load_scheduler(checkpoint)

if self.args.offload_optim:
logger.info("Offloading optimizer state...")
Expand Down
Loading
Loading