@@ -550,8 +550,17 @@ def is_fp16_compared_with_fp32(self):
550
550
not in op_accuracy_white_list .NO_FP16_COMPARED_WITH_FP32_OP_LIST
551
551
)
552
552
553
+ def is_bf16_compared_with_fp32 (self ):
554
+ return self .is_bfloat16_op () and (
555
+ self .op_type
556
+ not in op_accuracy_white_list .NO_BF16_COMPARED_WITH_FP32_OP_LIST
557
+ )
558
+
553
559
def enable_cal_ref_output (self ):
554
- self .is_calc_ref = self .is_fp16_compared_with_fp32 ()
560
+ self .is_calc_ref = (
561
+ self .is_fp16_compared_with_fp32 ()
562
+ or self .is_bf16_compared_with_fp32 ()
563
+ )
555
564
556
565
def disable_cal_ref_output (self ):
557
566
self .is_calc_ref = False
@@ -652,46 +661,86 @@ def feed_var(self, input_vars, place):
652
661
if isinstance (np_value , tuple ):
653
662
tensor .set (np_value [0 ], place )
654
663
dtype = np .array (np_value [1 ]).dtype
655
- if self .is_calc_ref and dtype == np .float16 :
656
- if isinstance (np_value [1 ], list ):
657
- tensor .set_recursive_sequence_lengths (
658
- np .array (np_value [1 ]).astype (np .float32 )
659
- )
664
+
665
+ if self .is_calc_ref :
666
+ # convert the float16 to float by numpy.astype
667
+ if dtype == np .float16 :
668
+ if isinstance (np_value [1 ], list ):
669
+ tensor .set_recursive_sequence_lengths (
670
+ np .array (np_value [1 ]).astype (np .float32 )
671
+ )
672
+ else :
673
+ tensor .set_recursive_sequence_lengths (
674
+ np_value [1 ].astype (np .float32 )
675
+ )
676
+ # convert the bfloat16 to float by convert_uint16_to_float
677
+ # provided in this file
678
+ elif dtype == np .uint16 :
679
+ if isinstance (np_value [1 ], list ):
680
+ tensor .set_recursive_sequence_lengths (
681
+ convert_uint16_to_float (
682
+ np .array (np_value [1 ])
683
+ )
684
+ )
685
+ else :
686
+ tensor .set_recursive_sequence_lengths (
687
+ convert_uint16_to_float (np_value [1 ])
688
+ )
660
689
else :
661
690
tensor .set_recursive_sequence_lengths (
662
- np_value [1 ]. astype ( np . float32 )
691
+ np_value [1 ]
663
692
)
664
693
else :
665
694
tensor .set_recursive_sequence_lengths (np_value [1 ])
666
695
else :
667
- if self .is_calc_ref and np_value .dtype == np .float16 :
668
- tensor .set (np_value .astype (np .float32 ), place )
696
+ if self .is_calc_ref :
697
+ if np_value .dtype == np .float16 :
698
+ tensor .set (np_value .astype (np .float32 ), place )
699
+ elif np_value .dtype == np .uint16 :
700
+ tensor .set (
701
+ convert_uint16_to_float (np_value ), place
702
+ )
703
+ else :
704
+ tensor .set (np_value , place )
669
705
else :
670
706
tensor .set (np_value , place )
671
707
feed_map [name ] = tensor
672
708
else :
673
709
tensor = core .LoDTensor ()
674
710
if isinstance (self .inputs [var_name ], tuple ):
675
711
tensor .set (self .inputs [var_name ][0 ], place )
676
- if (
677
- self .is_calc_ref
678
- and self .inputs [var_name ][1 ].dtype == np .float16
679
- ):
680
- tensor .set_recursive_sequence_lengths (
681
- self .inputs [var_name ][1 ].astype (np .float32 )
682
- )
712
+ if self .is_calc_ref :
713
+ if self .inputs [var_name ][1 ].dtype == np .float16 :
714
+ tensor .set_recursive_sequence_lengths (
715
+ self .inputs [var_name ][1 ].astype (np .float32 )
716
+ )
717
+ elif self .inputs [var_name ][1 ].dtype == np .uint16 :
718
+ tensor .set_recursive_sequence_lengths (
719
+ convert_uint16_to_float (
720
+ self .inputs [var_name ][1 ]
721
+ )
722
+ )
723
+ else :
724
+ tensor .set_recursive_sequence_lengths (
725
+ self .inputs [var_name ][1 ]
726
+ )
683
727
else :
684
728
tensor .set_recursive_sequence_lengths (
685
729
self .inputs [var_name ][1 ]
686
730
)
687
731
else :
688
- if (
689
- self .is_calc_ref
690
- and self .inputs [var_name ].dtype == np .float16
691
- ):
692
- tensor .set (
693
- self .inputs [var_name ].astype (np .float32 ), place
694
- )
732
+ if self .is_calc_ref :
733
+ if self .inputs [var_name ].dtype == np .float16 :
734
+ tensor .set (
735
+ self .inputs [var_name ].astype (np .float32 ), place
736
+ )
737
+ elif self .inputs [var_name ].dtype == np .uint16 :
738
+ tensor .set (
739
+ convert_uint16_to_float (self .inputs [var_name ]),
740
+ place ,
741
+ )
742
+ else :
743
+ tensor .set (self .inputs [var_name ], place )
695
744
else :
696
745
tensor .set (self .inputs [var_name ], place )
697
746
feed_map [var_name ] = tensor
@@ -1761,7 +1810,10 @@ def _compare_list(self, name, actual, expect):
1761
1810
def compare_single_output_with_expect (self , name , expect ):
1762
1811
actual , actual_np = self .find_actual_value (name )
1763
1812
# expect_np = expect[0] if isinstance(expect, tuple) else expect
1764
- if self .op_test .is_fp16_compared_with_fp32 ():
1813
+ if (
1814
+ self .op_test .is_fp16_compared_with_fp32 ()
1815
+ or self .op_test .is_bf16_compared_with_fp32 ()
1816
+ ):
1765
1817
expect , expect_np = self .find_expect_value (name )
1766
1818
else :
1767
1819
expect_np = (
@@ -1816,7 +1868,10 @@ def calculate_output(self):
1816
1868
)
1817
1869
self .outputs = outs
1818
1870
self .fetch_list = fetch_list
1819
- if self .op_test .is_fp16_compared_with_fp32 ():
1871
+ if (
1872
+ self .op_test .is_fp16_compared_with_fp32 ()
1873
+ or self .op_test .is_bf16_compared_with_fp32 ()
1874
+ ):
1820
1875
self .op_test .enable_cal_ref_output ()
1821
1876
ref_outs , ref_fetch_list = self .op_test ._calc_output (
1822
1877
place , no_check_set = no_check_set
@@ -1883,7 +1938,10 @@ def calculate_output(self):
1883
1938
place , no_check_set = no_check_set
1884
1939
)
1885
1940
self .outputs = dygraph_outs
1886
- if self .op_test .is_fp16_compared_with_fp32 ():
1941
+ if (
1942
+ self .op_test .is_fp16_compared_with_fp32 ()
1943
+ or self .op_test .is_bf16_compared_with_fp32 ()
1944
+ ):
1887
1945
self .op_test .enable_cal_ref_output ()
1888
1946
self .is_python_api_test = True
1889
1947
self .ref_outputs = self .op_test ._calc_python_api_output (
@@ -2228,9 +2286,8 @@ def _assert_is_close(
2228
2286
atol = atol ,
2229
2287
equal_nan = False ,
2230
2288
err_msg = (
2231
- "Operator %s error, %s variable %s (shape: %s, dtype: %s) max gradient diff over limit"
2232
- )
2233
- % (
2289
+ "Operator {} error, {} variable {} (shape: {}, dtype: {}) max gradient diff over limit"
2290
+ ).format (
2234
2291
self .op_type ,
2235
2292
msg_prefix ,
2236
2293
name ,
@@ -2486,7 +2543,10 @@ def check_grad_with_place(
2486
2543
if numeric_place is None :
2487
2544
numeric_place = place
2488
2545
2489
- if user_defined_grads is None and self .is_fp16_compared_with_fp32 ():
2546
+ if user_defined_grads is None and (
2547
+ self .is_fp16_compared_with_fp32 ()
2548
+ or self .is_bf16_compared_with_fp32 ()
2549
+ ):
2490
2550
self .enable_cal_ref_output ()
2491
2551
numeric_grads = self ._get_gradient (
2492
2552
inputs_to_check ,
@@ -2769,7 +2829,7 @@ def _get_gradient(
2769
2829
feed_dict = self .feed_var (inputs , place )
2770
2830
2771
2831
if user_defined_grad_outputs is None :
2772
- if self .dtype == np .uint16 :
2832
+ if self .dtype == np .uint16 and not self . is_calc_ref :
2773
2833
cast_inputs = list (map (block .var , output_names ))
2774
2834
if self .op_type in ["broadcast_tensors" , "meshgrid" ]:
2775
2835
output_names = self .cast_bf16_output (block , cast_inputs )
0 commit comments