Skip to content

Commit 915abce

Browse files
authored
[API] Remove dtype check in static branch to allow pass bf16 data to outer (PaddlePaddle#76019)
1 parent 69887cd commit 915abce

File tree

2 files changed

+20
-30
lines changed

2 files changed

+20
-30
lines changed

python/paddle/tensor/math.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2883,21 +2883,7 @@ def outer(
28832883
else:
28842884
ny = y.reshape((1, -1))
28852885

2886-
if in_dynamic_mode():
2887-
return _C_ops.multiply(nx, ny, out=out)
2888-
2889-
def __check_input(x, y):
2890-
var_names = {'x': x, 'y': y}
2891-
for name, val in var_names.items():
2892-
check_variable_and_dtype(
2893-
val,
2894-
name,
2895-
['float16', 'float32', 'float64', 'int32', 'int64'],
2896-
'outer',
2897-
)
2898-
2899-
__check_input(nx, ny)
2900-
if in_pir_mode():
2886+
if in_dynamic_or_pir_mode():
29012887
return _C_ops.multiply(nx, ny, out=out)
29022888
else:
29032889
helper = LayerHelper('outer', **locals())

test/legacy_test/test_outer.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@
1515
import unittest
1616

1717
import numpy as np
18-
from op_test import get_device_place
18+
from op_test import (
19+
convert_float_to_uint16,
20+
convert_uint16_to_float,
21+
get_device_place,
22+
)
1923

2024
import paddle
2125

@@ -161,16 +165,6 @@ def test_multiply_dynamic(self):
161165

162166

163167
class TestMultiplyError(unittest.TestCase):
164-
def test_errors_static(self):
165-
# test static computation graph: dtype can not be int8
166-
paddle.enable_static()
167-
with paddle.static.program_guard(
168-
paddle.static.Program(), paddle.static.Program()
169-
):
170-
x = paddle.static.data(name='x', shape=[100], dtype=np.int8)
171-
y = paddle.static.data(name='y', shape=[100], dtype=np.int8)
172-
self.assertRaises(TypeError, paddle.outer, x, y)
173-
174168
def test_errors_dynamic(self):
175169
np.random.seed(7)
176170

@@ -318,6 +312,8 @@ def test_outer_alias(self):
318312
"int32",
319313
"int64",
320314
]
315+
if paddle.is_compiled_with_cuda():
316+
dtype_cases.extend(["float16", "bfloat16"])
321317

322318
for shape in shape_cases:
323319
for dtype in dtype_cases:
@@ -332,14 +328,22 @@ def test_outer_alias(self):
332328
{"input": x, "vec2": y},
333329
]
334330

331+
x_numpy = x.numpy()
332+
y_numpy = y.numpy()
333+
335334
# Get baseline result
336-
expected = np.outer(x.numpy(), y.numpy())
335+
if dtype == "bfloat16":
336+
x_numpy = convert_uint16_to_float(x_numpy)
337+
y_numpy = convert_uint16_to_float(y_numpy)
338+
expected = np.outer(x_numpy, y_numpy)
339+
if dtype == "bfloat16":
340+
expected = convert_float_to_uint16(expected)
341+
342+
rtol = 1e-5 if dtype != "bfloat16" else 1e-4
337343

338344
for params in combinations:
339345
out = paddle.outer(**params)
340-
np.testing.assert_allclose(
341-
out.numpy(), expected, rtol=1e-05
342-
)
346+
np.testing.assert_allclose(out.numpy(), expected, rtol=rtol)
343347

344348

345349
if __name__ == '__main__':

0 commit comments

Comments
 (0)