Skip to content

Commit f09714f

Browse files
[NPU][MLU] Fix test_cumxxx_op on NPU (#1608)
1 parent a781b7f commit f09714f

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

backends/npu/tests/unittests/test_cumprod_op_npu.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def run_static(self, use_custom_device=False):
6767
exe.run(base.default_startup_program())
6868
out = exe.run(
6969
feed={"X": data_np},
70-
fetch_list=[y.name, y2.name, y3.name, y4.name, y5.name, y6.name],
70+
fetch_list=[y, y2, y3, y4, y5, y6],
7171
)
7272

7373
z = np.cumprod(data_np, axis=0)
@@ -87,10 +87,11 @@ def test_npu(self):
8787
self.run_static(use_custom_device=True)
8888

8989
def test_name(self):
90-
with base.program_guard(base.Program()):
91-
x = paddle.static.data("x", [3, 4])
92-
y = paddle.cumprod(x, dim=0, name="out")
93-
self.assertTrue("out" in y.name)
90+
with paddle.pir_utils.OldIrGuard():
91+
with base.program_guard(base.Program()):
92+
x = paddle.static.data("x", [3, 4])
93+
y = paddle.cumprod(x, dim=0, name="out")
94+
self.assertTrue("out" in y.name)
9495

9596

9697
if __name__ == "__main__":

backends/npu/tests/unittests/test_cumsum_op_npu.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def run_static(self, use_custom_device=False):
6868
exe.run(base.default_startup_program())
6969
out = exe.run(
7070
feed={"X": data_np},
71-
fetch_list=[y.name, y2.name, y3.name, y4.name, y5.name, y6.name],
71+
fetch_list=[y, y2, y3, y4, y5, y6],
7272
)
7373

7474
z = np.cumsum(data_np)
@@ -88,10 +88,11 @@ def test_npu(self):
8888
self.run_static(use_custom_device=True)
8989

9090
def test_name(self):
91-
with base.program_guard(base.Program()):
92-
x = paddle.static.data("x", [3, 4])
93-
y = paddle.cumsum(x, name="out")
94-
self.assertTrue("out" in y.name)
91+
with paddle.pir_utils.OldIrGuard():
92+
with base.program_guard(base.Program()):
93+
x = paddle.static.data("x", [3, 4])
94+
y = paddle.cumsum(x, name="out")
95+
self.assertTrue("out" in y.name)
9596

9697

9798
class TestNPUCumSumOp1(OpTest):

0 commit comments

Comments
 (0)