Skip to content

Commit 7f0baf6

Browse files
authored
[API-Compat] paddle.compat.split is added and tested (#74506)
* [API-Compat] paddle.compat.split is added and tested * [API-Compat] paddle.compat.split is rigorously tested * [API-Compat] Fixed erroneous func help doc * [API-Compat] Make the forbid_keywords decorator transparent * [API-Compat] Fixed decorator str input * [API-Compat] Fixed type annotation and removed legacy graph branch * [API-Compat] More unittest & static graph check & updated decorator * [API-Compat] Force update (local and not reproduce the bug) * [API-Compat] Removed unittest that paddle.split will also fail * [API-Compat] More efficient forbid-keyword decorator * [API-Compat] Resolved merge conflicts. * Update compat.py * Update compat.py
1 parent d77dd90 commit 7f0baf6

File tree

7 files changed

+631
-0
lines changed

7 files changed

+631
-0
lines changed

python/paddle/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@
122122
_pir_ops as _pir_ops,
123123
_typing as _typing,
124124
callbacks as callbacks,
125+
compat as compat,
125126
fft as fft,
126127
hub as hub,
127128
linalg as linalg,

python/paddle/compat.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .tensor.compat import (
16+
split,
17+
)
18+
19+
__all__ = [
20+
'split',
21+
]

python/paddle/tensor/compat.py

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from typing import TYPE_CHECKING
18+
19+
import paddle
20+
from paddle import _C_ops
21+
22+
from ..base.framework import Variable
23+
from ..framework import (
24+
in_dynamic_mode,
25+
)
26+
27+
if TYPE_CHECKING:
28+
from collections.abc import Sequence
29+
30+
from paddle import Tensor
31+
32+
from paddle.utils.decorator_utils import ForbidKeywordsDecorator
33+
34+
__all__ = []
35+
36+
37+
@ForbidKeywordsDecorator(
38+
illegal_keys={"x", "num_or_sections", "axis", "name"},
39+
func_name="paddle.compat.split",
40+
correct_name="paddle.split",
41+
)
42+
def split(
43+
tensor: Tensor, split_size_or_sections: int | Sequence[int], dim: int = 0
44+
) -> tuple[Tensor, ...]:
45+
"""
46+
(PyTorch Compatible API) Split the input tensor into multiple sub-Tensors.
47+
48+
Args:
49+
tensor (Tensor): A N-D Tensor. The data type is bool, bfloat16, float16, float32, float64, uint8, int8, int32 or int64.
50+
split_size_or_sections (int|list|tuple):
51+
If split_size_or_sections is an integer type, then tensor will be split into equally sized chunks (if possible).
52+
Last chunk will be smaller if the tensor size along the given dimension dim is not divisible by split_size.
53+
If split_size_or_sections is a list, then tensor will be split into len(split_size_or_sections) chunks with sizes
54+
in dim according to split_size_or_sections. Negative inputs are not allowed. For example: for a dim with 9 channels,
55+
[2, 3, -1] will not be interpreted as [2, 3, 4], but will be rejected and an exception will be thrown.
56+
dim (int|Tensor, optional): The dim along which to split, it can be a integer or a ``0-D Tensor``
57+
with shape [] and data type ``int32`` or ``int64``.
58+
If :math::`dim < 0`, the dim to split along is :math:`rank(x) + dim`. Default is 0.
59+
Returns:
60+
tuple(Tensor), The tuple of segmented Tensors.
61+
62+
Note:
63+
This is a pytorch compatible API that follows the function signature and behavior of torch.split.
64+
To use the original split of paddle, please consider `paddle.split`
65+
66+
Examples:
67+
.. code-block:: python
68+
69+
>>> import paddle
70+
71+
>>> # x is a Tensor of shape [3, 8, 5]
72+
>>> x = paddle.rand([3, 8, 5])
73+
74+
>>> out0, out1, out2 = paddle.compat.split(x, split_size_or_sections=3, dim=1)
75+
>>> print(out0.shape)
76+
[3, 3, 5]
77+
>>> print(out1.shape)
78+
[3, 3, 5]
79+
>>> print(out2.shape)
80+
[3, 2, 5]
81+
82+
>>> out0, out1, out2 = paddle.compat.split(x, split_size_or_sections=[1, 2, 5], dim=1)
83+
>>> print(out0.shape)
84+
[3, 1, 5]
85+
>>> print(out1.shape)
86+
[3, 2, 5]
87+
>>> print(out2.shape)
88+
[3, 5, 5]
89+
90+
>>> # dim is negative, the real dim is (rank(x) + dim)=1
91+
>>> out0, out1, out2 = paddle.compat.split(x, split_size_or_sections=3, dim=-2)
92+
>>> print(out0.shape)
93+
[3, 3, 5]
94+
>>> print(out1.shape)
95+
[3, 3, 5]
96+
>>> print(out2.shape)
97+
[3, 2, 5]
98+
"""
99+
100+
def GetSplitSize(split_size, shape_on_dim):
101+
remaining_num = shape_on_dim % split_size_or_sections
102+
num_complete_section = shape_on_dim // split_size_or_sections
103+
if remaining_num == 0:
104+
return num_complete_section
105+
else:
106+
sections = [
107+
split_size_or_sections for _ in range(num_complete_section)
108+
]
109+
sections.append(remaining_num)
110+
return sections
111+
112+
def GetShapeOnDimInRange(shape, dim: int) -> int:
113+
shape_range = len(shape)
114+
if isinstance(dim, int):
115+
if dim < -shape_range or dim >= shape_range:
116+
raise ValueError(
117+
f"(InvalidArgument) The dim is expected to be in range of [-{shape_range}, {shape_range}), but got {dim}"
118+
)
119+
return shape[dim]
120+
121+
if isinstance(split_size_or_sections, (list, tuple)):
122+
for i, section_size in enumerate(split_size_or_sections):
123+
shape_val = 0
124+
if isinstance(section_size, Variable):
125+
shape_val = int(section_size.item(0))
126+
else:
127+
shape_val = section_size
128+
if section_size < 0:
129+
raise ValueError(
130+
f"paddle.compat.split expects split_sizes have only non-negative entries, but got size = {section_size} on dim {i}"
131+
)
132+
133+
if in_dynamic_mode():
134+
if isinstance(dim, Variable):
135+
dim = dim.item(0)
136+
assert dim + len(tensor.shape) >= 0, "(rank(x) + dim) must >= 0"
137+
dim = (dim + len(tensor.shape)) if dim < 0 else dim
138+
139+
if isinstance(split_size_or_sections, (list, tuple)):
140+
if paddle.utils._contain_var(split_size_or_sections):
141+
for index, item in enumerate(split_size_or_sections):
142+
if isinstance(item, Variable):
143+
split_size_or_sections[index] = split_size_or_sections[
144+
index
145+
].item()
146+
elif not isinstance(split_size_or_sections, int):
147+
raise TypeError(
148+
"The type of 'split_size_or_sections' in split must be int, list or tuple in imperative mode, but "
149+
f"received {type(split_size_or_sections)}."
150+
)
151+
152+
if isinstance(split_size_or_sections, int):
153+
# check whether shape is divisible
154+
assert (
155+
split_size_or_sections > 0
156+
), 'split_size_or_sections must be greater than 0.'
157+
158+
split_size_or_sections = GetSplitSize(
159+
split_size_or_sections, GetShapeOnDimInRange(tensor.shape, dim)
160+
)
161+
162+
if isinstance(split_size_or_sections, list):
163+
return tuple(_C_ops.split(tensor, split_size_or_sections, dim))
164+
else:
165+
return tuple(
166+
_C_ops.split_with_num(tensor, split_size_or_sections, dim)
167+
)
168+
else:
169+
return tuple(_C_ops.split(tensor, split_size_or_sections, dim))
170+
else:
171+
if isinstance(dim, paddle.pir.Value):
172+
raise TypeError(
173+
"'dim' is not allowed to be a pir.Value in a static graph: "
174+
"\npir.Value can not be used for indexing python lists/tuples."
175+
)
176+
if isinstance(dim, int):
177+
assert len(tensor.shape) + dim >= 0, "(rank(x) + dim) must >= 0"
178+
dim = (len(tensor.shape) + dim) if dim < 0 else dim
179+
180+
input_shape = tensor.shape
181+
182+
if not isinstance(split_size_or_sections, (int, list, tuple)):
183+
raise TypeError(
184+
"The type of 'split_size_or_sections' in split must be int, list or tuple in imperative mode."
185+
)
186+
if isinstance(split_size_or_sections, int):
187+
assert (
188+
split_size_or_sections > 0
189+
), 'split_size_or_sections must be greater than 0.'
190+
191+
split_size_or_sections = GetSplitSize(
192+
split_size_or_sections, GetShapeOnDimInRange(tensor.shape, dim)
193+
)
194+
if isinstance(split_size_or_sections, list):
195+
if paddle.utils._contain_var(split_size_or_sections):
196+
split_size_or_sections = paddle.utils.get_int_tensor_list(
197+
split_size_or_sections
198+
)
199+
return tuple(_C_ops.split(tensor, split_size_or_sections, dim))
200+
else:
201+
return tuple(
202+
_C_ops.split_with_num(tensor, split_size_or_sections, dim)
203+
)
204+
else:
205+
if isinstance(dim, int) and input_shape[dim] > 0:
206+
assert (
207+
len(split_size_or_sections) <= input_shape[dim]
208+
), 'len(split_size_or_sections) must not be more than input.shape[dim].'
209+
if paddle.utils._contain_var(split_size_or_sections):
210+
split_size_or_sections = paddle.utils.get_int_tensor_list(
211+
split_size_or_sections
212+
)
213+
return tuple(_C_ops.split(tensor, split_size_or_sections, dim))

python/paddle/tensor/manipulation.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@
6363
TensorOrTensors,
6464
)
6565

66+
from paddle.utils.decorator_utils import ForbidKeywordsDecorator
67+
6668
__all__ = []
6769

6870

@@ -2735,6 +2737,11 @@ def row_stack(x: Sequence[Tensor], name: str | None = None) -> Tensor:
27352737
return paddle.vstack(x, name=name)
27362738

27372739

2740+
@ForbidKeywordsDecorator(
2741+
illegal_keys={"tensor", "split_size_or_sections", "dim"},
2742+
func_name="paddle.split",
2743+
correct_name="paddle.compat.split",
2744+
)
27382745
def split(
27392746
x: Tensor,
27402747
num_or_sections: int | Sequence[int],

python/paddle/utils/decorator_utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,34 @@ def wrapper(*args, **kwargs):
273273
return decorator
274274

275275

276+
class ForbidKeywordsDecorator(DecoratorBase):
277+
"""A decorator that hints users to use the correct `compat` functions, when erroneous keyword arguments are detected"""
278+
279+
def __init__(
280+
self, illegal_keys: set[str], func_name: str, correct_name: str
281+
) -> None:
282+
super().__init__()
283+
self.illegal_keys = illegal_keys
284+
self.func_name = func_name
285+
self.correct_name = correct_name
286+
287+
def process(
288+
self, args: tuple[Any, ...], kwargs: dict[str, Any]
289+
) -> tuple[tuple[Any, ...], dict[str, Any]]:
290+
found_keys = [key for key in self.illegal_keys if key in kwargs]
291+
292+
if found_keys:
293+
found_keys.sort()
294+
keys_str = ", ".join(f"'{key}'" for key in found_keys)
295+
plural = "s" if len(found_keys) > 1 else ""
296+
297+
raise TypeError(
298+
f"{self.func_name}() received unexpected keyword argument{plural} {keys_str}. "
299+
f"\nDid you mean to use {self.correct_name}() instead?"
300+
)
301+
return args, kwargs
302+
303+
276304
def reshape_decorator():
277305
"""
278306
Usage Example:

0 commit comments

Comments
 (0)