Skip to content

Commit 217cc54

Browse files
authored
【PIR API adaptor No.258、295、299、307】 Migrate glu/rank/sgn/take into pir (#59535)
1 parent 866819a commit 217cc54

File tree

6 files changed

+121
-43
lines changed

6 files changed

+121
-43
lines changed

python/paddle/tensor/attribute.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def rank(input):
4949
>>> print(rank.numpy())
5050
3
5151
"""
52-
check_type(input, 'input', (Variable), 'input')
52+
check_type(input, 'input', (Variable, paddle.pir.Value), 'input')
5353
ndims = len(input.shape)
5454
out = assign(np.array(ndims, 'int32'))
5555

@@ -163,12 +163,16 @@ def is_complex(x):
163163
>>> print(paddle.is_complex(x))
164164
False
165165
"""
166-
if not isinstance(x, (paddle.Tensor, paddle.static.Variable)):
166+
if not isinstance(
167+
x, (paddle.Tensor, paddle.static.Variable, paddle.pir.Value)
168+
):
167169
raise TypeError(f"Expected Tensor, but received type of x: {type(x)}")
168170
dtype = x.dtype
169171
is_complex_dtype = (
170172
dtype == core.VarDesc.VarType.COMPLEX64
171173
or dtype == core.VarDesc.VarType.COMPLEX128
174+
or dtype == core.DataType.COMPLEX64
175+
or dtype == core.DataType.COMPLEX128
172176
)
173177
return is_complex_dtype
174178

python/paddle/tensor/math.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6230,6 +6230,11 @@ def sgn(x, name=None):
62306230
paddle.float64,
62316231
paddle.complex64,
62326232
paddle.complex128,
6233+
DataType.FLOAT16,
6234+
DataType.FLOAT32,
6235+
DataType.FLOAT64,
6236+
DataType.COMPLEX64,
6237+
DataType.COMPLEX128,
62336238
]:
62346239
raise TypeError(
62356240
f"The data type of input must be one of ['float16', 'float32', 'float64', 'complex64', 'complex128'], but got {x.dtype}"
@@ -6317,12 +6322,17 @@ def take(x, index, mode='raise', name=None):
63176322
f"'mode' in 'take' should be 'raise', 'wrap', 'clip', but received {mode}."
63186323
)
63196324

6320-
if in_dynamic_mode():
6321-
if not isinstance(index, (paddle.Tensor, Variable)):
6325+
if in_dynamic_or_pir_mode():
6326+
if not isinstance(index, (paddle.Tensor, Variable, paddle.pir.Value)):
63226327
raise TypeError(
63236328
f"The type of 'index' must be Tensor, but got {type(index)}"
63246329
)
6325-
if index.dtype not in [paddle.int32, paddle.int64]:
6330+
if index.dtype not in [
6331+
paddle.int32,
6332+
paddle.int64,
6333+
DataType.INT32,
6334+
DataType.INT64,
6335+
]:
63266336
raise TypeError(
63276337
"The data type of 'index' must be one of ['int32', 'int64'], but got {}".format(
63286338
index.dtype

test/dygraph_to_static/test_function_spec.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,7 @@ def test_args_to_input_spec(self):
106106
if in_pir_mode():
107107
self.assertEqual(input_with_spec[1].shape, [4, 10]) # b.shape
108108
else:
109-
self.assertTupleEqual(
110-
tuple(input_with_spec[1].shape), (4, 10)
111-
) # b.shape
109+
self.assertTupleEqual(input_with_spec[1].shape, (4, 10)) # b.shape
112110

113111
self.assertEqual(input_with_spec[1].name, 'b_var') # b.name
114112

test/legacy_test/test_glu.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import paddle.base.dygraph as dg
2121
from paddle import base, nn
2222
from paddle.nn import functional as F
23+
from paddle.pir_utils import test_with_pir_api
2324

2425

2526
def sigmoid(x):
@@ -58,6 +59,7 @@ def glu_axis_size(self):
5859
x = paddle.static.data(name='x', shape=[1, 2, 3], dtype='float32')
5960
paddle.nn.functional.glu(x, axis=256)
6061

62+
@test_with_pir_api
6163
def test_errors(self):
6264
self.assertRaises(ValueError, self.glu_axis_size)
6365

@@ -92,6 +94,7 @@ def glu_axis_size(self):
9294
act = nn.GLU(256)
9395
act(x)
9496

97+
@test_with_pir_api
9598
def test_errors(self):
9699
self.assertRaises(ValueError, self.glu_axis_size)
97100
act = nn.GLU(256)

test/legacy_test/test_sgn.py

Lines changed: 84 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
import unittest
1616

1717
import numpy as np
18+
from utils import static_guard
1819

1920
import paddle
21+
from paddle.pir_utils import test_with_pir_api
2022

2123

2224
def np_sgn(x: np.ndarray):
@@ -31,7 +33,7 @@ def np_sgn(x: np.ndarray):
3133

3234

3335
class TestSgnError(unittest.TestCase):
34-
def test_errors(self):
36+
def test_errors_dynamic(self):
3537
# The input dtype of sgn must be float16, float32, float64,complex64,complex128.
3638
input2 = paddle.to_tensor(
3739
np.random.randint(-10, 10, size=[12, 20]).astype('int32')
@@ -43,33 +45,28 @@ def test_errors(self):
4345
self.assertRaises(TypeError, paddle.sgn, input2)
4446
self.assertRaises(TypeError, paddle.sgn, input3)
4547

48+
@test_with_pir_api
49+
def test_errors_static_and_pir(self):
50+
paddle.enable_static()
51+
main_program = paddle.static.Program()
52+
startup_program = paddle.static.Program()
4653

47-
class TestSignAPI(unittest.TestCase):
48-
def setUp(self) -> None:
49-
self.support_dtypes = [
50-
'float16',
51-
'float32',
52-
'float64',
53-
'complex64',
54-
'complex128',
55-
]
56-
if paddle.device.get_device() == 'cpu':
57-
self.support_dtypes = [
58-
'float32',
59-
'float64',
60-
'complex64',
61-
'complex128',
62-
]
63-
64-
def test_dtype(self):
65-
for dtype in self.support_dtypes:
66-
x = paddle.to_tensor(
67-
np.random.randint(-10, 10, size=[12, 20, 2]).astype(dtype)
54+
with paddle.static.program_guard(main_program, startup_program):
55+
# The input dtype of sgn must be float16, float32, float64,complex64,complex128.
56+
input2 = paddle.to_tensor(
57+
np.random.randint(-10, 10, size=[12, 20]).astype('int32')
58+
)
59+
input3 = paddle.to_tensor(
60+
np.random.randint(-10, 10, size=[12, 20]).astype('int64')
6861
)
6962

70-
paddle.sgn(x)
63+
self.assertRaises(TypeError, paddle.sgn, input2)
64+
self.assertRaises(TypeError, paddle.sgn, input3)
65+
paddle.disable_static()
66+
7167

72-
def test_complex(self):
68+
class TestSignAPI(unittest.TestCase):
69+
def test_complex_dynamic(self):
7370
for dtype in ['complex64', 'complex128']:
7471
np_x = np.array(
7572
[[3 + 4j, 7 - 24j, 0, 1 + 2j], [6 + 8j, 3, 0, -2]], dtype=dtype
@@ -80,15 +77,76 @@ def test_complex(self):
8077
z_expected = np_sgn(np_x)
8178
np.testing.assert_allclose(np_z, z_expected, rtol=1e-05)
8279

83-
def test_float(self):
84-
for dtype in self.support_dtypes:
80+
@test_with_pir_api
81+
def test_complex_static_and_pir(self):
82+
with static_guard():
83+
for dtype in ['complex64', 'complex128']:
84+
exe = paddle.static.Executor()
85+
86+
train_program = paddle.static.Program()
87+
startup_program = paddle.static.Program()
88+
with paddle.static.program_guard(
89+
train_program, startup_program
90+
):
91+
x = paddle.static.data(name='X', shape=[2, 4], dtype=dtype)
92+
z = paddle.sgn(x)
93+
94+
# Run the startup program once and only once.
95+
# Not need to optimize/compile the startup program.
96+
exe.run(startup_program)
97+
98+
# Run the main program directly without compile.
99+
x = np.array(
100+
[[3 + 4j, 7 - 24j, 0, 1 + 2j], [6 + 8j, 3, 0, -2]],
101+
dtype=dtype,
102+
)
103+
(z,) = exe.run(train_program, feed={"X": x}, fetch_list=[z])
104+
z_expected = np_sgn(x)
105+
np.testing.assert_allclose(z, z_expected, rtol=1e-05)
106+
107+
def test_float_dynamic(self):
108+
dtype_list = ['float32', 'float64']
109+
if paddle.is_compiled_with_cuda():
110+
dtype_list.append('float16')
111+
for dtype in dtype_list:
85112
np_x = np.random.randint(-10, 10, size=[12, 20, 2]).astype(dtype)
86113
x = paddle.to_tensor(np_x)
87114
z = paddle.sgn(x)
88115
np_z = z.numpy()
89116
z_expected = np_sgn(np_x)
90117
np.testing.assert_allclose(np_z, z_expected, rtol=1e-05)
91118

119+
@test_with_pir_api
120+
def test_float_static_and_pir(self):
121+
dtype_list = ['float32', 'float64']
122+
if paddle.is_compiled_with_cuda():
123+
dtype_list.append('float16')
124+
with static_guard():
125+
for dtype in dtype_list:
126+
exe = paddle.static.Executor()
127+
128+
train_program = paddle.static.Program()
129+
startup_program = paddle.static.Program()
130+
with paddle.static.program_guard(
131+
train_program, startup_program
132+
):
133+
np_x = np.random.randint(-10, 10, size=[12, 20, 2]).astype(
134+
dtype
135+
)
136+
x = paddle.static.data(
137+
name='X', shape=[12, 20, 2], dtype=dtype
138+
)
139+
z = paddle.sgn(x)
140+
141+
# Run the startup program once and only once.
142+
# Not need to optimize/compile the startup program.
143+
exe.run(startup_program)
144+
145+
# Run the main program directly without compile.
146+
(z,) = exe.run(train_program, feed={"X": np_x}, fetch_list=[z])
147+
z_expected = np_sgn(np_x)
148+
np.testing.assert_allclose(z, z_expected, rtol=1e-05)
149+
92150

93151
if __name__ == "__main__":
94152
unittest.main()

test/legacy_test/test_take.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818

1919
import paddle
2020
from paddle import base
21-
from paddle.base import Program, core, program_guard
21+
from paddle.base import core
22+
from paddle.pir_utils import test_with_pir_api
2223

2324

2425
class TestTakeAPI(unittest.TestCase):
@@ -49,11 +50,12 @@ def setUp(self):
4950
else base.CPUPlace()
5051
)
5152

53+
@test_with_pir_api
5254
def test_static_graph(self):
5355
paddle.enable_static()
54-
startup_program = Program()
55-
train_program = Program()
56-
with program_guard(startup_program, train_program):
56+
startup_program = paddle.static.Program()
57+
train_program = paddle.static.Program()
58+
with paddle.static.program_guard(startup_program, train_program):
5759
x = paddle.static.data(
5860
name='input', dtype=self.input_dtype, shape=self.input_shape
5961
)
@@ -62,9 +64,9 @@ def test_static_graph(self):
6264
)
6365
out = paddle.take(x, index, mode=self.mode)
6466

65-
exe = base.Executor(self.place)
67+
exe = paddle.static.Executor(self.place)
6668
st_result = exe.run(
67-
base.default_main_program(),
69+
paddle.static.default_main_program(),
6870
feed={'input': self.input_np, 'index': self.index_np},
6971
fetch_list=out,
7072
)
@@ -111,10 +113,11 @@ def set_dtype(self):
111113
class TestTakeTypeError(TestTakeAPI):
112114
"""Test take Type Error"""
113115

116+
@test_with_pir_api
114117
def test_static_type_error(self):
115118
"""Argument 'index' must be Tensor"""
116119
paddle.enable_static()
117-
with program_guard(Program()):
120+
with paddle.static.program_guard(paddle.static.Program()):
118121
x = paddle.static.data(
119122
name='input', dtype=self.input_dtype, shape=self.input_shape
120123
)
@@ -127,10 +130,11 @@ def test_dygraph_type_error(self):
127130
x = paddle.to_tensor(self.input_np)
128131
self.assertRaises(TypeError, paddle.take, x, self.index_np, self.mode)
129132

133+
@test_with_pir_api
130134
def test_static_dtype_error(self):
131135
"""Data type of argument 'index' must be in [paddle.int32, paddle.int64]"""
132136
paddle.enable_static()
133-
with program_guard(Program()):
137+
with paddle.static.program_guard(paddle.static.Program()):
134138
x = paddle.static.data(
135139
name='input', dtype='float64', shape=self.input_shape
136140
)
@@ -178,11 +182,12 @@ def setUp(self):
178182
else base.CPUPlace()
179183
)
180184

185+
@test_with_pir_api
181186
def test_static_index_error(self):
182187
"""When the index is out of range,
183188
an error is reported directly through `paddle.index_select`"""
184189
paddle.enable_static()
185-
with program_guard(Program()):
190+
with paddle.static.program_guard(paddle.static.Program()):
186191
x = paddle.static.data(
187192
name='input', dtype=self.input_dtype, shape=self.input_shape
188193
)

0 commit comments

Comments
 (0)