Skip to content

Commit ca322a2

Browse files
[NPU][MLU] Fix test_elementwise_sub/add/mod (#1627)
1 parent e760c4e commit ca322a2

File tree

4 files changed

+67
-40
lines changed

4 files changed

+67
-40
lines changed

backends/mlu/tests/unittests/test_elementwise_add_op_mlu.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -391,23 +391,29 @@ def test_errors(self):
391391
)
392392
self.assertRaises(TypeError, paddle.add, x1, y1)
393393

394-
# the input dtype of elementwise_add must be float16 or float32
395-
x2 = paddle.static.data(name="x2", shape=[-1, 3, 4, 5, 6], dtype="uint8")
396-
y2 = paddle.static.data(name="y2", shape=[-1, 3, 4, 5, 6], dtype="uint8")
397-
self.assertRaises(TypeError, paddle.add, x2, y2)
394+
if not paddle.framework.in_pir_mode():
395+
# the input dtype of elementwise_add must be float16 or float32
396+
x2 = paddle.static.data(
397+
name="x2", shape=[-1, 3, 4, 5, 6], dtype="uint8"
398+
)
399+
y2 = paddle.static.data(
400+
name="y2", shape=[-1, 3, 4, 5, 6], dtype="uint8"
401+
)
402+
self.assertRaises(TypeError, paddle.add, x2, y2)
398403

399404

400405
class TestAddApi(unittest.TestCase):
401406
def _executed_api(self, x, y, name=None):
402407
return paddle.add(x, y, name)
403408

404409
def test_name(self):
405-
with base.program_guard(base.Program()):
406-
x = paddle.static.data(name="x", shape=[2, 3], dtype="float32")
407-
y = paddle.static.data(name="y", shape=[2, 3], dtype="float32")
410+
with paddle.pir_utils.OldIrGuard():
411+
with base.program_guard(base.Program()):
412+
x = paddle.static.data(name="x", shape=[2, 3], dtype="float32")
413+
y = paddle.static.data(name="y", shape=[2, 3], dtype="float32")
408414

409-
y_1 = self._executed_api(x, y, name="add_res")
410-
self.assertEqual(("add_res" in y_1.name), True)
415+
y_1 = self._executed_api(x, y, name="add_res")
416+
self.assertEqual(("add_res" in y_1.name), True)
411417

412418
def test_declarative(self):
413419
with base.program_guard(base.Program()):
@@ -424,7 +430,7 @@ def gen_data():
424430

425431
place = paddle.CustomPlace("mlu", 0)
426432
exe = base.Executor(place)
427-
z_value = exe.run(feed=gen_data(), fetch_list=[z.name])
433+
z_value = exe.run(feed=gen_data(), fetch_list=[z])
428434
z_expected = np.array([3.0, 8.0, 6.0])
429435
self.assertEqual((z_value == z_expected).all(), True)
430436

@@ -509,15 +515,27 @@ def test_static_add(self):
509515
a = 1.5
510516
b = paddle.full([4, 5, 6], True, dtype="bool")
511517
c = a + b
512-
self.assertTrue(c.dtype == core.VarDesc.VarType.FP32)
518+
519+
expected_type = (
520+
core.DataType.FLOAT32
521+
if paddle.framework.use_pir_api()
522+
else core.VarDesc.VarType.FP32
523+
)
524+
self.assertTrue(c.dtype == expected_type)
513525
paddle.enable_static()
514526

515527
def test_dygraph_add(self):
516528
paddle.disable_static()
517529
a = 1.5
518530
b = paddle.full([4, 5, 6], True, dtype="bool")
519531
c = a + b
520-
self.assertTrue(c.dtype == core.VarDesc.VarType.FP32)
532+
533+
expected_type = (
534+
core.DataType.FLOAT32
535+
if paddle.framework.use_pir_api()
536+
else core.VarDesc.VarType.FP32
537+
)
538+
self.assertTrue(c.dtype == expected_type)
521539

522540

523541
if __name__ == "__main__":

backends/npu/tests/unittests/test_elementwise_add_op_npu.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@
2323

2424
paddle.enable_static()
2525

26+
# Initialize NPU device
27+
exe = paddle.static.Executor(paddle.CustomPlace("npu", 0))
28+
exe.run(paddle.static.default_startup_program())
29+
2630

2731
class TestElementwiseAddOp(OpTest):
2832
def setUp(self):
@@ -184,12 +188,13 @@ def init_input_output(self):
184188

185189
class TestAddAPI(unittest.TestCase):
186190
def test_name(self):
187-
with paddle.static.program_guard(paddle.static.Program()):
188-
x = paddle.static.data(name="x", shape=[2, 3], dtype="float32")
189-
y = paddle.static.data(name="y", shape=[2, 3], dtype="float32")
191+
with paddle.pir_utils.OldIrGuard():
192+
with paddle.static.program_guard(paddle.static.Program()):
193+
x = paddle.static.data(name="x", shape=[2, 3], dtype="float32")
194+
y = paddle.static.data(name="y", shape=[2, 3], dtype="float32")
190195

191-
y_1 = paddle.add(x, y, name="add_res")
192-
self.assertEqual(("add_res" in y_1.name), True)
196+
y_1 = paddle.add(x, y, name="add_res")
197+
self.assertEqual(("add_res" in y_1.name), True)
193198

194199
def test_static(self):
195200
with paddle.static.program_guard(paddle.static.Program()):
@@ -240,11 +245,11 @@ def test_errors(self):
240245
np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], paddle.CustomPlace("npu", 0)
241246
)
242247
self.assertRaises(TypeError, paddle.add, x1, y1)
243-
244-
# the input dtype must be float16 or float32 or float64 or int32 or int64
245-
x2 = paddle.static.data(name="x2", shape=[3, 4, 5, 6], dtype="uint8")
246-
y2 = paddle.static.data(name="y2", shape=[3, 4, 5, 6], dtype="uint8")
247-
self.assertRaises(TypeError, paddle.add, x2, y2)
248+
if not paddle.framework.in_pir_mode():
249+
# the input dtype must be float16 or float32 or float64 or int32 or int64
250+
x2 = paddle.static.data(name="x2", shape=[3, 4, 5, 6], dtype="uint8")
251+
y2 = paddle.static.data(name="y2", shape=[3, 4, 5, 6], dtype="uint8")
252+
self.assertRaises(TypeError, paddle.add, x2, y2)
248253

249254

250255
class TestElementwiseAddOp_Vector(TestElementwiseAddOp):
@@ -507,12 +512,13 @@ def _executed_api(self, x, y, name=None):
507512
return paddle.add(x, y, name)
508513

509514
def test_name(self):
510-
with base.program_guard(base.Program()):
511-
x = paddle.static.data(name="x", shape=[2, 3], dtype="float32")
512-
y = paddle.static.data(name="y", shape=[2, 3], dtype="float32")
515+
with paddle.pir_utils.OldIrGuard():
516+
with base.program_guard(base.Program()):
517+
x = paddle.static.data(name="x", shape=[2, 3], dtype="float32")
518+
y = paddle.static.data(name="y", shape=[2, 3], dtype="float32")
513519

514-
y_1 = self._executed_api(x, y, name="add_res")
515-
self.assertEqual(("add_res" in y_1.name), True)
520+
y_1 = self._executed_api(x, y, name="add_res")
521+
self.assertEqual(("add_res" in y_1.name), True)
516522

517523
def test_declarative(self):
518524
with base.program_guard(base.Program()):
@@ -529,7 +535,7 @@ def gen_data():
529535

530536
place = paddle.CustomPlace("npu", 0)
531537
exe = base.Executor(place)
532-
z_value = exe.run(feed=gen_data(), fetch_list=[z.name])
538+
z_value = exe.run(feed=gen_data(), fetch_list=[z])
533539
z_expected = np.array([3.0, 8.0, 6.0])
534540
self.assertEqual((z_value == z_expected).all(), True)
535541

@@ -639,7 +645,7 @@ def test_api_static(self):
639645
exe.run(startup_program)
640646
result = exe.run(
641647
main_program,
642-
feed={x_data.name: self.x, y_data.name: self.y},
648+
feed={"data_x": self.x, "data_y": self.y},
643649
fetch_list=[out_expect, out_actual],
644650
)
645651

backends/npu/tests/unittests/test_elementwise_mod_op_npu.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -141,12 +141,13 @@ def init_axis(self):
141141

142142
class TestRemainderOp(unittest.TestCase):
143143
def test_name(self):
144-
paddle.set_device("npu:0")
145-
with base.program_guard(base.Program()):
146-
x = paddle.static.data(name="x", shape=[2, 3], dtype="int64")
147-
y = paddle.static.data(name="y", shape=[2, 3], dtype="int64")
148-
y_1 = paddle.remainder(x, y, name="div_res")
149-
self.assertEqual(("div_res" in y_1.name), True)
144+
with paddle.pir_utils.OldIrGuard():
145+
paddle.set_device("npu:0")
146+
with base.program_guard(base.Program()):
147+
x = paddle.static.data(name="x", shape=[2, 3], dtype="int64")
148+
y = paddle.static.data(name="y", shape=[2, 3], dtype="int64")
149+
y_1 = paddle.remainder(x, y, name="div_res")
150+
self.assertEqual(("div_res" in y_1.name), True)
150151

151152
def test_dygraph(self):
152153
paddle.set_device("npu:0")
@@ -207,6 +208,7 @@ def test_dygraph_binary(self):
207208
out_cls = np.remainder(nx, ny)
208209
np.testing.assert_array_equal(out_cls, out.numpy())
209210
self.assertEqual(out.shape, [2, 3, 4])
211+
out.retain_grads()
210212
out.backward()
211213
if x.grad is not None:
212214
self.assertEqual(x.grad.shape, [2, 3, 4])

backends/npu/tests/unittests/test_elementwise_sub_op_npu.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,12 +134,13 @@ def init_dtype(self):
134134

135135
class TestSubtractAPI(unittest.TestCase):
136136
def test_name(self):
137-
with paddle.static.program_guard(paddle.static.Program()):
138-
x = paddle.static.data(name="x", shape=[2, 3], dtype="float32")
139-
y = paddle.static.data(name="y", shape=[2, 3], dtype="float32")
137+
with paddle.pir_utils.OldIrGuard():
138+
with paddle.static.program_guard(paddle.static.Program()):
139+
x = paddle.static.data(name="x", shape=[2, 3], dtype="float32")
140+
y = paddle.static.data(name="y", shape=[2, 3], dtype="float32")
140141

141-
y_1 = paddle.subtract(x, y, name="add_res")
142-
self.assertEqual(("add_res" in y_1.name), True)
142+
y_1 = paddle.subtract(x, y, name="add_res")
143+
self.assertEqual(("add_res" in y_1.name), True)
143144

144145
def test_static(self):
145146
with paddle.static.program_guard(paddle.static.Program()):

0 commit comments

Comments
 (0)