diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index cb9b300b6d624f..117b63fe11587d 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -24,7 +24,7 @@ import paddle from paddle import _C_ops -from paddle.utils.decorator_utils import ParamAliasDecorator +from paddle.utils.decorator_utils import ParamAliasDecorator, SizeArgsDecorator from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only from ..base.data_feeder import ( @@ -1241,6 +1241,7 @@ def fill_constant( return out +@SizeArgsDecorator() def ones( shape: ShapeLike, dtype: DTypeLike | None = None, name: str | None = None ) -> paddle.Tensor: diff --git a/python/paddle/utils/decorator_utils.py b/python/paddle/utils/decorator_utils.py index 79ec73937ec8c1..21c16212560941 100644 --- a/python/paddle/utils/decorator_utils.py +++ b/python/paddle/utils/decorator_utils.py @@ -89,3 +89,29 @@ def process( f"Cannot specify both '{original}' and its alias '{alias}'" ) return args, processed_kwargs + + +# *size => shape decorator +class SizeArgsDecorator(DecoratorBase): + """ + Usage Example: + + paddle.ones(1, dtype=paddle.float32) + paddle.ones(1, 2, 3, dtype=paddle.float32) + paddle.ones([1, 2, 3], dtype=paddle.float32) + paddle.ones(size=[1, 2, 3], dtype=paddle.float32) + + paddle.ones([1, 2, 3], paddle.float32) + paddle.ones(shape=[1, 2, 3], dtype=paddle.float32) + """ + + def process( + self, args: tuple[Any, ...], kwargs: dict[str, Any] + ) -> tuple[tuple[Any, ...], dict[str, Any]]: + if 'size' in kwargs: + kwargs['shape'] = kwargs.pop('size') + elif len(args) >= 1 and isinstance(args[0], int): + kwargs['shape'] = list(args) + args = () + + return args, kwargs diff --git a/test/legacy_test/test_ones_op.py b/test/legacy_test/test_ones_op.py index 3394bc611e7bfe..63ea2930633414 100644 --- a/test/legacy_test/test_ones_op.py +++ b/test/legacy_test/test_ones_op.py @@ -20,38 +20,121 @@ class ApiOnesTest(unittest.TestCase): - def test_paddle_ones(self): + def test_static_ones(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + ones = paddle.ones(10, dtype=paddle.float32) + place = paddle.CPUPlace() + exe = paddle.static.Executor(place) + (result,) = exe.run(fetch_list=[ones]) + expect = np.ones([10], dtype="float32") + np.testing.assert_equal(result, expect) + + with paddle.static.program_guard(paddle.static.Program()): + ones = paddle.ones(10, 2, 3, dtype=paddle.float32) + place = paddle.CPUPlace() + exe = paddle.static.Executor(place) + (result,) = exe.run(fetch_list=[ones]) + expect = np.ones([10, 2, 3], dtype="float32") + np.testing.assert_equal(result, expect) + + with paddle.static.program_guard(paddle.static.Program()): + ones = paddle.ones([10, 2, 3], dtype=paddle.float32) + place = paddle.CPUPlace() + exe = paddle.static.Executor(place) + (result,) = exe.run(fetch_list=[ones]) + expect = np.ones([10, 2, 3], dtype="float32") + np.testing.assert_equal(result, expect) + + with paddle.static.program_guard(paddle.static.Program()): + ones = paddle.ones(size=[10, 2, 3], dtype=paddle.float32) + place = paddle.CPUPlace() + exe = paddle.static.Executor(place) + (result,) = exe.run(fetch_list=[ones]) + expect = np.ones([10, 2, 3], dtype="float32") + np.testing.assert_equal(result, expect) + + with paddle.static.program_guard(paddle.static.Program()): + ones = paddle.ones([10, 2, 3], paddle.float32) + place = paddle.CPUPlace() + exe = paddle.static.Executor(place) + (result,) = exe.run(fetch_list=[ones]) + expect = np.ones([10, 2, 3], dtype="float32") + np.testing.assert_equal(result, expect) + + with paddle.static.program_guard(paddle.static.Program()): + ones = paddle.ones([10, 2, 3], paddle.float32) + place = paddle.CPUPlace() + exe = paddle.static.Executor(place) + (result,) = exe.run(fetch_list=[ones]) + expect = np.ones([10, 2, 3], dtype="float32") + np.testing.assert_equal(result, expect) + + with paddle.static.program_guard(paddle.static.Program()): + ones = paddle.ones(shape=[10, 2, 3], dtype=paddle.float32) + place = paddle.CPUPlace() + exe = paddle.static.Executor(place) + (result,) = exe.run(fetch_list=[ones]) + expect = np.ones([10, 2, 3], dtype="float32") + np.testing.assert_equal(result, expect) + with paddle.static.program_guard(paddle.static.Program()): ones = paddle.ones(shape=[10]) place = paddle.CPUPlace() exe = paddle.static.Executor(place) (result,) = exe.run(fetch_list=[ones]) - expected_result = np.ones(10, dtype="float32") - self.assertEqual((result == expected_result).all(), True) + expect = np.ones(10, dtype="float32") + np.testing.assert_equal(result, expect) with paddle.static.program_guard(paddle.static.Program()): ones = paddle.ones(shape=[10], dtype="float64") place = paddle.CPUPlace() exe = paddle.static.Executor(place) (result,) = exe.run(fetch_list=[ones]) - expected_result = np.ones(10, dtype="float64") - self.assertEqual((result == expected_result).all(), True) + expect = np.ones(10, dtype="float64") + np.testing.assert_equal(result, expect) with paddle.static.program_guard(paddle.static.Program()): ones = paddle.ones(shape=[10], dtype="int64") place = paddle.CPUPlace() exe = paddle.static.Executor(place) (result,) = exe.run(fetch_list=[ones]) - expected_result = np.ones(10, dtype="int64") - self.assertEqual((result == expected_result).all(), True) + expect = np.ones(10, dtype="int64") + np.testing.assert_equal(result, expect) with paddle.static.program_guard(paddle.static.Program()): ones = paddle.ones(shape=10, dtype="int64") place = paddle.CPUPlace() exe = paddle.static.Executor(place) (result,) = exe.run(fetch_list=[ones]) - expected_result = np.ones(10, dtype="int64") - self.assertEqual((result == expected_result).all(), True) + expect = np.ones(10, dtype="int64") + np.testing.assert_equal(result, expect) + paddle.disable_static() + + def test_dygraph_ones(self): + paddle.disable_static() + result = paddle.ones(10, dtype=paddle.float32) + expect = np.ones([10], dtype="float32") + np.testing.assert_equal(result, expect) + + result = paddle.ones(10, 2, 3, dtype=paddle.float32) + expect = np.ones([10, 2, 3], dtype="float32") + np.testing.assert_equal(result, expect) + + result = paddle.ones([10, 2, 3], dtype=paddle.float32) + np.testing.assert_equal(result, expect) + + result = paddle.ones(size=[10, 2, 3], dtype=paddle.float32) + np.testing.assert_equal(result, expect) + + result = paddle.ones([10, 2, 3], paddle.float32) + np.testing.assert_equal(result, expect) + + result = paddle.ones([10, 2, 3], "float32") + np.testing.assert_equal(result, expect) + + result = paddle.ones(shape=[10, 2, 3], dtype=paddle.float32) + np.testing.assert_equal(result, expect) if __name__ == "__main__":