Skip to content

Commit 7037807

Browse files
committed
some create api support more usage
1 parent 239dff3 commit 7037807

File tree

3 files changed

+125
-9
lines changed

3 files changed

+125
-9
lines changed

python/paddle/tensor/creation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import paddle
2626
from paddle import _C_ops
27+
from paddle.utils.decorator_utils import SizeArgsDecorator
2728
from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only
2829

2930
from ..base.data_feeder import (
@@ -1234,6 +1235,7 @@ def fill_constant(
12341235
return out
12351236

12361237

1238+
@SizeArgsDecorator()
12371239
def ones(
12381240
shape: ShapeLike, dtype: DTypeLike | None = None, name: str | None = None
12391241
) -> paddle.Tensor:

python/paddle/utils/decorator_utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,35 @@ def process(
105105
f"Cannot specify both '{original}' and its alias '{alias}'"
106106
)
107107
return args, processed_kwargs
108+
109+
110+
# Size可变参数装饰器
111+
class SizeArgsDecorator(DecoratorBase[_P, _R]):
112+
"""
113+
Usage Example:
114+
115+
paddle.ones(1, dtype=paddle.float32)
116+
paddle.ones(1, 2, 3, dtype=paddle.float32)
117+
paddle.ones([1, 2, 3], dtype=paddle.float32)
118+
paddle.ones(size=[1, 2, 3], dtype=paddle.float32)
119+
120+
paddle.ones([1, 2, 3], paddle.float32)
121+
paddle.ones(shape=[1, 2, 3], dtype=paddle.float32)
122+
"""
123+
124+
def process(
125+
self, args: tuple[Any, ...], kwargs: dict[str, Any]
126+
) -> tuple[tuple[Any, ...], dict[str, Any]]:
127+
if 'size' in kwargs:
128+
kwargs['shape'] = kwargs.pop('size')
129+
elif len(args) >= 1:
130+
is_all_int = True
131+
for ele in args:
132+
if not isinstance(ele, int):
133+
is_all_int = False
134+
break
135+
if is_all_int:
136+
kwargs['shape'] = list(args)
137+
args = ()
138+
139+
return args, kwargs

test/legacy_test/test_ones_op.py

Lines changed: 91 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,38 +20,120 @@
2020

2121

2222
class ApiOnesTest(unittest.TestCase):
23-
def test_paddle_ones(self):
23+
def test_static_ones(self):
24+
paddle.enable_static()
25+
with paddle.static.program_guard(paddle.static.Program()):
26+
ones = paddle.ones(10, dtype=paddle.float32)
27+
place = paddle.CPUPlace()
28+
exe = paddle.static.Executor(place)
29+
(result,) = exe.run(fetch_list=[ones])
30+
expect = np.ones([10], dtype="float32")
31+
np.testing.assert_equal(result, expect)
32+
33+
with paddle.static.program_guard(paddle.static.Program()):
34+
ones = paddle.ones(10, 2, 3, dtype=paddle.float32)
35+
place = paddle.CPUPlace()
36+
exe = paddle.static.Executor(place)
37+
(result,) = exe.run(fetch_list=[ones])
38+
expect = np.ones([10, 2, 3], dtype="float32")
39+
np.testing.assert_equal(result, expect)
40+
41+
with paddle.static.program_guard(paddle.static.Program()):
42+
ones = paddle.ones([10, 2, 3], dtype=paddle.float32)
43+
place = paddle.CPUPlace()
44+
exe = paddle.static.Executor(place)
45+
(result,) = exe.run(fetch_list=[ones])
46+
expect = np.ones([10, 2, 3], dtype="float32")
47+
np.testing.assert_equal(result, expect)
48+
49+
with paddle.static.program_guard(paddle.static.Program()):
50+
ones = paddle.ones(size=[10, 2, 3], dtype=paddle.float32)
51+
place = paddle.CPUPlace()
52+
exe = paddle.static.Executor(place)
53+
(result,) = exe.run(fetch_list=[ones])
54+
expect = np.ones([10, 2, 3], dtype="float32")
55+
np.testing.assert_equal(result, expect)
56+
57+
with paddle.static.program_guard(paddle.static.Program()):
58+
ones = paddle.ones([10, 2, 3], paddle.float32)
59+
place = paddle.CPUPlace()
60+
exe = paddle.static.Executor(place)
61+
(result,) = exe.run(fetch_list=[ones])
62+
expect = np.ones([10, 2, 3], dtype="float32")
63+
np.testing.assert_equal(result, expect)
64+
65+
with paddle.static.program_guard(paddle.static.Program()):
66+
ones = paddle.ones([10, 2, 3], paddle.float32)
67+
place = paddle.CPUPlace()
68+
exe = paddle.static.Executor(place)
69+
(result,) = exe.run(fetch_list=[ones])
70+
expect = np.ones([10, 2, 3], dtype="float32")
71+
np.testing.assert_equal(result, expect)
72+
73+
with paddle.static.program_guard(paddle.static.Program()):
74+
ones = paddle.ones(shape=[10, 2, 3], dtype=paddle.float32)
75+
place = paddle.CPUPlace()
76+
exe = paddle.static.Executor(place)
77+
(result,) = exe.run(fetch_list=[ones])
78+
expect = np.ones([10, 2, 3], dtype="float32")
79+
np.testing.assert_equal(result, expect)
80+
2481
with paddle.static.program_guard(paddle.static.Program()):
2582
ones = paddle.ones(shape=[10])
2683
place = paddle.CPUPlace()
2784
exe = paddle.static.Executor(place)
2885
(result,) = exe.run(fetch_list=[ones])
29-
expected_result = np.ones(10, dtype="float32")
30-
self.assertEqual((result == expected_result).all(), True)
86+
expect = np.ones(10, dtype="float32")
87+
np.testing.assert_equal(result, expect)
3188

3289
with paddle.static.program_guard(paddle.static.Program()):
3390
ones = paddle.ones(shape=[10], dtype="float64")
3491
place = paddle.CPUPlace()
3592
exe = paddle.static.Executor(place)
3693
(result,) = exe.run(fetch_list=[ones])
37-
expected_result = np.ones(10, dtype="float64")
38-
self.assertEqual((result == expected_result).all(), True)
94+
expect = np.ones(10, dtype="float64")
95+
np.testing.assert_equal(result, expect)
3996

4097
with paddle.static.program_guard(paddle.static.Program()):
4198
ones = paddle.ones(shape=[10], dtype="int64")
4299
place = paddle.CPUPlace()
43100
exe = paddle.static.Executor(place)
44101
(result,) = exe.run(fetch_list=[ones])
45-
expected_result = np.ones(10, dtype="int64")
46-
self.assertEqual((result == expected_result).all(), True)
102+
expect = np.ones(10, dtype="int64")
103+
np.testing.assert_equal(result, expect)
47104

48105
with paddle.static.program_guard(paddle.static.Program()):
49106
ones = paddle.ones(shape=10, dtype="int64")
50107
place = paddle.CPUPlace()
51108
exe = paddle.static.Executor(place)
52109
(result,) = exe.run(fetch_list=[ones])
53-
expected_result = np.ones(10, dtype="int64")
54-
self.assertEqual((result == expected_result).all(), True)
110+
expect = np.ones(10, dtype="int64")
111+
np.testing.assert_equal(result, expect)
112+
paddle.disable_static()
113+
114+
def test_dygraph_ones(self):
115+
result = paddle.ones(10, dtype=paddle.float32)
116+
expect = np.ones([10], dtype="float32")
117+
np.testing.assert_equal(result, expect)
118+
119+
result = paddle.ones(10, 2, 3, dtype=paddle.float32)
120+
expect = np.ones([10, 2, 3], dtype="float32")
121+
np.testing.assert_equal(result, expect)
122+
123+
result = paddle.ones([10, 2, 3], dtype=paddle.float32)
124+
np.testing.assert_equal(result, expect)
125+
126+
result = paddle.ones(size=[10, 2, 3], dtype=paddle.float32)
127+
np.testing.assert_equal(result, expect)
128+
129+
result = paddle.ones([10, 2, 3], paddle.float32)
130+
np.testing.assert_equal(result, expect)
131+
132+
result = paddle.ones([10, 2, 3], "float32")
133+
np.testing.assert_equal(result, expect)
134+
135+
result = paddle.ones(shape=[10, 2, 3], dtype=paddle.float32)
136+
np.testing.assert_equal(result, expect)
55137

56138

57139
if __name__ == "__main__":

0 commit comments

Comments
 (0)