18
18
19
19
import paddle
20
20
from paddle import _C_ops
21
- from paddle .tensor import fill_constant
22
21
23
- from ..base .data_feeder import (
24
- check_dtype ,
25
- check_type ,
26
- check_variable_and_dtype ,
27
- )
28
22
from ..base .framework import Variable
29
23
from ..framework import (
30
- LayerHelper ,
31
24
in_dynamic_mode ,
32
- in_pir_mode ,
33
25
)
34
26
35
27
if TYPE_CHECKING :
45
37
@forbid_keywords (["x" , "num_or_sections" , "axis" , "name" ], "paddle.split" )
46
38
def split (
47
39
tensor : Tensor , split_size_or_sections : int | Sequence [int ], dim : int = 0
48
- ) -> tuple [Tensor ]:
40
+ ) -> tuple [Tensor , ... ]:
49
41
"""
50
42
(PyTorch Compatible API) Split the input tensor into multiple sub-Tensors.
51
43
@@ -72,7 +64,7 @@ def split(
72
64
73
65
>>> import paddle
74
66
75
- >>> # x is a Tensor of shape [3, 9 , 5]
67
+ >>> # x is a Tensor of shape [3, 8 , 5]
76
68
>>> x = paddle.rand([3, 8, 5])
77
69
78
70
>>> out0, out1, out2 = paddle.compat.split(x, split_size_or_sections=3, dim=1)
@@ -170,7 +162,7 @@ def SaveGetShapeOnDim(shape, dim: int) -> int:
170
162
)
171
163
else :
172
164
return tuple (_C_ops .split (tensor , split_size_or_sections , dim ))
173
- elif in_pir_mode () :
165
+ else :
174
166
if isinstance (dim , paddle .pir .Value ):
175
167
dim .stop_gradient = True
176
168
if isinstance (dim , int ):
@@ -212,108 +204,3 @@ def SaveGetShapeOnDim(shape, dim: int) -> int:
212
204
split_size_or_sections
213
205
)
214
206
return tuple (_C_ops .split (tensor , split_size_or_sections , dim ))
215
-
216
- else :
217
- check_variable_and_dtype (
218
- tensor ,
219
- 'input' ,
220
- [
221
- 'bool' ,
222
- 'bfloat16' ,
223
- 'float16' ,
224
- 'uint16' ,
225
- 'float32' ,
226
- 'float64' ,
227
- 'int32' ,
228
- 'int64' ,
229
- 'uint8' ,
230
- 'int8' ,
231
- ],
232
- 'split' ,
233
- )
234
- check_type (
235
- split_size_or_sections ,
236
- 'split_size_or_sections' ,
237
- (list , int , tuple ),
238
- 'split' ,
239
- )
240
- check_type (dim , 'dim' , (int , Variable ), 'split' )
241
- if isinstance (dim , Variable ):
242
- check_dtype (dim .dtype , 'dim' , ['int32' , 'int64' ], 'split' )
243
-
244
- helper = LayerHelper ('split' , ** locals ())
245
-
246
- input_shape = tensor .shape
247
- inputs = {'X' : tensor }
248
- attrs = {'num' : 0 }
249
-
250
- def _get_SectionsTensorList (one_list ):
251
- tensor_list = []
252
- unk_dim_idx = - 1
253
- for idx , dim_size in enumerate (one_list ):
254
- if isinstance (dim_size , Variable ):
255
- dim_size .stop_gradient = True
256
- tensor_list .append (dim_size )
257
- else :
258
- assert isinstance (dim_size , int )
259
- if dim_size == - 1 :
260
- assert unk_dim_idx == - 1 , (
261
- "Only one value of 'num_or_section' in split can "
262
- f"be -1. But received num_or_section[{ idx } ] is also -1."
263
- )
264
- unk_dim_idx = idx
265
- temp_out = helper .create_variable_for_type_inference (
266
- 'int32'
267
- )
268
- fill_constant (
269
- [1 ], 'int32' , dim_size , force_cpu = True , out = temp_out
270
- )
271
- tensor_list .append (temp_out )
272
- return tuple (tensor_list )
273
-
274
- if isinstance (dim , Variable ):
275
- dim .stop_gradient = True
276
- inputs ['AxisTensor' ] = dim
277
- else :
278
- assert len (tensor .shape ) + dim >= 0 , "(rank(x) + dim) must >= 0"
279
- dim = (len (input_shape ) + dim ) if dim < 0 else dim
280
- attrs ['axis' ] = dim
281
-
282
- if isinstance (split_size_or_sections , int ):
283
- shape_on_dim = SaveGetShapeOnDim (tensor .shape , dim )
284
- split_size_or_sections = GetSplitSize (
285
- split_size_or_sections , shape_on_dim
286
- )
287
-
288
- if isinstance (split_size_or_sections , int ):
289
- # after GetSplitSize, if the result is int, split_size_or_sections is actually equivalent to the original num_or_sections (num)
290
- attrs ['num' ] = split_size_or_sections
291
- assert (
292
- split_size_or_sections > 0
293
- ), 'split_size_or_sections must be than 0.'
294
- num = split_size_or_sections
295
- else :
296
- if isinstance (dim , int ) and input_shape [dim ] > 0 :
297
- assert (
298
- len (split_size_or_sections ) <= input_shape [dim ]
299
- ), 'len(split_size_or_sections) must not be more than input.shape[dim].'
300
- num = len (split_size_or_sections )
301
- attrs ['sections' ] = [
302
- - 1 if isinstance (ele , Variable ) else ele
303
- for ele in split_size_or_sections
304
- ]
305
- if paddle .utils ._contain_var (split_size_or_sections ):
306
- inputs ['SectionsTensorList' ] = _get_SectionsTensorList (
307
- split_size_or_sections
308
- )
309
-
310
- outs = [
311
- helper .create_variable_for_type_inference (
312
- dtype = helper .input_dtype ()
313
- )
314
- for i in range (num )
315
- ]
316
- helper .append_op (
317
- type = 'split' , inputs = inputs , outputs = {'Out' : outs }, attrs = attrs
318
- )
319
- return tuple (outs )
0 commit comments