Skip to content

Commit 2ac8599

Browse files
authored
Fix loss for Align reduce precision with PyTorch 2.9.1. (#11199)
1 parent dcbdeaa commit 2ac8599

File tree

2 files changed

+20
-20
lines changed

2 files changed

+20
-20
lines changed

scripts/distribute/ci_case_auto.sh

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -107,27 +107,27 @@ function llama_case_list_auto() {
107107
# The test name must have "llama_" as a prefix, which will
108108
# be used for tracking the execution status of the case.
109109
llama_dygraph_auto_bs4_bf16_SD2
110-
# llama_dygraph_auto_bs8_fp32_DP2
111-
# llama_dygraph_auto_bs8_fp32_DP2-MP2
110+
llama_dygraph_auto_bs8_fp32_DP2
111+
llama_dygraph_auto_bs8_fp32_DP2-MP2
112112
llama_dygraph_auto_bs8_fp32_DP2-MP2-PP2
113-
# llama_dygraph_auto_bs8_fp16_DP2-MP2-PP2
113+
llama_dygraph_auto_bs8_fp16_DP2-MP2-PP2
114114
llama_dygraph_auto_bs8_fp16_DP2-MP2-CP2
115115
#llama_dygraph_auto_bs8_fp16_DP2-MP2-CP2_intermediate
116116
llama_dygraph_auto_bs8_fp16_DP2-MP2-PP2_hybrid_pp
117117
# llama_dygraph_auto_bs8_fp16_DP2-MP2-PP2_intermediate
118118
llama_dy2st_auto_bs4_bf16_DP1-MP1-PP4-SD2-VPP3_split_bw
119119
llama_dy2st_auto_bs4_bf16_DP1-MP1-PP4-SD2
120-
# llama_align_dygraph_dy2st_auto_bs2_bf16_DP2-MP1-PP1
121-
# llama_pir_auto_fuse_ffn_attention_qkv_MP2
120+
llama_align_dygraph_dy2st_auto_bs2_bf16_DP2-MP1-PP1
121+
llama_pir_auto_fuse_ffn_attention_qkv_MP2
122122
# llama_convert_hybrid_ckpt_to_auto_parallel_bs2_fp32_DP2-MP1-PP1
123123
llama_align_dygraph_dy2st_pir_auto_bs2_bf16_DP2-MP2-PP1-SP
124-
# llama_align_dygraph_dy2st_pir_auto_bs2_bf16_DP2-MP2-PP2-SP
124+
llama_align_dygraph_dy2st_pir_auto_bs2_bf16_DP2-MP2-PP2-SP
125125
llama_align_dygraph_dy2st_pir_auto_grad_merge_bs2_fp32_DP1-MP1-PP1
126-
# llama_align_dy2st_fthenb_and_vpp_auto_bs2_fp32_DP1-MP1-PP4
126+
llama_align_dy2st_fthenb_and_vpp_auto_bs2_fp32_DP1-MP1-PP4
127127
llama_align_dygraph_dy2st_pir_auto_pp_bs2_bf16_DP1-MP1-PP4
128128
llama_baichuan_pir_auto_fuse_ffn_attention_qkv_DP2_MP2_PP2
129129
# llama_baichuan_pir_auto_fuse_ffn_attention_qkv_DP2_MP2_PP2_intermediate
130-
# llama_dy2st_auto_bs2_bf16_DP2-MP1-PP1-CINN
130+
llama_dy2st_auto_bs2_bf16_DP2-MP1-PP1-CINN
131131
llama_lora_static_graph_auto_bs_2_bf16_DP2-TP2-PP1
132132
llama_dpo_dy2st_auto_bs2_bf16_MP8_intermediate
133133
llama_baichuan_dygraph_auto_sp_async_reduce_scatter_bs8_bf16_DP4-MP2-SP
@@ -171,7 +171,7 @@ function llm_gpt_case_list_auto() {
171171
fun_list=(
172172
# The test name must have "llm_gpt_dygraph_auto_" as a prefix,
173173
# which will be used for tracking the execution status of the case.
174-
# llm_gpt_dygraph_auto_bs8_fp32_DP2
174+
llm_gpt_dygraph_auto_bs8_fp32_DP2
175175
llm_gpt_dygraph_auto_bs8_fp32_DP2-MP2
176176
llm_gpt_dygraph_auto_bs8_fp32_DP2-MP2-PP2
177177
llm_gpt_dygraph_auto_bs8_fp16_DP2-MP2-PP2
@@ -406,7 +406,7 @@ function llama_dygraph_auto_bs8_fp32_DP2() {
406406
ips=-1
407407
mem=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'max_memory_reserved: ' '{print $2}' | awk -F ',' '{print $1}'`
408408
echo "result: loss=$loss ips=$ips mem=$mem"
409-
loss_base=9.4992733
409+
loss_base=9.49927235
410410
if [ $IS_A100 -ne 0 ];then
411411
loss_base=9.50651741
412412
fi
@@ -478,7 +478,7 @@ function llama_dygraph_auto_bs8_fp32_DP2-MP2() {
478478
ips=-1
479479
mem=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'max_memory_reserved: ' '{print $2}' | awk -F ',' '{print $1}'`
480480
echo "result: loss=$loss ips=$ips mem=$mem"
481-
loss_base=9.3507843
481+
loss_base=9.35078526
482482
if [ $IS_A100 -ne 0 ];then
483483
if [ $IS_CUDA123 -ne 0 ];then
484484
loss_base=9.38577747
@@ -636,7 +636,7 @@ function llama_dygraph_auto_bs8_fp16_DP2-MP2-PP2() {
636636
ips=-1
637637
mem=-1
638638
echo "result: loss=$loss ips=$ips mem=$mem"
639-
loss_base=9.35162163
639+
loss_base=9.35163498
640640
if [ $IS_A100 -ne 0 ];then
641641
if [ $IS_CUDA123 -ne 0 ];then
642642
loss_base=9.39367676
@@ -1563,7 +1563,7 @@ function llama_pir_auto_fuse_ffn_attention_qkv_MP2() {
15631563
loss_base_10=9.4961319
15641564
else
15651565
loss_base_2=10.53477287
1566-
loss_base_10=9.49613285
1566+
loss_base_10=9.4961319
15671567
fi
15681568
fi
15691569
check_result $FUNCNAME ${loss_base_2} ${auto_loss_2} ${ips_base} ${auto_ips} ${mem_base} ${auto_mem}
@@ -1658,7 +1658,7 @@ function llama_align_dygraph_dy2st_pir_auto_bs2_bf16_DP2-MP2-PP2-SP() {
16581658
mem=-1
16591659
echo "result: to_static=$to_static loss=$loss loss_md5=$loss_md5 ips=$ips mem=$mem"
16601660
if [ $to_static -eq 0 ];then
1661-
loss_base=9.25199432
1661+
loss_base=9.2519928
16621662
elif [ $to_static -eq 1 ];then
16631663
loss_base=9.25199356
16641664
fi
@@ -1768,7 +1768,7 @@ function llama_align_dygraph_dy2st_auto_bs2_bf16_DP2-MP1-PP1() {
17681768
mem=-1
17691769
echo "result: to_static=$to_static loss=$loss ips=$ips mem=$mem"
17701770
if [ $to_static -eq 0 ];then
1771-
loss_base=9.99302597
1771+
loss_base=9.99302673
17721772
elif [ $to_static -eq 1 ];then
17731773
loss_base=9.99302673
17741774
fi
@@ -1882,7 +1882,7 @@ function llama_dy2st_auto_bs2_bf16_DP2-MP1-PP1-CINN() {
18821882
ips=-1
18831883
mem=-1
18841884
echo "result: to_static=$to_static loss=$loss ips=$ips mem=$mem"
1885-
loss_base=9.99302597
1885+
loss_base=9.99302521
18861886
if [ $IS_A100 -ne 0 ];then
18871887
if [ $IS_CUDA123 -ne 0 ];then
18881888
loss_base=10.20989532
@@ -2203,7 +2203,7 @@ function llama_align_dy2st_fthenb_and_vpp_auto_bs2_fp32_DP1-MP1-PP4() {
22032203
if [ $IS_A100 -ne 0 ];then
22042204
check_result $FUNCNAME ${loss1} ${loss2} ${ips_base} ${ips} ${mem_base} ${mem}
22052205
else
2206-
loss_base_fthenb=10.24240494
2206+
loss_base_fthenb=10.24240398
22072207
loss_base_vpp=10.24149513 # Paddle PR#74530
22082208
echo "FThenB check"
22092209
check_result $FUNCNAME ${loss_base_fthenb} ${loss1} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -2690,7 +2690,7 @@ function llm_gpt_dygraph_auto_bs8_fp32_DP2() {
26902690
ips=-1
26912691
mem=-1
26922692
echo "result: loss=$loss ips=$ips mem=$mem loss_md5=$loss_md5"
2693-
loss_base=10.55727577 # output of dropout is different after supporting spmd
2693+
loss_base=10.55727673 # output of dropout is different after supporting spmd
26942694
ips_base=-1
26952695
mem_base=-1
26962696
if [ $IS_A100 -ne 0 ];then

scripts/distribute/ci_case_dy.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ function llm_gpt_case_list_dygraph() {
133133
fun_list=(
134134
# The test name must have "llm_gpt_" as a prefix, which will
135135
# be used for tracking the execution status of the case.
136-
# llm_gpt_recompute_bs32_bf16_MP2-SD4-stage1
136+
llm_gpt_recompute_bs32_bf16_MP2-SD4-stage1
137137
)
138138
if [ $1 = "prepare_case" ]; then
139139
restore_func $fun_list
@@ -549,7 +549,7 @@ function llm_gpt_recompute_bs32_bf16_MP2-SD4-stage1() {
549549
if [ $IS_CUDA123 -ne 0 ];then
550550
loss_base=8.93676758
551551
else
552-
loss_base=8.93362617
552+
loss_base=8.93362999
553553
fi
554554
ips_base=64.75564390065037
555555
mem_base=8904

0 commit comments

Comments
 (0)