Skip to content

Commit 8077d79

Browse files
authored
[Cherry-Pick] Modify the bf16 accuracy checking framework in OpTest (#54658)
* modify the bf16 accuracy checking framework in OpTest * modify the bf16 accuracy checking framework in OpTest * modify the bf16 accuracy checking framework in OpTest * modify the bf16 accuracy checking framework in OpTest * modify the bf16 accuracy checking framework in OpTest * modify the bf16 accuracy checking framework in OpTest
1 parent 0abd9ff commit 8077d79

File tree

3 files changed

+100
-32
lines changed

3 files changed

+100
-32
lines changed

test/legacy_test/eager_op_test.py

Lines changed: 91 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -550,8 +550,17 @@ def is_fp16_compared_with_fp32(self):
550550
not in op_accuracy_white_list.NO_FP16_COMPARED_WITH_FP32_OP_LIST
551551
)
552552

553+
def is_bf16_compared_with_fp32(self):
554+
return self.is_bfloat16_op() and (
555+
self.op_type
556+
not in op_accuracy_white_list.NO_BF16_COMPARED_WITH_FP32_OP_LIST
557+
)
558+
553559
def enable_cal_ref_output(self):
554-
self.is_calc_ref = self.is_fp16_compared_with_fp32()
560+
self.is_calc_ref = (
561+
self.is_fp16_compared_with_fp32()
562+
or self.is_bf16_compared_with_fp32()
563+
)
555564

556565
def disable_cal_ref_output(self):
557566
self.is_calc_ref = False
@@ -652,46 +661,86 @@ def feed_var(self, input_vars, place):
652661
if isinstance(np_value, tuple):
653662
tensor.set(np_value[0], place)
654663
dtype = np.array(np_value[1]).dtype
655-
if self.is_calc_ref and dtype == np.float16:
656-
if isinstance(np_value[1], list):
657-
tensor.set_recursive_sequence_lengths(
658-
np.array(np_value[1]).astype(np.float32)
659-
)
664+
665+
if self.is_calc_ref:
666+
# convert the float16 to float by numpy.astype
667+
if dtype == np.float16:
668+
if isinstance(np_value[1], list):
669+
tensor.set_recursive_sequence_lengths(
670+
np.array(np_value[1]).astype(np.float32)
671+
)
672+
else:
673+
tensor.set_recursive_sequence_lengths(
674+
np_value[1].astype(np.float32)
675+
)
676+
# convert the bfloat16 to float by convert_uint16_to_float
677+
# provided in this file
678+
elif dtype == np.uint16:
679+
if isinstance(np_value[1], list):
680+
tensor.set_recursive_sequence_lengths(
681+
convert_uint16_to_float(
682+
np.array(np_value[1])
683+
)
684+
)
685+
else:
686+
tensor.set_recursive_sequence_lengths(
687+
convert_uint16_to_float(np_value[1])
688+
)
660689
else:
661690
tensor.set_recursive_sequence_lengths(
662-
np_value[1].astype(np.float32)
691+
np_value[1]
663692
)
664693
else:
665694
tensor.set_recursive_sequence_lengths(np_value[1])
666695
else:
667-
if self.is_calc_ref and np_value.dtype == np.float16:
668-
tensor.set(np_value.astype(np.float32), place)
696+
if self.is_calc_ref:
697+
if np_value.dtype == np.float16:
698+
tensor.set(np_value.astype(np.float32), place)
699+
elif np_value.dtype == np.uint16:
700+
tensor.set(
701+
convert_uint16_to_float(np_value), place
702+
)
703+
else:
704+
tensor.set(np_value, place)
669705
else:
670706
tensor.set(np_value, place)
671707
feed_map[name] = tensor
672708
else:
673709
tensor = core.LoDTensor()
674710
if isinstance(self.inputs[var_name], tuple):
675711
tensor.set(self.inputs[var_name][0], place)
676-
if (
677-
self.is_calc_ref
678-
and self.inputs[var_name][1].dtype == np.float16
679-
):
680-
tensor.set_recursive_sequence_lengths(
681-
self.inputs[var_name][1].astype(np.float32)
682-
)
712+
if self.is_calc_ref:
713+
if self.inputs[var_name][1].dtype == np.float16:
714+
tensor.set_recursive_sequence_lengths(
715+
self.inputs[var_name][1].astype(np.float32)
716+
)
717+
elif self.inputs[var_name][1].dtype == np.uint16:
718+
tensor.set_recursive_sequence_lengths(
719+
convert_uint16_to_float(
720+
self.inputs[var_name][1]
721+
)
722+
)
723+
else:
724+
tensor.set_recursive_sequence_lengths(
725+
self.inputs[var_name][1]
726+
)
683727
else:
684728
tensor.set_recursive_sequence_lengths(
685729
self.inputs[var_name][1]
686730
)
687731
else:
688-
if (
689-
self.is_calc_ref
690-
and self.inputs[var_name].dtype == np.float16
691-
):
692-
tensor.set(
693-
self.inputs[var_name].astype(np.float32), place
694-
)
732+
if self.is_calc_ref:
733+
if self.inputs[var_name].dtype == np.float16:
734+
tensor.set(
735+
self.inputs[var_name].astype(np.float32), place
736+
)
737+
elif self.inputs[var_name].dtype == np.uint16:
738+
tensor.set(
739+
convert_uint16_to_float(self.inputs[var_name]),
740+
place,
741+
)
742+
else:
743+
tensor.set(self.inputs[var_name], place)
695744
else:
696745
tensor.set(self.inputs[var_name], place)
697746
feed_map[var_name] = tensor
@@ -1761,7 +1810,10 @@ def _compare_list(self, name, actual, expect):
17611810
def compare_single_output_with_expect(self, name, expect):
17621811
actual, actual_np = self.find_actual_value(name)
17631812
# expect_np = expect[0] if isinstance(expect, tuple) else expect
1764-
if self.op_test.is_fp16_compared_with_fp32():
1813+
if (
1814+
self.op_test.is_fp16_compared_with_fp32()
1815+
or self.op_test.is_bf16_compared_with_fp32()
1816+
):
17651817
expect, expect_np = self.find_expect_value(name)
17661818
else:
17671819
expect_np = (
@@ -1816,7 +1868,10 @@ def calculate_output(self):
18161868
)
18171869
self.outputs = outs
18181870
self.fetch_list = fetch_list
1819-
if self.op_test.is_fp16_compared_with_fp32():
1871+
if (
1872+
self.op_test.is_fp16_compared_with_fp32()
1873+
or self.op_test.is_bf16_compared_with_fp32()
1874+
):
18201875
self.op_test.enable_cal_ref_output()
18211876
ref_outs, ref_fetch_list = self.op_test._calc_output(
18221877
place, no_check_set=no_check_set
@@ -1883,7 +1938,10 @@ def calculate_output(self):
18831938
place, no_check_set=no_check_set
18841939
)
18851940
self.outputs = dygraph_outs
1886-
if self.op_test.is_fp16_compared_with_fp32():
1941+
if (
1942+
self.op_test.is_fp16_compared_with_fp32()
1943+
or self.op_test.is_bf16_compared_with_fp32()
1944+
):
18871945
self.op_test.enable_cal_ref_output()
18881946
self.is_python_api_test = True
18891947
self.ref_outputs = self.op_test._calc_python_api_output(
@@ -2228,9 +2286,8 @@ def _assert_is_close(
22282286
atol=atol,
22292287
equal_nan=False,
22302288
err_msg=(
2231-
"Operator %s error, %s variable %s (shape: %s, dtype: %s) max gradient diff over limit"
2232-
)
2233-
% (
2289+
"Operator {} error, {} variable {} (shape: {}, dtype: {}) max gradient diff over limit"
2290+
).format(
22342291
self.op_type,
22352292
msg_prefix,
22362293
name,
@@ -2486,7 +2543,10 @@ def check_grad_with_place(
24862543
if numeric_place is None:
24872544
numeric_place = place
24882545

2489-
if user_defined_grads is None and self.is_fp16_compared_with_fp32():
2546+
if user_defined_grads is None and (
2547+
self.is_fp16_compared_with_fp32()
2548+
or self.is_bf16_compared_with_fp32()
2549+
):
24902550
self.enable_cal_ref_output()
24912551
numeric_grads = self._get_gradient(
24922552
inputs_to_check,
@@ -2769,7 +2829,7 @@ def _get_gradient(
27692829
feed_dict = self.feed_var(inputs, place)
27702830

27712831
if user_defined_grad_outputs is None:
2772-
if self.dtype == np.uint16:
2832+
if self.dtype == np.uint16 and not self.is_calc_ref:
27732833
cast_inputs = list(map(block.var, output_names))
27742834
if self.op_type in ["broadcast_tensors", "meshgrid"]:
27752835
output_names = self.cast_bf16_output(block, cast_inputs)

test/legacy_test/testsuite.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def create_var(block, name, np_list, var_proto, is_calc_ref=False):
120120
if is_input:
121121
shape = list(np_value.shape)
122122
lod_level = 0
123-
if is_calc_ref and dtype == np.float16:
123+
if is_calc_ref and (dtype == np.float16 or dtype == np.uint16):
124124
dtype = np.float32
125125
return block.create_var(
126126
dtype=dtype, shape=shape, lod_level=lod_level, name=name

test/white_list/op_accuracy_white_list.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,11 @@
9494
'fake_quantize_moving_average_abs_max',
9595
'p_norm',
9696
]
97+
98+
99+
NO_BF16_COMPARED_WITH_FP32_OP_LIST = [
100+
'unique',
101+
'fusion_gru',
102+
'fusion_lstm',
103+
'dequantize',
104+
]

0 commit comments

Comments
 (0)