Skip to content

Commit 5c3d16f

Browse files
author
Yibing Liu
authored
Add complex layer for matmul & transpose, test=release/2.0 (#24218)
1 parent 05335a2 commit 5c3d16f

File tree

7 files changed

+226
-6
lines changed

7 files changed

+226
-6
lines changed

python/paddle/complex/tensor/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,11 @@
1414

1515
from . import math
1616
from . import manipulation
17+
from . import linalg
1718
from .math import *
1819
from .manipulation import *
20+
from .linalg import *
1921

20-
__all__ = math.__all__ + []
22+
__all__ = math.__all__
2123
__all__ += manipulation.__all__
24+
__all__ += linalg.__all__
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from ..helper import is_complex, is_real, complex_variable_exists
16+
from ...fluid.framework import ComplexVariable
17+
from ...fluid import layers
18+
19+
__all__ = ['matmul', ]
20+
21+
22+
def matmul(x, y, transpose_x=False, transpose_y=False, alpha=1.0, name=None):
23+
"""
24+
Applies matrix multiplication to two complex number tensors. See the
25+
detailed description in :ref:`api_fluid_layers_matmul`.
26+
27+
Args:
28+
x (ComplexVariable|Variable): The first input, can be a ComplexVariable
29+
with data type complex32 or complex64, or a Variable with data type
30+
float32 or float64.
31+
y (ComplexVariable|Variable): The second input, can be a ComplexVariable
32+
with data type complex32 or complex64, or a Variable with data type
33+
float32 or float64.
34+
transpose_x (bool): Whether to transpose :math:`x` before multiplication.
35+
transpose_y (bool): Whether to transpose :math:`y` before multiplication.
36+
alpha (float): The scale of output. Default 1.0.
37+
name(str|None): A name for this layer(optional). If set None, the layer
38+
will be named automatically.
39+
40+
Returns:
41+
ComplexVariable: The product result, with the same data type as inputs.
42+
43+
Examples:
44+
.. code-block:: python
45+
46+
import numpy as np
47+
import paddle
48+
import paddle.fluid.dygraph as dg
49+
with dg.guard():
50+
x = np.array([[1.0 + 1j, 2.0 + 1j], [3.0+1j, 4.0+1j]])
51+
y = np.array([1.0 + 1j, 1.0 + 1j])
52+
x_var = dg.to_variable(x)
53+
y_var = dg.to_variable(y)
54+
result = paddle.complex.matmul(x_var, y_var)
55+
print(result.numpy())
56+
# [1.+5.j 5.+9.j]
57+
"""
58+
# x = a + bi, y = c + di
59+
# mm(x, y) = mm(a, c) - mm(b, d) + (mm(a, d) + mm(b, c))i
60+
complex_variable_exists([x, y], "matmul")
61+
a, b = (x.real, x.imag) if is_complex(x) else (x, None)
62+
c, d = (y.real, y.imag) if is_complex(y) else (y, None)
63+
ac = layers.matmul(a, c, transpose_x, transpose_y, alpha, name)
64+
if is_real(b) and is_real(d):
65+
bd = layers.matmul(b, d, transpose_x, transpose_y, alpha, name)
66+
real = ac - bd
67+
imag = layers.matmul(a, d, transpose_x, transpose_y, alpha, name) + \
68+
layers.matmul(b, c, transpose_x, transpose_y, alpha, name)
69+
elif is_real(b):
70+
real = ac
71+
imag = layers.matmul(b, c, transpose_x, transpose_y, alpha, name)
72+
else:
73+
real = ac
74+
imag = layers.matmul(a, d, transpose_x, transpose_y, alpha, name)
75+
return ComplexVariable(real, imag)

python/paddle/complex/tensor/manipulation.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
from ...fluid.framework import ComplexVariable
1818
from ...fluid import layers
1919

20-
__all__ = ['reshape', ]
20+
__all__ = [
21+
'reshape',
22+
'transpose',
23+
]
2124

2225

2326
def reshape(x, shape, inplace=False, name=None):
@@ -104,3 +107,39 @@ def reshape(x, shape, inplace=False, name=None):
104107
out_real = fluid.layers.reshape(x.real, shape, inplace=inplace, name=name)
105108
out_imag = fluid.layers.reshape(x.imag, shape, inplace=inplace, name=name)
106109
return ComplexVariable(out_real, out_imag)
110+
111+
112+
def transpose(x, perm, name=None):
113+
"""
114+
Permute the data dimensions for complex number :attr:`input` according to `perm`.
115+
116+
See :ref:`api_fluid_layers_transpose` for the real number API.
117+
118+
Args:
119+
x (ComplexVariable): The input n-D ComplexVariable with data type
120+
complex64 or complex128.
121+
perm (list): Permute the input according to the value of perm.
122+
name (str): The name of this layer. It is optional.
123+
124+
Returns:
125+
ComplexVariable: A transposed n-D ComplexVariable, with the same data type as :attr:`input`.
126+
127+
Examples:
128+
.. code-block:: python
129+
130+
import paddle
131+
import numpy as np
132+
import paddle.fluid.dygraph as dg
133+
134+
with dg.guard():
135+
a = np.array([[1.0 + 1.0j, 2.0 + 1.0j], [3.0+1.0j, 4.0+1.0j]])
136+
x = dg.to_variable(a)
137+
y = paddle.complex.transpose(x, [1, 0])
138+
print(y.numpy())
139+
# [[1.+1.j 3.+1.j]
140+
# [2.+1.j 4.+1.j]]
141+
"""
142+
complex_variable_exists([x], "transpose")
143+
real = layers.transpose(x.real, perm, name)
144+
imag = layers.transpose(x.imag, perm, name)
145+
return ComplexVariable(real, imag)

python/paddle/fluid/framework.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1644,10 +1644,9 @@ class ComplexVariable(object):
16441644
holding the real part and imaginary part of complex numbers respectively.
16451645
16461646
**Notes**:
1647-
**The constructor of Variable should not be invoked directly.**
1647+
**The constructor of ComplexVariable should not be invoked directly.**
16481648
1649-
**Only support dygraph mode at present. Please use** :ref:`api_fluid_dygraph_to_variable` **
1650-
to create a dygraph ComplexVariable with complex number data.**
1649+
**Only support dygraph mode at present. Please use** :ref:`api_fluid_dygraph_to_variable` **to create a dygraph ComplexVariable with complex number data.**
16511650
16521651
Args:
16531652
real (Variable): The Variable holding real-part data.
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
import paddle
17+
import numpy as np
18+
import paddle.fluid as fluid
19+
import paddle.fluid.dygraph as dg
20+
21+
22+
class TestComplexMatMulLayer(unittest.TestCase):
23+
def setUp(self):
24+
self._places = [fluid.CPUPlace()]
25+
if fluid.core.is_compiled_with_cuda():
26+
self._places.append(fluid.CUDAPlace(0))
27+
28+
def compare(self, x, y):
29+
for place in self._places:
30+
with dg.guard(place):
31+
x_var = dg.to_variable(x)
32+
y_var = dg.to_variable(y)
33+
result = paddle.complex.matmul(x_var, y_var)
34+
np_result = np.matmul(x, y)
35+
self.assertTrue(np.allclose(result.numpy(), np_result))
36+
37+
def test_complex_xy(self):
38+
x = np.random.random(
39+
(2, 3, 4, 5)).astype("float32") + 1J * np.random.random(
40+
(2, 3, 4, 5)).astype("float32")
41+
y = np.random.random(
42+
(2, 3, 5, 4)).astype("float32") + 1J * np.random.random(
43+
(2, 3, 5, 4)).astype("float32")
44+
self.compare(x, y)
45+
46+
def test_complex_x(self):
47+
x = np.random.random(
48+
(2, 3, 4, 5)).astype("float32") + 1J * np.random.random(
49+
(2, 3, 4, 5)).astype("float32")
50+
y = np.random.random((2, 3, 5, 4)).astype("float32")
51+
self.compare(x, y)
52+
53+
def test_complex_y(self):
54+
x = np.random.random((2, 3, 4, 5)).astype("float32")
55+
y = np.random.random(
56+
(2, 3, 5, 4)).astype("float32") + 1J * np.random.random(
57+
(2, 3, 5, 4)).astype("float32")
58+
self.compare(x, y)
59+
60+
61+
if __name__ == '__main__':
62+
unittest.main()
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
import paddle
17+
import numpy as np
18+
import paddle.fluid as fluid
19+
import paddle.fluid.dygraph as dg
20+
21+
22+
class TestComplexTransposeLayer(unittest.TestCase):
23+
def setUp(self):
24+
self._places = [fluid.CPUPlace()]
25+
if fluid.core.is_compiled_with_cuda():
26+
self._places.append(fluid.CUDAPlace(0))
27+
28+
def test_identity(self):
29+
data = np.random.random(
30+
(2, 3, 4, 5)).astype("float32") + 1J * np.random.random(
31+
(2, 3, 4, 5)).astype("float32")
32+
perm = [3, 2, 0, 1]
33+
np_trans = np.transpose(data, perm)
34+
for place in self._places:
35+
with dg.guard(place):
36+
var = dg.to_variable(data)
37+
trans = paddle.complex.transpose(var, perm=perm)
38+
self.assertTrue(np.allclose(trans.numpy(), np_trans))
39+
40+
41+
if __name__ == '__main__':
42+
unittest.main()

python/paddle/fluid/tests/unittests/test_complex_variable.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

0 commit comments

Comments
 (0)