Skip to content

Commit 234013a

Browse files
committed
Add python wrapper for matmul_op
1 parent e7acf32 commit 234013a

File tree

1 file changed

+82
-0
lines changed
  • python/paddle/v2/fluid/layers

1 file changed

+82
-0
lines changed

python/paddle/v2/fluid/layers/nn.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1584,3 +1584,85 @@ def split(input, num_or_sections, dim=-1):
15841584
'axis': dim
15851585
})
15861586
return outs
1587+
1588+
1589+
def matmul(x, y):
1590+
"""
1591+
Applies matrix multipication to two tensors.
1592+
1593+
This operator is used to perform (batched) matrix multiplication
1594+
over the last two dimensions of the input tensors `X` and `Y`.
1595+
1596+
If a transpose flag is specified, the last two dimensions of the
1597+
tensor are transposed. If the tensor is rank-1 of shape [D], then
1598+
for `X` it is treated as [1, D] in nontransposed form and as [D, 1]
1599+
in transposed form, whereas for `Y` it is the opposite: It is treated
1600+
as [D, 1] in nontransposed form and as [1, D] in transposed form.
1601+
1602+
Examples without transpose:
1603+
- X: [K], Y: [K] => Out: [1]
1604+
- X: [K], Y: [K, N] => Out: [N]
1605+
- X: [B, M, K], Y: [K] => Out: [B, M]
1606+
- X: [M, K], Y: [B, K, N] => Out: [B, M, N]
1607+
- X: [B, M, K], Y: [B, K, N] => Out: [B, M, N]
1608+
1609+
The behavior is designed to be similar to the `numpy.matmul` function.
1610+
The differences are:
1611+
- Currently only rank 1 to rank 3 input tensors are supported.
1612+
- We add `transpose_X` and `transpose_Y` flags.
1613+
1614+
Both the input `X` and `Y` can carry the LoD (Level of Details) information,
1615+
or not. But the output only shares the LoD information with input `X`.
1616+
1617+
Args:
1618+
x (Variable): The input variable which is a Tensor or LoDTensor.
1619+
y (Variable): If :attr:`num_or_sections` is an integer,
1620+
then the integer indicates the number of equal sized sub-tensors
1621+
that the tensor will be divided into. If :attr:`num_or_sections`
1622+
is a list of integers, the length of list indicates the number of
1623+
sub-tensors and the integers indicate the sizes of sub-tensors'
1624+
:attr:`dim` dimension orderly.
1625+
dim (int): The dimension along which to split. If :math:`dim < 0`, the
1626+
dimension to split along is :math:`rank(input) + dim`.
1627+
1628+
Returns:
1629+
List: The list of segmented tensor variables.
1630+
1631+
Examples:
1632+
.. code-block:: python
1633+
1634+
# x is a Tensor variable with shape [3, 9, 5]:
1635+
x0, x1, x2 = fluid.layers.split(x, num_or_sections=3, dim=1)
1636+
x0.shape # [3, 3, 5]
1637+
x1.shape # [3, 3, 5]
1638+
x2.shape # [3, 3, 5]
1639+
x0, x1, x2 = fluid.layers.split(x, num_or_sections=[2, 3, 4], dim=1)
1640+
x0.shape # [3, 2, 5]
1641+
x1.shape # [3, 3, 5]
1642+
x2.shape # [3, 4, 5]
1643+
"""
1644+
helper = LayerHelper('split', **locals())
1645+
input_shape = input.shape
1646+
dim = (len(input_shape) + dim) if dim < 0 else dim
1647+
if isinstance(num_or_sections, int):
1648+
assert num_or_sections > 1, 'num_or_sections must be more than 1.'
1649+
num = num_or_sections
1650+
else:
1651+
assert len(num_or_sections) < input_shape[
1652+
dim], 'len(num_or_sections) must not be more than input.shape[dim].'
1653+
num = len(num_or_sections)
1654+
outs = [
1655+
helper.create_tmp_variable(dtype=helper.input_dtype())
1656+
for i in range(num)
1657+
]
1658+
helper.append_op(
1659+
type='split',
1660+
inputs={'X': input},
1661+
outputs={'Out': outs},
1662+
attrs={
1663+
'num': num_or_sections if isinstance(num_or_sections, int) else 0,
1664+
'sections': num_or_sections
1665+
if isinstance(num_or_sections, list) else [],
1666+
'axis': dim
1667+
})
1668+
return outs

0 commit comments

Comments
 (0)