Skip to content

Commit fb06426

Browse files
[NPU][MLU] Fix test_compare_op on NPU & MLU (#1606)
1 parent f09714f commit fb06426

File tree

2 files changed

+165
-73
lines changed

2 files changed

+165
-73
lines changed

backends/mlu/tests/unittests/test_compare_op_mlu.py

Lines changed: 97 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -15,53 +15,103 @@
1515
from __future__ import print_function
1616

1717
import unittest
18-
from tests.op_test import OpTest
1918

2019
import numpy as np
21-
import paddle.base as base
2220
import paddle
21+
import paddle.base as base
2322
from paddle.base import Program, program_guard
23+
from tests.op_test import OpTest
24+
25+
with paddle.pir_utils.OldIrGuard():
26+
from paddle.base import program_guard as old_program_guard, Program as OldProgram
2427

2528

2629
def create_test_class(op_type, typename, callback):
2730
class Cls(OpTest):
2831
def setUp(self):
29-
self.set_mlu()
30-
self.place = paddle.CustomPlace("mlu", 0)
31-
x = np.random.random(size=(10, 7)).astype(typename)
32-
y = np.random.random(size=(10, 7)).astype(typename)
33-
out = callback(x, y)
32+
self.device = "mlu:0"
33+
self.op_type = op_type
34+
self.python_api = eval("paddle." + op_type)
35+
self.set_device()
36+
37+
def init_input_output(self, shape1, shape2):
38+
x = np.random.random(size=shape1).astype(typename)
39+
y = np.random.random(size=shape2).astype(typename)
3440
self.inputs = {"X": x, "Y": y}
41+
out = callback(x, y)
3542
self.outputs = {"Out": out}
36-
self.op_type = op_type
3743

38-
def set_mlu(self):
44+
def set_device(self):
3945
self.__class__.use_custom_device = True
46+
self.place = paddle.CustomPlace(
47+
self.device.split(":")[0], int(self.device.split(":")[1])
48+
)
49+
paddle.set_device(self.device)
4050

4151
def test_output(self):
52+
self.init_input_output((10, 7), (10, 7))
53+
self.check_output_with_place(place=self.place)
54+
55+
def test_output1(self):
56+
self.init_input_output((8192,), (8192,))
57+
self.check_output_with_place(place=self.place)
58+
59+
def test_output2(self):
60+
self.init_input_output((2, 4096, 1), (2, 4096, 1))
4261
self.check_output_with_place(place=self.place)
4362

4463
def test_errors(self):
4564
paddle.enable_static()
46-
with program_guard(Program(), Program()):
47-
a = paddle.static.data(name="a", shape=[-1, 2], dtype="float32")
48-
b = paddle.static.data(name="b", shape=[-1, 2], dtype="float32")
49-
c = paddle.static.data(name="c", shape=[-1, 2], dtype="int16")
50-
d = base.create_lod_tensor(np.array([[-1]]), [[1]], self.place)
51-
52-
op = eval("paddle.%s" % self.op_type)
53-
self.assertRaises(TypeError, op, x=a, y=b, axis=True)
54-
self.assertRaises(TypeError, op, x=a, y=b, force_cpu=1)
55-
self.assertRaises(TypeError, op, x=a, y=b, cond=1)
56-
self.assertRaises(TypeError, op, x=a, y=c)
57-
self.assertRaises(TypeError, op, x=c, y=a)
58-
self.assertRaises(TypeError, op, x=a, y=d)
59-
self.assertRaises(TypeError, op, x=d, y=a)
60-
self.assertRaises(TypeError, op, x=c, y=d)
65+
op = eval("paddle." + self.op_type)
66+
67+
# TODO(LittleHeroZZZX): Remove after CI switched to PIR
68+
if not paddle.get_flags(["FLAGS_enable_pir_api"])["FLAGS_enable_pir_api"]:
69+
cases = [
70+
{"x": "a", "y": "b", "args": {"axis": True}},
71+
{"x": "a", "y": "b", "args": {"force_cpu": 1}},
72+
{"x": "a", "y": "b", "args": {"cond": 1}},
73+
{"x": "a", "y": "c", "args": {}},
74+
{"x": "c", "y": "a", "args": {}},
75+
{"x": "a", "y": "d", "args": {}},
76+
{"x": "d", "y": "a", "args": {}},
77+
{"x": "c", "y": "d", "args": {}},
78+
]
79+
80+
def build_op(case):
81+
with old_program_guard(OldProgram(), OldProgram()):
82+
a = paddle.static.data(name="a", shape=[-1, 2], dtype="float32")
83+
b = paddle.static.data(name="b", shape=[-1, 2], dtype="float32")
84+
c = paddle.static.data(name="c", shape=[-1, 2], dtype="int16")
85+
d = base.create_lod_tensor(np.array([[-1]]), [[1]], self.place)
86+
inputs = {"a": a, "b": b, "c": c, "d": d}
87+
88+
op(x=inputs[case["x"]], y=inputs[case["y"]], **case["args"])
89+
exe = paddle.static.Executor(self.place)
90+
91+
exe.run(paddle.static.default_startup_program())
92+
exe.run(paddle.static.default_main_program())
93+
94+
for case in cases:
95+
self.assertRaises(TypeError, build_op, case)
96+
else:
97+
with program_guard(Program(), Program()):
98+
a = paddle.static.data(name="a", shape=[-1, 2], dtype="float32")
99+
b = paddle.static.data(name="b", shape=[-1, 2], dtype="float32")
100+
c = paddle.static.data(name="c", shape=[-1, 2], dtype="int16")
101+
d = base.create_lod_tensor(np.array([[-1]]), [[1]], self.place)
102+
103+
self.assertRaises(TypeError, op, x=a, y=b, axis=True)
104+
self.assertRaises(TypeError, op, x=a, y=b, force_cpu=1)
105+
self.assertRaises(TypeError, op, x=a, y=b, cond=1)
106+
self.assertRaises(TypeError, op, x=a, y=c)
107+
self.assertRaises(TypeError, op, x=c, y=a)
108+
self.assertRaises(TypeError, op, x=a, y=d)
109+
self.assertRaises(TypeError, op, x=d, y=a)
110+
self.assertRaises(TypeError, op, x=c, y=d)
111+
paddle.disable_static()
61112

62113
def test_dynamic_api(self):
63114
paddle.disable_static()
64-
paddle.set_device("mlu")
65115
x = np.random.random(size=(10, 7)).astype(typename)
66116
y = np.random.random(size=(10, 7)).astype(typename)
67117
real_result = callback(x, y)
@@ -71,9 +121,22 @@ def test_dynamic_api(self):
71121
out = op(x, y)
72122
self.assertEqual((out.numpy() == real_result).all(), True)
73123

124+
def test_dynamic_api_different_type(self):
125+
if op_type != "equal":
126+
return
127+
paddle.disable_static()
128+
y = np.random.random(size=(10, 7)).astype("int32")
129+
x = np.random.random(size=(10, 7)).astype(typename)
130+
real_result = callback(x, y)
131+
x = paddle.to_tensor(x, dtype=typename).cast("float32")
132+
y = paddle.to_tensor(y, dtype="float32")
133+
op = eval("paddle.%s" % (self.op_type))
134+
out = op(x, y)
135+
136+
self.assertEqual((out.numpy() == real_result).all(), True)
137+
74138
def test_broadcast_api_1(self):
75139
paddle.enable_static()
76-
paddle.set_device("mlu")
77140
with program_guard(Program(), Program()):
78141
x = paddle.static.data(name="x", shape=[1, 2, 1, 3], dtype=typename)
79142
y = paddle.static.data(name="y", shape=[1, 2, 3], dtype=typename)
@@ -88,7 +151,6 @@ def test_broadcast_api_1(self):
88151

89152
def test_broadcast_api_2(self):
90153
paddle.enable_static()
91-
paddle.set_device("mlu")
92154
with program_guard(Program(), Program()):
93155
x = paddle.static.data(name="x", shape=[1, 2, 3], dtype=typename)
94156
y = paddle.static.data(name="y", shape=[1, 2, 1, 3], dtype=typename)
@@ -103,7 +165,6 @@ def test_broadcast_api_2(self):
103165

104166
def test_broadcast_api_3(self):
105167
paddle.enable_static()
106-
paddle.set_device("mlu")
107168
with program_guard(Program(), Program()):
108169
x = paddle.static.data(name="x", shape=[5], dtype=typename)
109170
y = paddle.static.data(name="y", shape=[3, 1], dtype=typename)
@@ -118,20 +179,20 @@ def test_broadcast_api_3(self):
118179

119180
def test_attr_name(self):
120181
paddle.enable_static()
121-
paddle.set_device("mlu")
122-
with program_guard(Program(), Program()):
123-
x = paddle.static.data(name="x", shape=[-1, 4], dtype=typename)
124-
y = paddle.static.data(name="y", shape=[-1, 4], dtype=typename)
125-
op = eval("paddle.%s" % (self.op_type))
126-
out = op(x=x, y=y, name="name_%s" % (self.op_type))
127-
self.assertEqual("name_%s" % (self.op_type) in out.name, True)
182+
with paddle.pir_utils.OldIrGuard():
183+
with old_program_guard(OldProgram(), OldProgram()):
184+
x = paddle.static.data(name="x", shape=[-1, 4], dtype=typename)
185+
y = paddle.static.data(name="y", shape=[-1, 4], dtype=typename)
186+
op = eval("paddle.%s" % (self.op_type))
187+
out = op(x=x, y=y, name="name_%s" % (self.op_type))
188+
self.assertEqual("name_%s" % (self.op_type) in out.name, True)
128189

129190
cls_name = "{0}_{1}".format(op_type, typename)
130191
Cls.__name__ = cls_name
131192
globals()[cls_name] = Cls
132193

133194

134-
for _type_name in {"float16", "float32", "int32", "bool", "int64"}:
195+
for _type_name in {"float16", "float32", "int32", "int64", "bool"}:
135196
create_test_class("equal", _type_name, lambda _a, _b: _a == _b)
136197
create_test_class("not_equal", _type_name, lambda _a, _b: _a != _b)
137198
create_test_class("less_than", _type_name, lambda _a, _b: _a < _b)

backends/npu/tests/unittests/test_compare_op_npu.py

Lines changed: 68 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,20 @@
2323
from tests.op_test import OpTest, convert_float_to_uint16
2424
from npu_utils import get_cann_version
2525

26+
with paddle.pir_utils.OldIrGuard():
27+
from paddle.base import program_guard as old_program_guard, Program as OldProgram
28+
29+
2630
CANN_VERSION_CODE = get_cann_version()
2731

2832

2933
def create_test_class(op_type, typename, callback):
3034
class Cls(OpTest):
3135
def setUp(self):
32-
self.set_npu()
33-
self.place = paddle.CustomPlace("npu", 0)
36+
self.device = "npu:0"
3437
self.op_type = op_type
38+
self.python_api = eval("paddle." + op_type)
39+
self.set_device()
3540

3641
def init_input_output(self, shape1, shape2):
3742
if typename == "bfloat16":
@@ -48,8 +53,16 @@ def init_input_output(self, shape1, shape2):
4853
out = callback(x, y)
4954
self.outputs = {"Out": out}
5055

51-
def set_npu(self):
56+
def set_device(self):
5257
self.__class__.use_custom_device = True
58+
self.place = paddle.CustomPlace(
59+
self.device.split(":")[0], int(self.device.split(":")[1])
60+
)
61+
paddle.set_device(self.device)
62+
paddle.enable_static()
63+
exe = paddle.static.Executor(self.place)
64+
exe.run(paddle.static.default_startup_program())
65+
paddle.disable_static()
5366

5467
def test_output(self):
5568
self.init_input_output((10, 7), (10, 7))
@@ -65,38 +78,56 @@ def test_output2(self):
6578

6679
def test_errors(self):
6780
paddle.enable_static()
68-
with program_guard(Program(), Program()):
69-
a = paddle.static.data(name="a", shape=[-1, 2], dtype="float32")
70-
b = paddle.static.data(name="b", shape=[-1, 2], dtype="float32")
71-
c = paddle.static.data(name="c", shape=[-1, 2], dtype="int16")
72-
d = base.create_lod_tensor(np.array([[-1]]), [[1]], self.place)
73-
74-
op = eval("paddle.%s" % self.op_type)
75-
self.assertRaises(TypeError, op, x=a, y=b, axis=True)
76-
self.assertRaises(TypeError, op, x=a, y=b, force_cpu=1)
77-
self.assertRaises(TypeError, op, x=a, y=b, cond=1)
78-
79-
try:
80-
result = op(x=a, y=c)
81-
except TypeError:
82-
self.fail(
83-
"TypeError should not raised for float32 and int16 inputs"
84-
)
81+
op = eval("paddle." + self.op_type)
8582

86-
try:
87-
result = op(x=c, y=a)
88-
except TypeError:
89-
self.fail(
90-
"TypeError should not raised for int16 and float32 inputs"
91-
)
83+
# TODO(LittleHeroZZZX): Remove after CI switched to PIR
84+
if not paddle.get_flags(["FLAGS_enable_pir_api"])["FLAGS_enable_pir_api"]:
85+
cases = [
86+
{"x": "a", "y": "b", "args": {"axis": True}},
87+
{"x": "a", "y": "b", "args": {"force_cpu": 1}},
88+
{"x": "a", "y": "b", "args": {"cond": 1}},
89+
{"x": "a", "y": "c", "args": {}},
90+
{"x": "c", "y": "a", "args": {}},
91+
{"x": "a", "y": "d", "args": {}},
92+
{"x": "d", "y": "a", "args": {}},
93+
{"x": "c", "y": "d", "args": {}},
94+
]
95+
96+
def build_op(case):
97+
with old_program_guard(OldProgram(), OldProgram()):
98+
a = paddle.static.data(name="a", shape=[-1, 2], dtype="float32")
99+
b = paddle.static.data(name="b", shape=[-1, 2], dtype="float32")
100+
c = paddle.static.data(name="c", shape=[-1, 2], dtype="int16")
101+
d = base.create_lod_tensor(np.array([[-1]]), [[1]], self.place)
102+
inputs = {"a": a, "b": b, "c": c, "d": d}
92103

93-
self.assertRaises(TypeError, op, x=a, y=d)
94-
self.assertRaises(TypeError, op, x=d, y=a)
95-
self.assertRaises(TypeError, op, x=c, y=d)
104+
op(x=inputs[case["x"]], y=inputs[case["y"]], **case["args"])
105+
exe = paddle.static.Executor(self.place)
106+
107+
exe.run(paddle.static.default_startup_program())
108+
exe.run(paddle.static.default_main_program())
109+
110+
for case in cases:
111+
self.assertRaises(TypeError, build_op, case)
112+
else:
113+
with program_guard(Program(), Program()):
114+
a = paddle.static.data(name="a", shape=[-1, 2], dtype="float32")
115+
b = paddle.static.data(name="b", shape=[-1, 2], dtype="float32")
116+
c = paddle.static.data(name="c", shape=[-1, 2], dtype="int16")
117+
d = base.create_lod_tensor(np.array([[-1]]), [[1]], self.place)
118+
119+
self.assertRaises(TypeError, op, x=a, y=b, axis=True)
120+
self.assertRaises(TypeError, op, x=a, y=b, force_cpu=1)
121+
self.assertRaises(TypeError, op, x=a, y=b, cond=1)
122+
self.assertRaises(TypeError, op, x=a, y=c)
123+
self.assertRaises(TypeError, op, x=c, y=a)
124+
self.assertRaises(TypeError, op, x=a, y=d)
125+
self.assertRaises(TypeError, op, x=d, y=a)
126+
self.assertRaises(TypeError, op, x=c, y=d)
127+
paddle.disable_static()
96128

97129
def test_dynamic_api(self):
98130
paddle.disable_static()
99-
paddle.set_device("npu:0")
100131
if typename == "bfloat16":
101132
x = np.random.random(size=(10, 7)).astype(np.float32)
102133
y = np.random.random(size=(10, 7)).astype(np.float32)
@@ -114,7 +145,6 @@ def test_dynamic_api_different_type(self):
114145
if op_type != "equal":
115146
return
116147
paddle.disable_static()
117-
paddle.set_device("npu:0")
118148
y = np.random.random(size=(10, 7)).astype("int32")
119149
if typename == "bfloat16":
120150
x = np.random.random(size=(10, 7)).astype(np.float32)
@@ -211,12 +241,13 @@ def test_broadcast_api_3(self):
211241

212242
def test_attr_name(self):
213243
paddle.enable_static()
214-
with program_guard(Program(), Program()):
215-
x = paddle.static.data(name="x", shape=[-1, 4], dtype=typename)
216-
y = paddle.static.data(name="y", shape=[-1, 4], dtype=typename)
217-
op = eval("paddle.%s" % (self.op_type))
218-
out = op(x=x, y=y, name="name_%s" % (self.op_type))
219-
self.assertEqual("name_%s" % (self.op_type) in out.name, True)
244+
with paddle.pir_utils.OldIrGuard():
245+
with old_program_guard(OldProgram(), OldProgram()):
246+
x = paddle.static.data(name="x", shape=[-1, 4], dtype=typename)
247+
y = paddle.static.data(name="y", shape=[-1, 4], dtype=typename)
248+
op = eval("paddle.%s" % (self.op_type))
249+
out = op(x=x, y=y, name="name_%s" % (self.op_type))
250+
self.assertEqual("name_%s" % (self.op_type) in out.name, True)
220251

221252
cls_name = "{0}_{1}".format(op_type, typename)
222253
Cls.__name__ = cls_name

0 commit comments

Comments
 (0)