Skip to content

Commit 17b20dd

Browse files
zhwesky2010Enigmatisms
authored andcommitted
some create api support more usage (PaddlePaddle#74494)
1 parent fbd9f59 commit 17b20dd

File tree

3 files changed

+111
-30
lines changed

3 files changed

+111
-30
lines changed

python/paddle/tensor/creation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
import paddle
2626
from paddle import _C_ops
27-
from paddle.utils.decorator_utils import ParamAliasDecorator
27+
from paddle.utils.decorator_utils import ParamAliasDecorator, SizeArgsDecorator
2828
from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only
2929

3030
from ..base.data_feeder import (
@@ -1241,6 +1241,7 @@ def fill_constant(
12411241
return out
12421242

12431243

1244+
@SizeArgsDecorator()
12441245
def ones(
12451246
shape: ShapeLike, dtype: DTypeLike | None = None, name: str | None = None
12461247
) -> paddle.Tensor:

python/paddle/utils/decorator_utils.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -91,30 +91,27 @@ def process(
9191
return args, processed_kwargs
9292

9393

94-
class ForbidKeywordsDecorator(DecoratorBase):
95-
"""A decorator that hints users to use the correct `compat` functions, when erroneous keyword arguments are detected"""
94+
# *size => shape decorator
95+
class SizeArgsDecorator(DecoratorBase):
96+
"""
97+
Usage Example:
9698
97-
def __init__(
98-
self, illegal_keys: list[str], func_name: str, correct_name: str
99-
) -> None:
100-
super().__init__()
101-
self.illegal_keys = (
102-
[illegal_keys] if isinstance(illegal_keys, str) else illegal_keys
103-
)
104-
self.func_name = func_name
105-
self.correct_name = correct_name
99+
paddle.ones(1, dtype=paddle.float32)
100+
paddle.ones(1, 2, 3, dtype=paddle.float32)
101+
paddle.ones([1, 2, 3], dtype=paddle.float32)
102+
paddle.ones(size=[1, 2, 3], dtype=paddle.float32)
103+
104+
paddle.ones([1, 2, 3], paddle.float32)
105+
paddle.ones(shape=[1, 2, 3], dtype=paddle.float32)
106+
"""
106107

107108
def process(
108109
self, args: tuple[Any, ...], kwargs: dict[str, Any]
109110
) -> tuple[tuple[Any, ...], dict[str, Any]]:
110-
found_keys = [key for key in self.illegal_keys if key in kwargs]
111-
112-
if found_keys:
113-
keys_str = ", ".join(f"'{key}'" for key in found_keys)
114-
plural = "s" if len(found_keys) > 1 else ""
111+
if 'size' in kwargs:
112+
kwargs['shape'] = kwargs.pop('size')
113+
elif len(args) >= 1 and isinstance(args[0], int):
114+
kwargs['shape'] = list(args)
115+
args = ()
115116

116-
raise TypeError(
117-
f"{self.func_name}() received unexpected keyword argument{plural} {keys_str}. "
118-
f"\nDid you mean to use {self.correct_name}() instead?"
119-
)
120117
return args, kwargs

test/legacy_test/test_ones_op.py

Lines changed: 92 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,38 +20,121 @@
2020

2121

2222
class ApiOnesTest(unittest.TestCase):
23-
def test_paddle_ones(self):
23+
def test_static_ones(self):
24+
paddle.enable_static()
25+
with paddle.static.program_guard(paddle.static.Program()):
26+
ones = paddle.ones(10, dtype=paddle.float32)
27+
place = paddle.CPUPlace()
28+
exe = paddle.static.Executor(place)
29+
(result,) = exe.run(fetch_list=[ones])
30+
expect = np.ones([10], dtype="float32")
31+
np.testing.assert_equal(result, expect)
32+
33+
with paddle.static.program_guard(paddle.static.Program()):
34+
ones = paddle.ones(10, 2, 3, dtype=paddle.float32)
35+
place = paddle.CPUPlace()
36+
exe = paddle.static.Executor(place)
37+
(result,) = exe.run(fetch_list=[ones])
38+
expect = np.ones([10, 2, 3], dtype="float32")
39+
np.testing.assert_equal(result, expect)
40+
41+
with paddle.static.program_guard(paddle.static.Program()):
42+
ones = paddle.ones([10, 2, 3], dtype=paddle.float32)
43+
place = paddle.CPUPlace()
44+
exe = paddle.static.Executor(place)
45+
(result,) = exe.run(fetch_list=[ones])
46+
expect = np.ones([10, 2, 3], dtype="float32")
47+
np.testing.assert_equal(result, expect)
48+
49+
with paddle.static.program_guard(paddle.static.Program()):
50+
ones = paddle.ones(size=[10, 2, 3], dtype=paddle.float32)
51+
place = paddle.CPUPlace()
52+
exe = paddle.static.Executor(place)
53+
(result,) = exe.run(fetch_list=[ones])
54+
expect = np.ones([10, 2, 3], dtype="float32")
55+
np.testing.assert_equal(result, expect)
56+
57+
with paddle.static.program_guard(paddle.static.Program()):
58+
ones = paddle.ones([10, 2, 3], paddle.float32)
59+
place = paddle.CPUPlace()
60+
exe = paddle.static.Executor(place)
61+
(result,) = exe.run(fetch_list=[ones])
62+
expect = np.ones([10, 2, 3], dtype="float32")
63+
np.testing.assert_equal(result, expect)
64+
65+
with paddle.static.program_guard(paddle.static.Program()):
66+
ones = paddle.ones([10, 2, 3], paddle.float32)
67+
place = paddle.CPUPlace()
68+
exe = paddle.static.Executor(place)
69+
(result,) = exe.run(fetch_list=[ones])
70+
expect = np.ones([10, 2, 3], dtype="float32")
71+
np.testing.assert_equal(result, expect)
72+
73+
with paddle.static.program_guard(paddle.static.Program()):
74+
ones = paddle.ones(shape=[10, 2, 3], dtype=paddle.float32)
75+
place = paddle.CPUPlace()
76+
exe = paddle.static.Executor(place)
77+
(result,) = exe.run(fetch_list=[ones])
78+
expect = np.ones([10, 2, 3], dtype="float32")
79+
np.testing.assert_equal(result, expect)
80+
2481
with paddle.static.program_guard(paddle.static.Program()):
2582
ones = paddle.ones(shape=[10])
2683
place = paddle.CPUPlace()
2784
exe = paddle.static.Executor(place)
2885
(result,) = exe.run(fetch_list=[ones])
29-
expected_result = np.ones(10, dtype="float32")
30-
self.assertEqual((result == expected_result).all(), True)
86+
expect = np.ones(10, dtype="float32")
87+
np.testing.assert_equal(result, expect)
3188

3289
with paddle.static.program_guard(paddle.static.Program()):
3390
ones = paddle.ones(shape=[10], dtype="float64")
3491
place = paddle.CPUPlace()
3592
exe = paddle.static.Executor(place)
3693
(result,) = exe.run(fetch_list=[ones])
37-
expected_result = np.ones(10, dtype="float64")
38-
self.assertEqual((result == expected_result).all(), True)
94+
expect = np.ones(10, dtype="float64")
95+
np.testing.assert_equal(result, expect)
3996

4097
with paddle.static.program_guard(paddle.static.Program()):
4198
ones = paddle.ones(shape=[10], dtype="int64")
4299
place = paddle.CPUPlace()
43100
exe = paddle.static.Executor(place)
44101
(result,) = exe.run(fetch_list=[ones])
45-
expected_result = np.ones(10, dtype="int64")
46-
self.assertEqual((result == expected_result).all(), True)
102+
expect = np.ones(10, dtype="int64")
103+
np.testing.assert_equal(result, expect)
47104

48105
with paddle.static.program_guard(paddle.static.Program()):
49106
ones = paddle.ones(shape=10, dtype="int64")
50107
place = paddle.CPUPlace()
51108
exe = paddle.static.Executor(place)
52109
(result,) = exe.run(fetch_list=[ones])
53-
expected_result = np.ones(10, dtype="int64")
54-
self.assertEqual((result == expected_result).all(), True)
110+
expect = np.ones(10, dtype="int64")
111+
np.testing.assert_equal(result, expect)
112+
paddle.disable_static()
113+
114+
def test_dygraph_ones(self):
115+
paddle.disable_static()
116+
result = paddle.ones(10, dtype=paddle.float32)
117+
expect = np.ones([10], dtype="float32")
118+
np.testing.assert_equal(result, expect)
119+
120+
result = paddle.ones(10, 2, 3, dtype=paddle.float32)
121+
expect = np.ones([10, 2, 3], dtype="float32")
122+
np.testing.assert_equal(result, expect)
123+
124+
result = paddle.ones([10, 2, 3], dtype=paddle.float32)
125+
np.testing.assert_equal(result, expect)
126+
127+
result = paddle.ones(size=[10, 2, 3], dtype=paddle.float32)
128+
np.testing.assert_equal(result, expect)
129+
130+
result = paddle.ones([10, 2, 3], paddle.float32)
131+
np.testing.assert_equal(result, expect)
132+
133+
result = paddle.ones([10, 2, 3], "float32")
134+
np.testing.assert_equal(result, expect)
135+
136+
result = paddle.ones(shape=[10, 2, 3], dtype=paddle.float32)
137+
np.testing.assert_equal(result, expect)
55138

56139

57140
if __name__ == "__main__":

0 commit comments

Comments
 (0)