Skip to content

Commit 5f6998f

Browse files
authored
add amp unit test case for auto_parallel ci. (#8966)
1 parent 30a2ac6 commit 5f6998f

File tree

1 file changed

+94
-0
lines changed

1 file changed

+94
-0
lines changed

scripts/distribute/ci_case_auto.sh

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ function llama_case_list_auto() {
5252
# llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP2
5353
# llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP2-VPP2-Sharding2_stage2
5454
llama_static_auto_recompute_bs16_fp16_DP2-MP2-PP2-VPP2-Sharding2_stage2
55+
56+
llama_align_dygraph_dy2st_auto_bs2_bf16_DP2-MP1-PP1
5557
}
5658

5759
function llm_gpt_case_list_auto() {
@@ -968,6 +970,98 @@ function llama_dy2st_auto_bs4_bf16_DP1-MP1-PP4-SD2-VPP3_split_bw() {
968970
echo "=========== $FUNCNAME run end ==========="
969971
}
970972

973+
function llama_align_dygraph_dy2st_auto_bs2_bf16_DP2-MP1-PP1() {
974+
echo "=========== $FUNCNAME run begin ==========="
975+
export PYTHONPATH=$root_path/:$PYTHONPATH
976+
export FLAGS_call_stack_level=3
977+
export NVIDIA_TF32_OVERRIDE=0
978+
export FLAGS_enable_pir_api=1
979+
export FLAGS_max_inplace_grad_add=3
980+
981+
task_name="llama_align_dygraph_dy2st_auto_bs2_bf16_dp2"
982+
case_out_dir="output/$task_name"
983+
case_log_dir="output/$task_name""_log"
984+
985+
for to_static in "0" "1"; do
986+
rm -rf $case_out_dir
987+
rm -rf $case_log_dir
988+
python -u -m paddle.distributed.launch \
989+
--gpus "0,1" \
990+
--log_dir $case_log_dir \
991+
run_pretrain_auto.py \
992+
--model_type "llama" \
993+
--model_name_or_path "facebook/llama-7b" \
994+
--tokenizer_name_or_path "facebook/llama-7b" \
995+
--input_dir "./data" \
996+
--output_dir $case_out_dir \
997+
--split 949,50,1 \
998+
--weight_decay 0.01 \
999+
--warmup_ratio 0.01 \
1000+
--warmup_steps 30 \
1001+
--max_grad_norm 0.0 \
1002+
--learning_rate 3e-05 \
1003+
--min_learning_rate 3e-06 \
1004+
--max_steps 10 \
1005+
--logging_steps 10 \
1006+
--eval_steps 1000 \
1007+
--save_steps 50000 \
1008+
--continue_training 0 \
1009+
--do_train true \
1010+
--do_eval false \
1011+
--do_predict false \
1012+
--disable_tqdm true \
1013+
--skip_profile_timer true \
1014+
--save_total_limit 2 \
1015+
--device gpu \
1016+
--disable_tqdm true \
1017+
--dataloader_num_workers 1 \
1018+
--distributed_dataloader 0 \
1019+
--enable_auto_parallel 1 \
1020+
--per_device_train_batch_size 1 \
1021+
--gradient_accumulation_steps 1 \
1022+
--per_device_eval_batch_size 2 \
1023+
--recompute false \
1024+
--recompute_use_reentrant true \
1025+
--recompute_granularity full \
1026+
--pp_recompute_interval 0 \
1027+
--bf16 1\
1028+
--fp16_opt_level "O2" \
1029+
--amp_custom_black_list "reduce_sum" "c_softmax_with_cross_entropy" \
1030+
--amp_custom_white_list "lookup_table" "lookup_table_v2" \
1031+
--amp_master_grad 1 \
1032+
--fuse_attention_ffn true \
1033+
--fuse_attention_qkv false \
1034+
--fuse_sequence_parallel_allreduce false \
1035+
--use_flash_attention 0 \
1036+
--use_fused_rope false \
1037+
--use_fused_rms_norm 0 \
1038+
--max_seq_length 4096 \
1039+
--sep_parallel_degree 1 \
1040+
--sequence_parallel false \
1041+
--pipeline_parallel_degree 1 \
1042+
--sharding_parallel_degree 1 \
1043+
--tensor_parallel_degree 1 \
1044+
--virtual_pp_degree 1 \
1045+
--pipeline_schedule_mode "VPP" \
1046+
--sharding "" \
1047+
--to_static ${to_static} \
1048+
--num_hidden_layers 2 \
1049+
>>${log_path}/$FUNCNAME 2>&1
1050+
loss=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'`
1051+
ips=-1
1052+
mem=-1
1053+
echo "result: to_static=$to_static loss=$loss ips=$ips mem=$mem"
1054+
loss_base=10.06303482
1055+
if [ $IS_A100 -ne 0 ];then
1056+
loss_base=10.056101989746093
1057+
fi
1058+
ips_base=-1
1059+
mem_base=-1
1060+
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
1061+
done
1062+
echo "=========== $FUNCNAME run end ==========="
1063+
}
1064+
9711065
function llm_gpt_dygraph_auto_bs8_fp32_DP2() {
9721066
echo "=========== $FUNCNAME run begin ==========="
9731067
export PYTHONPATH=$root_path/:$PYTHONPATH

0 commit comments

Comments
 (0)