@@ -791,29 +791,29 @@ def gather(x, index, axis=None, name=None):
791
791
792
792
def unbind (input , axis = 0 ):
793
793
"""
794
- :alias_main: paddle.tensor.unbind
795
- :alias: paddle.tensor.unbind,paddle.tensor.manipulation.unbind
796
794
797
795
Removes a tensor dimension, then split the input tensor into multiple sub-Tensors.
796
+
798
797
Args:
799
- input (Variable): The input variable which is an N-D Tensor, data type being float32, float64, int32 or int64.
800
-
801
- axis (int32|int64, optional): A scalar with type ``int32|int64`` shape [1]. The dimension along which to unbind. If :math:`axis < 0`, the
802
- dimension to unbind along is :math:`rank(input) + axis`. Default is 0.
798
+ input (Tensor): The input variable which is an N-D Tensor, data type being float32, float64, int32 or int64.
799
+ axis (int32|int64, optional): A scalar with type ``int32|int64`` shape [1]. The dimension along which to unbind.
800
+ If :math:`axis < 0`, the dimension to unbind along is :math:`rank(input) + axis`. Default is 0.
803
801
Returns:
804
- list(Variable ): The list of segmented Tensor variables.
802
+ list(Tensor ): The list of segmented Tensor variables.
805
803
806
804
Example:
807
805
.. code-block:: python
806
+
808
807
import paddle
808
+ import numpy as np
809
809
# input is a variable which shape is [3, 4, 5]
810
- input = paddle.fluid.data(
811
- name=" input", shape=[3, 4, 5], dtype="float32" )
812
- [x0, x1, x2] = paddle.tensor. unbind(input, axis=0)
810
+ np_input = np.random.rand(3, 4, 5).astype('float32')
811
+ input = paddle.to_tensor(np_input )
812
+ [x0, x1, x2] = paddle.unbind(input, axis=0)
813
813
# x0.shape [4, 5]
814
814
# x1.shape [4, 5]
815
815
# x2.shape [4, 5]
816
- [x0, x1, x2, x3] = paddle.tensor. unbind(input, axis=1)
816
+ [x0, x1, x2, x3] = paddle.unbind(input, axis=1)
817
817
# x0.shape [3, 5]
818
818
# x1.shape [3, 5]
819
819
# x2.shape [3, 5]
@@ -837,6 +837,8 @@ def unbind(input, axis=0):
837
837
helper .create_variable_for_type_inference (dtype = helper .input_dtype ())
838
838
for i in range (num )
839
839
]
840
+ if in_dygraph_mode ():
841
+ return core .ops .unbind (input , num , 'axis' , axis )
840
842
841
843
helper .append_op (
842
844
type = "unbind" ,
0 commit comments