Skip to content

Commit ce97c37

Browse files
author
Feiyu Chan
authored
Cherry pick: add reshape in paddle.complex (#24210)
* add reshape in paddle.complex, test=develop * fix typos in paddle.complex.kron's comment, fix unittest, test=release-2.0
1 parent f3ffd75 commit ce97c37

File tree

3 files changed

+160
-0
lines changed

3 files changed

+160
-0
lines changed

python/paddle/complex/tensor/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
# limitations under the License.
1414

1515
from . import math
16+
from . import manipulation
1617
from .math import *
18+
from .manipulation import *
1719

1820
__all__ = math.__all__ + []
21+
__all__ += manipulation.__all__
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from paddle.common_ops_import import *
16+
from ..helper import is_complex, is_real, complex_variable_exists
17+
from ...fluid.framework import ComplexVariable
18+
from ...fluid import layers
19+
20+
__all__ = ['reshape', ]
21+
22+
23+
def reshape(x, shape, inplace=False, name=None):
24+
"""
25+
To change the shape of ``x`` without changing its data.
26+
27+
There are some tricks when specifying the target shape.
28+
29+
1. -1 means the value of this dimension is inferred from the total element
30+
number of x and remaining dimensions. Thus one and only one dimension can
31+
be set -1.
32+
33+
2. 0 means the actual dimension value is going to be copied from the
34+
corresponding dimension of x. The index of 0s in shape can not exceed
35+
the dimension of x.
36+
37+
Here are some examples to explain it.
38+
39+
1. Given a 3-D tensor x with a shape [2, 4, 6], and the target shape
40+
is [6, 8], the reshape operator will transform x into a 2-D tensor with
41+
shape [6, 8] and leaving x's data unchanged.
42+
43+
2. Given a 3-D tensor x with a shape [2, 4, 6], and the target shape
44+
specified is [2, 3, -1, 2], the reshape operator will transform x into a
45+
4-D tensor with shape [2, 3, 4, 2] and leaving x's data unchanged. In this
46+
case, one dimension of the target shape is set to -1, the value of this
47+
dimension is inferred from the total element number of x and remaining
48+
dimensions.
49+
50+
3. Given a 3-D tensor x with a shape [2, 4, 6], and the target shape
51+
is [-1, 0, 3, 2], the reshape operator will transform x into a 4-D tensor
52+
with shape [2, 4, 3, 2] and leaving x's data unchanged. In this case,
53+
besides -1, 0 means the actual dimension value is going to be copied from
54+
the corresponding dimension of x.
55+
56+
Args:
57+
x(ComplexVariable): the input. A ``Tensor`` or ``LoDTensor`` , data
58+
type: ``complex64`` or ``complex128``.
59+
shape(list|tuple|Variable): target shape. At most one dimension of
60+
the target shape can be -1. If ``shape`` is a list or tuple, the
61+
elements of it should be integers or Tensors with shape [1] and
62+
data type ``int32``. If ``shape`` is an Variable, it should be
63+
an 1-D Tensor of data type ``int32``.
64+
inplace(bool, optional): If ``inplace`` is True, the output of
65+
``reshape`` is the same ComplexVariable as the input. Otherwise,
66+
the input and output of ``reshape`` are different
67+
ComplexVariables. Defaults to False. Note that if ``x``is more
68+
than one OPs' input, ``inplace`` must be False.
69+
name(str, optional): The default value is None. Normally there is no
70+
need for user to set this property. For more information, please
71+
refer to :ref:`api_guide_Name` .
72+
73+
Returns:
74+
ComplexVariable: A ``Tensor`` or ``LoDTensor``. The data type is same as ``x``. It is a new ComplexVariable if ``inplace`` is ``False``, otherwise it is ``x``.
75+
76+
Raises:
77+
ValueError: If more than one elements of ``shape`` is -1.
78+
ValueError: If the element of ``shape`` is 0, the corresponding dimension should be less than or equal to the dimension of ``x``.
79+
ValueError: If the elements in ``shape`` is negative except -1.
80+
81+
Examples:
82+
.. code-block:: python
83+
84+
import paddle.fluid as fluid
85+
import paddle.complex as cpx
86+
import paddle.fluid.dygraph as dg
87+
import numpy as np
88+
89+
x_np = np.random.randn(2, 3, 4) + 1j * np.random.randn(2, 3, 4)
90+
91+
place = fluid.CPUPlace()
92+
with dg.guard(place):
93+
x_var = dg.to_variable(x_np)
94+
y_var = cpx.reshape(x_var, (2, -1))
95+
y_np = y_var.numpy()
96+
print(y_np.shape)
97+
# (2, 12)
98+
"""
99+
complex_variable_exists([x], "reshape")
100+
if inplace:
101+
x.real = fluid.layers.reshape(x.real, shape, inplace=inplace, name=name)
102+
x.imag = fluid.layers.reshape(x.imag, shape, inplace=inplace, name=name)
103+
return x
104+
out_real = fluid.layers.reshape(x.real, shape, inplace=inplace, name=name)
105+
out_imag = fluid.layers.reshape(x.imag, shape, inplace=inplace, name=name)
106+
return ComplexVariable(out_real, out_imag)
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle.fluid as fluid
16+
import paddle.complex as cpx
17+
import paddle.fluid.dygraph as dg
18+
import numpy as np
19+
import unittest
20+
21+
22+
class TestComplexReshape(unittest.TestCase):
23+
def test_case1(self):
24+
x_np = np.random.randn(2, 3, 4) + 1j * np.random.randn(2, 3, 4)
25+
shape = (2, -1)
26+
27+
place = fluid.CPUPlace()
28+
with dg.guard(place):
29+
x_var = dg.to_variable(x_np)
30+
y_var = cpx.reshape(x_var, shape)
31+
y_np = y_var.numpy()
32+
33+
np.testing.assert_allclose(np.reshape(x_np, shape), y_np)
34+
35+
def test_case2(self):
36+
x_np = np.random.randn(2, 3, 4) + 1j * np.random.randn(2, 3, 4)
37+
shape = (0, -1)
38+
shape_ = (2, 12)
39+
40+
place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
41+
) else fluid.CPUPlace()
42+
with dg.guard(place):
43+
x_var = dg.to_variable(x_np)
44+
y_var = cpx.reshape(x_var, shape, inplace=True)
45+
y_np = y_var.numpy()
46+
47+
np.testing.assert_allclose(np.reshape(x_np, shape_), y_np)
48+
49+
50+
if __name__ == "__main__":
51+
unittest.main()

0 commit comments

Comments
 (0)