Skip to content

Commit a9319a2

Browse files
authored
[Fix] Fix unit test test_stack_extension_api (#76157)
* fix ut * fix
1 parent 3e34774 commit a9319a2

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

test/legacy_test/test_stack_extension_api.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,10 @@ def _test_static_api(
8989
names: list,
9090
):
9191
"""Test `static`, convert `Tensor` to `numpy array` before feed into graph"""
92+
# convert grad value to bool if dtype is bool
93+
grad_value = 123.0 if dtypes[0] != 'bool' else True
94+
if dtypes[0] == 'bfloat16':
95+
grad_value = paddle.to_tensor(grad_value, dtype=dtypes[0]).numpy()
9296
paddle.enable_static()
9397

9498
for device, place in PLACES:
@@ -130,8 +134,6 @@ def _test_static_api(
130134
exe = paddle.static.Executor(place)
131135
res, *res_grad = exe.run(feed=feed, fetch_list=fetch_list)
132136

133-
# convert grad value to bool if dtype is bool
134-
grad_value = 123.0 if dtypes[0] != 'bool' else True
135137
np.testing.assert_allclose(
136138
res_grad[0], np.ones(x[0].shape) * grad_value
137139
)

0 commit comments

Comments
 (0)