diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 4ebc15fdc9753..53680e172adcd 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -122,6 +122,7 @@ _pir_ops as _pir_ops, _typing as _typing, callbacks as callbacks, + compat as compat, fft as fft, hub as hub, linalg as linalg, diff --git a/python/paddle/compat.py b/python/paddle/compat.py new file mode 100644 index 0000000000000..d42b733edccc8 --- /dev/null +++ b/python/paddle/compat.py @@ -0,0 +1,21 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .tensor.compat import ( + split, +) + +__all__ = [ + 'split', +] diff --git a/python/paddle/tensor/compat.py b/python/paddle/tensor/compat.py new file mode 100644 index 0000000000000..a6a755b702520 --- /dev/null +++ b/python/paddle/tensor/compat.py @@ -0,0 +1,213 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import paddle +from paddle import _C_ops + +from ..base.framework import Variable +from ..framework import ( + in_dynamic_mode, +) + +if TYPE_CHECKING: + from collections.abc import Sequence + + from paddle import Tensor + +from paddle.utils.decorator_utils import ForbidKeywordsDecorator + +__all__ = [] + + +@ForbidKeywordsDecorator( + illegal_keys=["x", "num_or_sections", "axis", "name"], + func_name="paddle.compat.split", + correct_name="paddle.split", +) +def split( + tensor: Tensor, split_size_or_sections: int | Sequence[int], dim: int = 0 +) -> tuple[Tensor, ...]: + """ + (PyTorch Compatible API) Split the input tensor into multiple sub-Tensors. + + Args: + tensor (Tensor): A N-D Tensor. The data type is bool, bfloat16, float16, float32, float64, uint8, int8, int32 or int64. + split_size_or_sections (int|list|tuple): + If split_size_or_sections is an integer type, then tensor will be split into equally sized chunks (if possible). + Last chunk will be smaller if the tensor size along the given dimension dim is not divisible by split_size. + If split_size_or_sections is a list, then tensor will be split into len(split_size_or_sections) chunks with sizes + in dim according to split_size_or_sections. Negative inputs are not allowed. For example: for a dim with 9 channels, + [2, 3, -1] will not be interpreted as [2, 3, 4], but will be rejected and an exception will be thrown. + dim (int|Tensor, optional): The dim along which to split, it can be a integer or a ``0-D Tensor`` + with shape [] and data type ``int32`` or ``int64``. + If :math::`dim < 0`, the dim to split along is :math:`rank(x) + dim`. Default is 0. + Returns: + tuple(Tensor), The tuple of segmented Tensors. + + Note: + This is a pytorch compatible API that follows the function signature and behavior of torch.split. + To use the original split of paddle, please consider `paddle.split` + + Examples: + .. code-block:: python + + >>> import paddle + + >>> # x is a Tensor of shape [3, 8, 5] + >>> x = paddle.rand([3, 8, 5]) + + >>> out0, out1, out2 = paddle.compat.split(x, split_size_or_sections=3, dim=1) + >>> print(out0.shape) + [3, 3, 5] + >>> print(out1.shape) + [3, 3, 5] + >>> print(out2.shape) + [3, 2, 5] + + >>> out0, out1, out2 = paddle.compat.split(x, split_size_or_sections=[1, 2, 5], dim=1) + >>> print(out0.shape) + [3, 1, 5] + >>> print(out1.shape) + [3, 2, 5] + >>> print(out2.shape) + [3, 5, 5] + + >>> # dim is negative, the real dim is (rank(x) + dim)=1 + >>> out0, out1, out2 = paddle.compat.split(x, split_size_or_sections=3, dim=-2) + >>> print(out0.shape) + [3, 3, 5] + >>> print(out1.shape) + [3, 3, 5] + >>> print(out2.shape) + [3, 2, 5] + """ + + def GetSplitSize(split_size, shape_on_dim): + remaining_num = shape_on_dim % split_size_or_sections + num_complete_section = shape_on_dim // split_size_or_sections + if remaining_num == 0: + return num_complete_section + else: + sections = [ + split_size_or_sections for _ in range(num_complete_section) + ] + sections.append(remaining_num) + return sections + + def GetShapeOnDimInRange(shape, dim: int) -> int: + shape_range = len(shape) + if isinstance(dim, int): + if dim < -shape_range or dim >= shape_range: + raise ValueError( + f"(InvalidArgument) The dim is expected to be in range of [-{shape_range}, {shape_range}), but got {dim}" + ) + return shape[dim] + + if isinstance(split_size_or_sections, (list, tuple)): + for i, section_size in enumerate(split_size_or_sections): + shape_val = 0 + if isinstance(section_size, Variable): + shape_val = int(section_size.item(0)) + else: + shape_val = section_size + if section_size < 0: + raise ValueError( + f"paddle.compat.split expects split_sizes have only non-negative entries, but got size = {section_size} on dim {i}" + ) + + if in_dynamic_mode(): + if isinstance(dim, Variable): + dim = dim.item(0) + assert dim + len(tensor.shape) >= 0, "(rank(x) + dim) must >= 0" + dim = (dim + len(tensor.shape)) if dim < 0 else dim + + if isinstance(split_size_or_sections, (list, tuple)): + if paddle.utils._contain_var(split_size_or_sections): + for index, item in enumerate(split_size_or_sections): + if isinstance(item, Variable): + split_size_or_sections[index] = split_size_or_sections[ + index + ].item() + elif not isinstance(split_size_or_sections, int): + raise TypeError( + "The type of 'split_size_or_sections' in split must be int, list or tuple in imperative mode, but " + f"received {type(split_size_or_sections)}." + ) + + if isinstance(split_size_or_sections, int): + # check whether shape is divisible + assert ( + split_size_or_sections > 0 + ), 'split_size_or_sections must be greater than 0.' + + split_size_or_sections = GetSplitSize( + split_size_or_sections, GetShapeOnDimInRange(tensor.shape, dim) + ) + + if isinstance(split_size_or_sections, list): + return tuple(_C_ops.split(tensor, split_size_or_sections, dim)) + else: + return tuple( + _C_ops.split_with_num(tensor, split_size_or_sections, dim) + ) + else: + return tuple(_C_ops.split(tensor, split_size_or_sections, dim)) + else: + if isinstance(dim, paddle.pir.Value): + raise TypeError( + "'dim' is not allowed to be a pir.Value in a static graph: " + "\npir.Value can not be used for indexing python lists/tuples." + ) + if isinstance(dim, int): + assert len(tensor.shape) + dim >= 0, "(rank(x) + dim) must >= 0" + dim = (len(tensor.shape) + dim) if dim < 0 else dim + + input_shape = tensor.shape + + if not isinstance(split_size_or_sections, (int, list, tuple)): + raise TypeError( + "The type of 'split_size_or_sections' in split must be int, list or tuple in imperative mode." + ) + if isinstance(split_size_or_sections, int): + assert ( + split_size_or_sections > 0 + ), 'split_size_or_sections must be greater than 0.' + + split_size_or_sections = GetSplitSize( + split_size_or_sections, GetShapeOnDimInRange(tensor.shape, dim) + ) + if isinstance(split_size_or_sections, list): + if paddle.utils._contain_var(split_size_or_sections): + split_size_or_sections = paddle.utils.get_int_tensor_list( + split_size_or_sections + ) + return tuple(_C_ops.split(tensor, split_size_or_sections, dim)) + else: + return tuple( + _C_ops.split_with_num(tensor, split_size_or_sections, dim) + ) + else: + if isinstance(dim, int) and input_shape[dim] > 0: + assert ( + len(split_size_or_sections) <= input_shape[dim] + ), 'len(split_size_or_sections) must not be more than input.shape[dim].' + if paddle.utils._contain_var(split_size_or_sections): + split_size_or_sections = paddle.utils.get_int_tensor_list( + split_size_or_sections + ) + return tuple(_C_ops.split(tensor, split_size_or_sections, dim)) diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 857554b5dd1f2..2014603dff6ca 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -58,6 +58,8 @@ TensorOrTensors, ) +from paddle.utils.decorator_utils import ForbidKeywordsDecorator + __all__ = [] @@ -2723,6 +2725,11 @@ def row_stack(x: Sequence[Tensor], name: str | None = None) -> Tensor: return paddle.vstack(x, name=name) +@ForbidKeywordsDecorator( + illegal_keys=["tensor", "split_size_or_sections", "dim"], + func_name="paddle.split", + correct_name="paddle.compat.split", +) def split( x: Tensor, num_or_sections: int | Sequence[int], diff --git a/python/paddle/utils/decorator_utils.py b/python/paddle/utils/decorator_utils.py index 21c1621256094..97d1f4da60351 100644 --- a/python/paddle/utils/decorator_utils.py +++ b/python/paddle/utils/decorator_utils.py @@ -115,3 +115,30 @@ def process( args = () return args, kwargs + + +class ForbidKeywordsDecorator(DecoratorBase): + """A decorator that hints users to use the correct `compat` functions, when erroneous keyword arguments are detected""" + + def __init__( + self, illegal_keys: list[str], func_name: str, correct_name: str + ) -> None: + super().__init__() + self.illegal_keys = illegal_keys + self.func_name = func_name + self.correct_name = correct_name + + def process( + self, args: tuple[Any, ...], kwargs: dict[str, Any] + ) -> tuple[tuple[Any, ...], dict[str, Any]]: + found_keys = [key for key in self.illegal_keys if key in kwargs] + + if found_keys: + keys_str = ", ".join(f"'{key}'" for key in found_keys) + plural = "s" if len(found_keys) > 1 else "" + + raise TypeError( + f"{self.func_name}() received unexpected keyword argument{plural} {keys_str}. " + f"\nDid you mean to use {self.correct_name}() instead?" + ) + return args, kwargs diff --git a/test/legacy_test/test_compat_split.py b/test/legacy_test/test_compat_split.py new file mode 100644 index 0000000000000..8410e10e1e1ca --- /dev/null +++ b/test/legacy_test/test_compat_split.py @@ -0,0 +1,177 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle.compat import split + + +class TestCompatSplit(unittest.TestCase): + def _compare_with_origin(self, input_tensor, size, axis=0): + pd_results = split(input_tensor, size, dim=axis) + + if isinstance(size, int): + shape_on_axis = input_tensor.shape[axis] + remaining_num = shape_on_axis % size + num_sections = shape_on_axis // size + if remaining_num == 0: + size = num_sections + else: + size = [size for _ in range(num_sections)] + size.append(remaining_num) + + origin_results = paddle.split( + input_tensor, num_or_sections=size, axis=axis + ) + + self.assertEqual(len(origin_results), len(pd_results)) + + # check shape and output section size of the output + for origin_ts, pd_ts in zip(origin_results, pd_results): + np.testing.assert_allclose(origin_ts.numpy(), pd_ts.numpy()) + + def test_basic_split(self): + """Test basic splitting with integer size""" + data = paddle.arange(12).reshape([3, 4]).astype('float32') + self._compare_with_origin(data, 1, 0) + self._compare_with_origin(data, 2, 1) + + def test_split_with_list_sections(self): + """Test splitting with list of section sizes""" + data = paddle.rand([10, 5]) + self._compare_with_origin(data, [3, 2, 5], 0) + self._compare_with_origin(data, [1, 4], -1) + + def test_chained_operations(self): + """Test split with complex operation chain""" + x = paddle.rand([8, 12]) + y = paddle.sin(x) * 2.0 + paddle.exp(x) / 3.0 + z = paddle.nn.functional.relu(y) + + z1, z2 = split(z, 7, dim=1) + + self.assertEqual(z1.shape, [8, 7]) + self.assertEqual(z2.shape, [8, 5]) + + z_np = z.numpy() + np.testing.assert_allclose(z_np[:, :7], z1.numpy()) + np.testing.assert_allclose(z_np[:, 7:], z2.numpy()) + + def test_split_grad(self): + """Test backprop for split, in1 and in2 are computed by + compat.split and original split""" + + def get_tensors(): + np.random.seed(114514) + np_arr = np.random.normal(0, 1, [2, 3, 4, 5]) + return paddle.to_tensor(np_arr), paddle.to_tensor(np_arr) + + in1, in2 = get_tensors() + in1.stop_gradient = False + in2.stop_gradient = False + + def computation_graph(in_tensor): + y = in_tensor * 2.3 + 3.0 + y = paddle.maximum(y, paddle.to_tensor([0], dtype=paddle.float32)) + return y.mean(axis=0) + + out1 = computation_graph(in1) + out2 = computation_graph(in2) + + packs1 = paddle.compat.split(out1, 2, dim=2) + packs2 = paddle.split(out2, [2, 2, 1], axis=2) + + res1 = packs1[0] + packs1[1] + packs1[2] + res2 = packs2[0] + packs2[1] + packs2[2] + res1.backward() + res2.backward() + np.testing.assert_allclose(in1.grad.numpy(), in2.grad.numpy()) + + def test_empty_dim(self): + """Split with empty dim""" + in_tensor = paddle.arange(72, dtype=paddle.int64).reshape([3, 12, 2]) + self._compare_with_origin(in_tensor, [5, 0, 7], axis=1) + + def test_split_with_one_block(self): + """Resulting tuple should be of length 1""" + in_tensor = paddle.arange(60, dtype=paddle.float32).reshape([3, 4, 5]) + self._compare_with_origin(in_tensor, 5, paddle.to_tensor([-1])) + self._compare_with_origin(in_tensor, [5], paddle.to_tensor(2)) + + def test_edge_cases(self): + """Test edge cases and error handling""" + x = paddle.arange(5) + s1, s2 = split(x, [3, 2]) + np.testing.assert_allclose(s1.numpy(), [0, 1, 2]) + np.testing.assert_allclose(s2.numpy(), [3, 4]) + + x = paddle.rand([2, 2, 2]) + a, b = split(x, 1, 2) + self.assertEqual(a.shape, [2, 2, 1]) + + # invalid split sections + with self.assertRaises(ValueError): + split(x, [3, 1], 1) + + # invalid split axis + with self.assertRaises(ValueError): + split(x, 2, 3) + + def test_error_hint(self): + """Test whether there will be correct exception when users pass paddle.split kwargs in paddle.compat.split, vice versa.""" + x = paddle.randn([3, 9, 5]) + + msg_gt_1 = ( + "paddle.split() received unexpected keyword arguments 'tensor', 'split_size_or_sections', 'dim'. " + "\nDid you mean to use paddle.compat.split() instead?" + ) + msg_gt_2 = ( + "paddle.compat.split() received unexpected keyword argument 'num_or_sections'. " + "\nDid you mean to use paddle.split() instead?" + ) + msg_gt_3 = "(InvalidArgument) The dim is expected to be in range of [-3, 3), but got 3" + msg_gt_4 = "paddle.compat.split expects split_sizes have only non-negative entries, but got size = -5 on dim 2" + + split_size = paddle.to_tensor([3]) + msg_gt_5 = ( + "The type of 'split_size_or_sections' in split must be int, list or tuple in imperative mode, but " + f"received {type(split_size)}." + ) + + with self.assertRaises(TypeError) as cm: + tensors = paddle.split(tensor=x, split_size_or_sections=3, dim=0) + self.assertEqual(str(cm.exception), msg_gt_1) + + with self.assertRaises(TypeError) as cm: + tensors = split(x, num_or_sections=3, dim=0) + self.assertEqual(str(cm.exception), msg_gt_2) + + with self.assertRaises(ValueError) as cm: + tensors = split(x, 3, dim=3) + self.assertEqual(str(cm.exception), msg_gt_3) + + with self.assertRaises(ValueError) as cm: + tensors = split(x, [3, 3, -5], -2) + self.assertEqual(str(cm.exception), msg_gt_4) + + with self.assertRaises(TypeError) as cm: + tensors = split(x, split_size, 1) + self.assertEqual(str(cm.exception), msg_gt_5) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/legacy_test/test_compat_split_static.py b/test/legacy_test/test_compat_split_static.py new file mode 100644 index 0000000000000..006e3ec30ea07 --- /dev/null +++ b/test/legacy_test/test_compat_split_static.py @@ -0,0 +1,184 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle.compat import split + + +class TestCompatSplitStatic(unittest.TestCase): + def _compare_with_origin_static( + self, input_shape, size, axis=0, dim_rank=-1 + ): + """size_dim: -1 means we input size by int, 0 means 0-size tensor, 1 means tensor with shape [1]""" + numel = 1 + for v in input_shape: + numel *= v + input_axis = axis + if dim_rank == 0: + input_axis = paddle.to_tensor(axis) + elif dim_rank == 1: + input_axis = paddle.to_tensor([axis]) + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + input_tensor = paddle.arange(numel, dtype=paddle.float32).reshape( + input_shape + ) + pd_results = split(input_tensor, size, dim=input_axis) + + if isinstance(size, int): + shape_on_axis = input_tensor.shape[axis] + remaining_num = shape_on_axis % size + num_sections = shape_on_axis // size + if remaining_num == 0: + size = num_sections + else: + size = [size for _ in range(num_sections)] + size.append(remaining_num) + + origin_results = paddle.split( + input_tensor, num_or_sections=size, axis=axis + ) + assert len(pd_results) == len(origin_results), "length mismatched" + place = ( + paddle.CUDAPlace(0) + if paddle.is_compiled_with_cuda() + else paddle.CPUPlace() + ) + exe = paddle.static.Executor(place) + results = exe.run(fetch_list=[*origin_results, *pd_results]) + length_needed = len(results) // 2 + for i in range(length_needed): + np.testing.assert_allclose( + results[i], results[i + length_needed] + ) + paddle.disable_static() + + def test_split_composite_static(self): + paddle.seed(114514) + + def get_tensors(): + np.random.seed(114514) + np_arr = np.random.normal(0, 1, [2, 3, 4, 5]) + return paddle.to_tensor(np_arr), paddle.to_tensor(np_arr) + + in1, in2 = get_tensors() + in1.stop_gradient = False + in2.stop_gradient = False + + @paddle.jit.to_static + def computation_graph(in1: paddle.Tensor, in2: paddle.Tensor): + y1 = in1 * 1.5 + 1.0 + y1 = paddle.minimum(y1, paddle.to_tensor([0], dtype=paddle.float32)) + out1 = y1.mean(axis=0) + + y2 = in2 * 1.5 + 1.0 + y2 = paddle.minimum(y2, paddle.to_tensor([0], dtype=paddle.float32)) + out2 = y2.mean(axis=0) + + packs1 = paddle.compat.split(out1, 2, dim=2) + packs2 = paddle.split(out2, [2, 2, 1], axis=2) + + res1 = packs1[0] + packs1[1] + packs1[2] + res2 = packs2[0] + packs2[1] + packs2[2] + + return res1, res2 + + res1, res2 = computation_graph(in1, in2) + np.testing.assert_allclose(res1.numpy(), res2.numpy()) + + def test_static_graph(self): + """Test static graph execution""" + # fixed random seed for reproducibility + np.random.seed(114514) + # old static graph mode + paddle.enable_static() + + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data(name='x', shape=[None, 6], dtype='float32') + result0, result1 = split(x, split_size_or_sections=[3, 3], dim=1) + output = result0 * 2.0 + paddle.sin(result1) + + place = ( + paddle.CUDAPlace(0) + if paddle.is_compiled_with_cuda() + else paddle.CPUPlace() + ) + exe = paddle.static.Executor(place) + + input_data = np.random.rand(3, 6).astype('float32') + feed = {'x': input_data} + + results = exe.run(feed=feed, fetch_list=[result0, result1, output]) + + pd_result0, pd_result1 = results[0], results[1] + np.testing.assert_allclose(input_data[:, :3], pd_result0) + np.testing.assert_allclose(input_data[:, 3:], pd_result1) + + expected_output = input_data[:, :3] * 2.0 + np.sin( + input_data[:, 3:] + ) + np.testing.assert_allclose( + expected_output, results[2], rtol=1e-4, atol=1e-4 + ) + + paddle.disable_static() + + def test_error_hint(self): + """Test whether there will be correct exception when users pass paddle.split kwargs in paddle.compat.split, vice versa.""" + + msg_gt_1 = "split_size_or_sections must be greater than 0." + msg_gt_2 = "len(split_size_or_sections) must not be more than input.shape[dim]." + msg_gt_3 = "The type of 'split_size_or_sections' in split must be int, list or tuple in imperative mode." + msg_gt_4 = ( + "'dim' is not allowed to be a pir.Value in a static graph: " + "\npir.Value can not be used for indexing python lists/tuples." + ) + + paddle.enable_static() + with self.assertRaises(AssertionError) as cm: + x = paddle.randn([3, 4, 5]) + tensors = split(x, -2, dim=0) + self.assertEqual(str(cm.exception), msg_gt_1) + + with self.assertRaises(AssertionError) as cm: + x = paddle.randn([3, 4, 5]) + tensors = split(x, (1, 1, 1, 1, 2, 2), dim=-1) + self.assertEqual(str(cm.exception), msg_gt_2) + + with self.assertRaises(TypeError) as cm: + x = paddle.randn([3, 4, 5]) + tensors = split(x, paddle.to_tensor(2), dim=2) + self.assertEqual(str(cm.exception), msg_gt_3) + + with self.assertRaises(TypeError) as cm: + x = paddle.randn([3, 4, 5]) + tensors = split(x, 2, dim=paddle.to_tensor(2)) + paddle.disable_static() + self.assertEqual(str(cm.exception), msg_gt_4) + + def test_basic_split(self): + """Test basic splitting with integer size""" + input_shape = [3, 6] + self._compare_with_origin_static(input_shape, 1, 0) + self._compare_with_origin_static(input_shape, 3, -1) + self._compare_with_origin_static(input_shape, 4, dim_rank=0) + self._compare_with_origin_static(input_shape, 3, dim_rank=1) + + +if __name__ == '__main__': + unittest.main()