File tree Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Original file line number Diff line number Diff line change @@ -89,6 +89,10 @@ def _test_static_api(
8989 names : list ,
9090 ):
9191 """Test `static`, convert `Tensor` to `numpy array` before feed into graph"""
92+ # convert grad value to bool if dtype is bool
93+ grad_value = 123.0 if dtypes [0 ] != 'bool' else True
94+ if dtypes [0 ] == 'bfloat16' :
95+ grad_value = paddle .to_tensor (grad_value , dtype = dtypes [0 ]).numpy ()
9296 paddle .enable_static ()
9397
9498 for device , place in PLACES :
@@ -130,8 +134,6 @@ def _test_static_api(
130134 exe = paddle .static .Executor (place )
131135 res , * res_grad = exe .run (feed = feed , fetch_list = fetch_list )
132136
133- # convert grad value to bool if dtype is bool
134- grad_value = 123.0 if dtypes [0 ] != 'bool' else True
135137 np .testing .assert_allclose (
136138 res_grad [0 ], np .ones (x [0 ].shape ) * grad_value
137139 )
You can’t perform that action at this time.
0 commit comments