Skip to content

Commit b456110

Browse files
【AutoParallel】Support 'master_grad' in Llama in static auto-parallelism (PaddlePaddle#7658)
* add master_grad * add llama_fp16_test * polish * merge develop
1 parent b50db1c commit b456110

File tree

3 files changed

+73
-4
lines changed

3 files changed

+73
-4
lines changed

paddlenlp/trainer/training_args.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1174,9 +1174,6 @@ def is_segment_parallel_supported():
11741174
pipeline.micro_batch_size = self.per_device_train_batch_size
11751175
pipeline.schedule_mode = self.pipeline_schedule_mode
11761176

1177-
if self.amp_master_grad:
1178-
warnings.warn("`amp_master_grad` is not supported NOW in AutoParallel!")
1179-
self.amp_master_grad = False
11801177
logger.info(f"PP configs:{strategy.pipeline}, use master_grad: {self.amp_master_grad}")
11811178

11821179
if self.do_eval:
@@ -1260,6 +1257,7 @@ def is_segment_parallel_supported():
12601257
amp.enable = True
12611258
amp.dtype = "bfloat16" if self.bf16 else "float16"
12621259
amp.level = self.fp16_opt_level.lower()
1260+
amp.use_master_grad = self.amp_master_grad
12631261
amp.init_loss_scaling = self.scale_loss
12641262
amp.custom_black_list = self.amp_custom_black_list if self.amp_custom_black_list is not None else []
12651263
amp.custom_white_list = self.amp_custom_white_list if self.amp_custom_white_list is not None else []

scripts/distribute/ci_case_auto.sh

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ function llama_case_list_auto() {
5252
llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP1
5353
llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP2
5454
llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP2-VPP2-Sharding2_stage2
55+
llama_static_auto_recompute_bs16_fp16_DP2-MP2-PP2-VPP2-Sharding2_stage2
5556
}
5657

5758
function gpt_case_list_auto_pir() {
@@ -1168,6 +1169,75 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP2-VPP2-Sharding2_stage2
11681169
echo "=========== $FUNCNAME run end ==========="
11691170
}
11701171

1172+
function llama_static_auto_recompute_bs16_fp16_DP2-MP2-PP2-VPP2-Sharding2_stage2() {
1173+
echo "=========== $FUNCNAME run begin ==========="
1174+
export PYTHONPATH=$root_path/:$PYTHONPATH
1175+
export FLAGS_call_stack_level=2
1176+
1177+
task_name="llama_auto_bs16_fp16_dp2mp2pp2vpp2sharding2"
1178+
case_out_dir="output/$task_name"
1179+
case_log_dir="output/$task_name""_log"
1180+
rm -rf $case_out_dir
1181+
rm -rf $case_log_dir
1182+
1183+
python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" --log_dir $case_log_dir run_pretrain_auto.py \
1184+
--model_type "llama" \
1185+
--model_name_or_path "facebook/llama-7b" \
1186+
--tokenizer_name_or_path "facebook/llama-7b" \
1187+
--hidden_size 1024 \
1188+
--intermediate_size 3072 \
1189+
--num_hidden_layers 8 \
1190+
--num_attention_heads 32 \
1191+
--input_dir "./data" \
1192+
--output_dir $case_out_dir \
1193+
--split 949,50,1 \
1194+
--max_seq_length 2048 \
1195+
--per_device_train_batch_size 1 \
1196+
--per_device_eval_batch_size 8 \
1197+
--gradient_accumulation_steps 8 \
1198+
--use_flash_attention 0 \
1199+
--use_fused_rms_norm 0 \
1200+
--fp16 1 \
1201+
--fp16_opt_level "O2" \
1202+
--amp_master_grad 1 \
1203+
--scale_loss 1024 \
1204+
--tensor_parallel_degree 2 \
1205+
--pipeline_parallel_degree 2 \
1206+
--virtual_pp_degree 2 \
1207+
--pipeline_schedule_mode "VPP" \
1208+
--sharding_parallel_degree 2 \
1209+
--sharding "stage2" \
1210+
--learning_rate 0.0001 \
1211+
--min_learning_rate 0.00001 \
1212+
--max_steps 10 \
1213+
--save_steps 5000 \
1214+
--weight_decay 0.01 \
1215+
--warmup_ratio 0.01 \
1216+
--max_grad_norm 1.0 \
1217+
--logging_steps 1 \
1218+
--dataloader_num_workers 1 \
1219+
--eval_steps 1000 \
1220+
--report_to "visualdl" \
1221+
--disable_tqdm true \
1222+
--continue_training 0 \
1223+
--recompute 1 \
1224+
--do_train \
1225+
--do_eval \
1226+
--device "gpu" \
1227+
--data_impl "mmap" \
1228+
--parallel_mode "auto" \
1229+
>>${log_path}/$FUNCNAME 2>&1
1230+
loss=`cat $case_log_dir/workerlog.3 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'`
1231+
ips=-1
1232+
mem=-1
1233+
echo "result: loss=$loss ips=$ips mem=$mem"
1234+
loss_base=10.0859375
1235+
ips_base=-1
1236+
mem_base=-1
1237+
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
1238+
echo "=========== $FUNCNAME run end ==========="
1239+
}
1240+
11711241
function llama_dygraph_auto_bs4_fp32_DP2-MP2-PP2() {
11721242
echo "=========== $FUNCNAME run begin ==========="
11731243
export PYTHONPATH=$root_path/:$PYTHONPATH
@@ -1233,7 +1303,6 @@ function llama_dygraph_auto_bs4_fp32_DP2-MP2-PP2() {
12331303
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
12341304
echo "=========== $FUNCNAME run end ==========="
12351305
}
1236-
12371306
############ case end ############
12381307

12391308
function check_result() {

scripts/distribute/run_ci.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,13 @@ export case_list=()
2222

2323
target_lists_for_gpt=(
2424
"model_zoo/gpt-3"
25+
"scripts/distribute"
2526
)
2627

2728
target_lists_for_llama=(
2829
"llm/llama/auto_parallel"
2930
"paddlenlp/transformers/llama/modeling_auto.py"
31+
"scripts/distribute"
3032
)
3133

3234
target_path_for_ci_scripts="scripts/distribute"

0 commit comments

Comments
 (0)