Skip to content

Commit 349550a

Browse files
[MLU] Fix test_accuracy_op, test_gather_op, test_logical_op & test_static_print (#1635)
1 parent ca322a2 commit 349550a

File tree

5 files changed

+142
-63
lines changed

5 files changed

+142
-63
lines changed

backends/mlu/kernels/gather_kernel.cc

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,20 @@ void GatherKernel(const Context& dev_ctx,
2525
phi::DenseTensor* out) {
2626
dev_ctx.template Alloc<T>(out);
2727

28+
PADDLE_ENFORCE_EQ(
29+
axis.dtype() == DataType::INT32 || axis.dtype() == DataType::INT64,
30+
true,
31+
phi::errors::InvalidArgument(
32+
"The axis should be INT32 or INT64, but we get %s",
33+
DataTypeToString(axis.dtype())));
34+
35+
PADDLE_ENFORCE_EQ(
36+
index.dtype() == DataType::INT32 || index.dtype() == DataType::INT64,
37+
true,
38+
phi::errors::InvalidArgument(
39+
"The index should be INT32 or INT64, but we get %s",
40+
DataTypeToString(index.dtype())));
41+
2842
const auto index_dims = index.dims();
2943
if (index_dims.size() == 2) {
3044
PADDLE_ENFORCE_EQ(
@@ -116,11 +130,19 @@ PD_REGISTER_PLUGIN_KERNEL(gather,
116130
ALL_LAYOUT,
117131
custom_kernel::GatherKernel,
118132
float,
133+
int8_t,
134+
uint8_t,
135+
int16_t,
136+
int32_t,
119137
phi::dtype::float16) {}
120138

121139
PD_REGISTER_PLUGIN_KERNEL(gather_grad,
122140
mlu,
123141
ALL_LAYOUT,
124142
custom_kernel::GatherGradKernel,
125143
float,
144+
int8_t,
145+
uint8_t,
146+
int16_t,
147+
int32_t,
126148
phi::dtype::float16) {}

backends/mlu/tests/unittests/test_accuracy_op_mlu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def test_api(self):
8383
exe = paddle.static.Executor()
8484
(result,) = exe.run(
8585
feed={"predictions": self.input_predictions, "labels": self.input_labels},
86-
fetch_list=[self.result.name],
86+
fetch_list=[self.result],
8787
)
8888
self.assertEqual((result == self.expect_value).all(), True)
8989

backends/mlu/tests/unittests/test_gather_op_mlu.py

Lines changed: 109 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from tests.op_test import OpTest
1818

1919
import numpy as np
20-
import paddle.base as base
2120
import paddle
2221

2322
paddle.enable_static()
@@ -120,60 +119,119 @@ def test_zero_index(self):
120119

121120

122121
class TestGathertError(unittest.TestCase):
123-
def test_error1(self):
124-
with paddle.static.program_guard(
125-
paddle.static.Program(), paddle.static.Program()
126-
):
127-
paddle.set_device("mlu")
128-
129-
shape = [8, 9, 6]
130-
x = paddle.static.data(shape=shape, dtype="int8", name="x")
131-
axis = paddle.static.data(shape=[1], dtype="float32", name="axis")
132-
index = paddle.static.data(shape=shape, dtype="int32", name="index")
133-
index_float = paddle.static.data(
134-
shape=shape, dtype="float32", name="index_float"
135-
)
136-
137-
def test_x_type():
138-
paddle.gather(x, index)
139-
140-
self.assertRaises(TypeError, test_x_type)
141-
142-
def test_index_type():
143-
paddle.gather(x, index_float)
144-
145-
self.assertRaises(TypeError, test_index_type)
146-
147-
def test_axis_dtype():
148-
paddle.gather(x, index, axis=1.11)
149-
150-
self.assertRaises(TypeError, test_axis_dtype)
151-
152-
def test_axis_dtype1():
153-
paddle.gather(x, index, axis=axis)
154-
155-
self.assertRaises(TypeError, test_axis_dtype1)
122+
def setUp(self) -> None:
123+
self.place = paddle.CustomPlace("mlu", 0)
124+
paddle.set_device("mlu:0")
156125

157-
def test_error2(self):
158-
with base.program_guard(base.Program(), base.Program()):
126+
def test_error1(self):
127+
paddle.enable_static()
128+
if not paddle.framework.use_pir_api():
129+
with paddle.static.program_guard(
130+
paddle.static.Program(), paddle.static.Program()
131+
):
132+
133+
input_shape = [8, 9, 6]
134+
index_shape = [4]
135+
x_int8 = paddle.static.data(
136+
shape=input_shape, dtype="int8", name="x_int8"
137+
)
138+
x_float32 = paddle.static.data(
139+
shape=input_shape, dtype="float32", name="x_float32"
140+
)
141+
axis = paddle.static.data(shape=[1], dtype="float32", name="axis")
142+
index = paddle.static.data(
143+
shape=index_shape, dtype="int32", name="index"
144+
)
145+
index_float = paddle.static.data(
146+
shape=index_shape, dtype="float32", name="index_float"
147+
)
148+
149+
def test_x_type():
150+
paddle.gather(x_int8, index)
151+
152+
self.assertRaises(TypeError, test_x_type)
153+
154+
def test_index_type():
155+
paddle.gather(x_float32, index_float)
156+
157+
self.assertRaises(TypeError, test_index_type)
158+
159+
def test_axis_dtype():
160+
paddle.gather(x_float32, index, axis=1.11)
161+
162+
self.assertRaises(TypeError, test_axis_dtype)
163+
164+
def test_axis_dtype1():
165+
paddle.gather(x_float32, index, axis=axis)
166+
167+
self.assertRaises(TypeError, test_axis_dtype1)
168+
else:
159169
paddle.set_device("mlu")
160-
161-
shape = [8, 9, 6]
162-
x = paddle.static.data(shape=shape, dtype="int8", name="x")
163-
index = paddle.static.data(shape=shape, dtype="int32", name="mask")
164-
index_float = paddle.static.data(
165-
shape=shape, dtype="float32", name="index_float"
166-
)
167-
168-
def test_x_type():
169-
paddle.gather(x, index)
170-
171-
self.assertRaises(TypeError, test_x_type)
170+
input_shape = [8, 9, 6]
171+
index_shape = [4]
172172

173173
def test_index_type():
174-
paddle.gather(x, index_float)
175-
176-
self.assertRaises(TypeError, test_index_type)
174+
with paddle.static.program_guard(
175+
paddle.static.Program(), paddle.static.Program()
176+
):
177+
x = paddle.static.data(shape=input_shape, dtype="float32", name="x")
178+
index = paddle.static.data(
179+
shape=index_shape, dtype="float32", name="index_float"
180+
)
181+
out = paddle.gather(x, index)
182+
exe = paddle.static.Executor(place=self.place)
183+
exe.run(paddle.static.default_startup_program())
184+
self.assertRaises(
185+
ValueError,
186+
exe.run,
187+
paddle.static.default_main_program(),
188+
feed={
189+
"x": np.random.random(input_shape).astype("float32"),
190+
"index_float": np.random.random(index_shape).astype(
191+
"float32"
192+
),
193+
},
194+
)
195+
196+
def test_axis_scalar_dtype():
197+
with paddle.static.program_guard(
198+
paddle.static.Program(), paddle.static.Program()
199+
):
200+
x = paddle.static.data(shape=input_shape, dtype="float32", name="x")
201+
index = paddle.static.data(
202+
shape=index_shape, dtype="int32", name="index"
203+
)
204+
axis = paddle.static.data(shape=[1], dtype="int32", name="axis")
205+
self.assertRaises(TypeError, paddle.gather, x, index, axis=1.11)
206+
207+
def test_axis_tensor_dtype():
208+
with paddle.static.program_guard(
209+
paddle.static.Program(), paddle.static.Program()
210+
):
211+
x = paddle.static.data(shape=input_shape, dtype="float32", name="x")
212+
index = paddle.static.data(
213+
shape=index_shape, dtype="int32", name="index"
214+
)
215+
axis = paddle.static.data(shape=[1], dtype="float32", name="axis")
216+
y = paddle.gather(x, index, axis=axis)
217+
exe = paddle.static.Executor(place=self.place)
218+
exe.run(paddle.static.default_startup_program())
219+
self.assertRaises(
220+
ValueError,
221+
exe.run,
222+
paddle.static.default_main_program(),
223+
feed={
224+
"x": np.random.random(input_shape).astype("float32"),
225+
"index": np.random.randint(0, 8, index_shape).astype(
226+
"int32"
227+
),
228+
"axis": np.array([1.11]).astype("float32"),
229+
},
230+
)
231+
232+
test_index_type()
233+
test_axis_scalar_dtype()
234+
test_axis_tensor_dtype()
177235

178236

179237
if __name__ == "__main__":

backends/mlu/tests/unittests/test_logical_op_mlu.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -127,18 +127,17 @@ def test(unit_test, use_mlu=False, test_error=False):
127127

128128
def test_type_error(unit_test, use_mlu, type_str_map):
129129
def check_type(op_str, x, y, binary_op):
130-
op = getattr(paddle, op_str)
131-
error_type = ValueError
132-
if isinstance(x, np.ndarray):
133-
x = paddle.to_tensor(x)
134-
y = paddle.to_tensor(y)
135-
error_type = BaseException
136-
if binary_op:
137-
if not paddle.in_dynamic_mode():
130+
if not paddle.framework.in_dynamic_or_pir_mode():
131+
op = getattr(paddle, op_str)
132+
error_type = ValueError
133+
if isinstance(x, np.ndarray):
134+
x = paddle.to_tensor(x)
135+
y = paddle.to_tensor(y)
136+
error_type = BaseException
137+
if binary_op:
138138
error_type = TypeError
139139
unit_test.assertRaises(error_type, op, x=x, y=y, out=1)
140-
else:
141-
if not paddle.in_dynamic_mode():
140+
else:
142141
error_type = TypeError
143142
unit_test.assertRaises(error_type, op, x=x, out=1)
144143

backends/mlu/tests/unittests/test_static_print_mlu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def test_with_new_ir(self):
3939
z = x + y
4040
z = paddle.static.Print(z)
4141

42-
out = exe.run(main_program, {}, fetch_list=[z.name])
42+
out = exe.run(main_program, {}, fetch_list=[z])
4343

4444
gold_res = np.ones([2, 2], dtype="float32") * 2
4545

0 commit comments

Comments
 (0)