Skip to content

Commit f4350f3

Browse files
committed
[API-Compat] Fixed type annotation and removed legacy graph branch
1 parent a82b8e3 commit f4350f3

File tree

1 file changed

+3
-116
lines changed

1 file changed

+3
-116
lines changed

python/paddle/tensor/compat.py

Lines changed: 3 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,10 @@
1818

1919
import paddle
2020
from paddle import _C_ops
21-
from paddle.tensor import fill_constant
2221

23-
from ..base.data_feeder import (
24-
check_dtype,
25-
check_type,
26-
check_variable_and_dtype,
27-
)
2822
from ..base.framework import Variable
2923
from ..framework import (
30-
LayerHelper,
3124
in_dynamic_mode,
32-
in_pir_mode,
3325
)
3426

3527
if TYPE_CHECKING:
@@ -45,7 +37,7 @@
4537
@forbid_keywords(["x", "num_or_sections", "axis", "name"], "paddle.split")
4638
def split(
4739
tensor: Tensor, split_size_or_sections: int | Sequence[int], dim: int = 0
48-
) -> tuple[Tensor]:
40+
) -> tuple[Tensor, ...]:
4941
"""
5042
(PyTorch Compatible API) Split the input tensor into multiple sub-Tensors.
5143
@@ -72,7 +64,7 @@ def split(
7264
7365
>>> import paddle
7466
75-
>>> # x is a Tensor of shape [3, 9, 5]
67+
>>> # x is a Tensor of shape [3, 8, 5]
7668
>>> x = paddle.rand([3, 8, 5])
7769
7870
>>> out0, out1, out2 = paddle.compat.split(x, split_size_or_sections=3, dim=1)
@@ -170,7 +162,7 @@ def SaveGetShapeOnDim(shape, dim: int) -> int:
170162
)
171163
else:
172164
return tuple(_C_ops.split(tensor, split_size_or_sections, dim))
173-
elif in_pir_mode():
165+
else:
174166
if isinstance(dim, paddle.pir.Value):
175167
dim.stop_gradient = True
176168
if isinstance(dim, int):
@@ -212,108 +204,3 @@ def SaveGetShapeOnDim(shape, dim: int) -> int:
212204
split_size_or_sections
213205
)
214206
return tuple(_C_ops.split(tensor, split_size_or_sections, dim))
215-
216-
else:
217-
check_variable_and_dtype(
218-
tensor,
219-
'input',
220-
[
221-
'bool',
222-
'bfloat16',
223-
'float16',
224-
'uint16',
225-
'float32',
226-
'float64',
227-
'int32',
228-
'int64',
229-
'uint8',
230-
'int8',
231-
],
232-
'split',
233-
)
234-
check_type(
235-
split_size_or_sections,
236-
'split_size_or_sections',
237-
(list, int, tuple),
238-
'split',
239-
)
240-
check_type(dim, 'dim', (int, Variable), 'split')
241-
if isinstance(dim, Variable):
242-
check_dtype(dim.dtype, 'dim', ['int32', 'int64'], 'split')
243-
244-
helper = LayerHelper('split', **locals())
245-
246-
input_shape = tensor.shape
247-
inputs = {'X': tensor}
248-
attrs = {'num': 0}
249-
250-
def _get_SectionsTensorList(one_list):
251-
tensor_list = []
252-
unk_dim_idx = -1
253-
for idx, dim_size in enumerate(one_list):
254-
if isinstance(dim_size, Variable):
255-
dim_size.stop_gradient = True
256-
tensor_list.append(dim_size)
257-
else:
258-
assert isinstance(dim_size, int)
259-
if dim_size == -1:
260-
assert unk_dim_idx == -1, (
261-
"Only one value of 'num_or_section' in split can "
262-
f"be -1. But received num_or_section[{idx}] is also -1."
263-
)
264-
unk_dim_idx = idx
265-
temp_out = helper.create_variable_for_type_inference(
266-
'int32'
267-
)
268-
fill_constant(
269-
[1], 'int32', dim_size, force_cpu=True, out=temp_out
270-
)
271-
tensor_list.append(temp_out)
272-
return tuple(tensor_list)
273-
274-
if isinstance(dim, Variable):
275-
dim.stop_gradient = True
276-
inputs['AxisTensor'] = dim
277-
else:
278-
assert len(tensor.shape) + dim >= 0, "(rank(x) + dim) must >= 0"
279-
dim = (len(input_shape) + dim) if dim < 0 else dim
280-
attrs['axis'] = dim
281-
282-
if isinstance(split_size_or_sections, int):
283-
shape_on_dim = SaveGetShapeOnDim(tensor.shape, dim)
284-
split_size_or_sections = GetSplitSize(
285-
split_size_or_sections, shape_on_dim
286-
)
287-
288-
if isinstance(split_size_or_sections, int):
289-
# after GetSplitSize, if the result is int, split_size_or_sections is actually equivalent to the original num_or_sections (num)
290-
attrs['num'] = split_size_or_sections
291-
assert (
292-
split_size_or_sections > 0
293-
), 'split_size_or_sections must be than 0.'
294-
num = split_size_or_sections
295-
else:
296-
if isinstance(dim, int) and input_shape[dim] > 0:
297-
assert (
298-
len(split_size_or_sections) <= input_shape[dim]
299-
), 'len(split_size_or_sections) must not be more than input.shape[dim].'
300-
num = len(split_size_or_sections)
301-
attrs['sections'] = [
302-
-1 if isinstance(ele, Variable) else ele
303-
for ele in split_size_or_sections
304-
]
305-
if paddle.utils._contain_var(split_size_or_sections):
306-
inputs['SectionsTensorList'] = _get_SectionsTensorList(
307-
split_size_or_sections
308-
)
309-
310-
outs = [
311-
helper.create_variable_for_type_inference(
312-
dtype=helper.input_dtype()
313-
)
314-
for i in range(num)
315-
]
316-
helper.append_op(
317-
type='split', inputs=inputs, outputs={'Out': outs}, attrs=attrs
318-
)
319-
return tuple(outs)

0 commit comments

Comments
 (0)