Skip to content

Commit 9bcb2d2

Browse files
committed
Add python wrapper for matmul_op and dot_product_attention
1 parent 234013a commit 9bcb2d2

File tree

5 files changed

+121
-67
lines changed

5 files changed

+121
-67
lines changed

doc/api/v2/fluid/layers.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,12 @@ split
364364
.. autofunction:: paddle.v2.fluid.layers.split
365365
:noindex:
366366

367+
368+
matmul
369+
------
370+
.. autofunction:: paddle.v2.fluid.layers.matmul
371+
:noindex:
372+
367373
logsigmoid
368374
----------
369375
.. autofunction:: paddle.v2.fluid.layers.logsigmoid

doc/api/v2/fluid/nets.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,9 @@ glu
2525
.. autofunction:: paddle.v2.fluid.nets.glu
2626
:noindex:
2727

28+
29+
dot_product_attention
30+
---------------------
31+
.. autofunction:: paddle.v2.fluid.nets.dot_product_attention
32+
:noindex:
33+

python/paddle/v2/fluid/layers/nn.py

Lines changed: 52 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
'sequence_last_step',
3838
'dropout',
3939
'split',
40+
'matmul',
4041
]
4142

4243

@@ -1586,83 +1587,71 @@ def split(input, num_or_sections, dim=-1):
15861587
return outs
15871588

15881589

1589-
def matmul(x, y):
1590+
def matmul(x, y, transpose_x=False, transpose_y=False, name=None):
15901591
"""
1591-
Applies matrix multipication to two tensors.
1592+
Applies matrix multipication to two tensors. Currently only rank 1 to rank
1593+
3 input tensors are supported.
15921594
1593-
This operator is used to perform (batched) matrix multiplication
1594-
over the last two dimensions of the input tensors `X` and `Y`.
1595+
The actual behavior depends on the shapes of :math:`x`, :math:`y` and the
1596+
flag values of :attr:`transpose_x`, :attr:`transpose_y`. Specifically:
15951597
1596-
If a transpose flag is specified, the last two dimensions of the
1597-
tensor are transposed. If the tensor is rank-1 of shape [D], then
1598-
for `X` it is treated as [1, D] in nontransposed form and as [D, 1]
1599-
in transposed form, whereas for `Y` it is the opposite: It is treated
1600-
as [D, 1] in nontransposed form and as [1, D] in transposed form.
1598+
- If a transpose flag is specified, the last two dimensions of the tensor
1599+
are transposed. If the tensor is rank-1 of shape :math:`[D]`, then for
1600+
:math:`x` it is treated as :math:`[1, D]` in nontransposed form and as
1601+
:math:`[D, 1]` in transposed form, whereas for :math:`y` it is the
1602+
opposite: It is treated as :math:`[D, 1]` in nontransposed form and as
1603+
:math:`[1, D]` in transposed form.
16011604
1602-
Examples without transpose:
1603-
- X: [K], Y: [K] => Out: [1]
1604-
- X: [K], Y: [K, N] => Out: [N]
1605-
- X: [B, M, K], Y: [K] => Out: [B, M]
1606-
- X: [M, K], Y: [B, K, N] => Out: [B, M, N]
1607-
- X: [B, M, K], Y: [B, K, N] => Out: [B, M, N]
1605+
- After transpose, the two tensors are 2-D or 3-D and matrix multipication
1606+
performs in the following way.
16081607
1609-
The behavior is designed to be similar to the `numpy.matmul` function.
1610-
The differences are:
1611-
- Currently only rank 1 to rank 3 input tensors are supported.
1612-
- We add `transpose_X` and `transpose_Y` flags.
1608+
- If both are 2-D, they are multiplied like conventional matrices.
1609+
- If either is 3-D, it is treated as a stack of matrices residing in the
1610+
last two dimensions and a batched matrix multiply supporting broadcast
1611+
applies on the two tensors.
16131612
1614-
Both the input `X` and `Y` can carry the LoD (Level of Details) information,
1615-
or not. But the output only shares the LoD information with input `X`.
1613+
Also note that if the raw tensor :math:`x` or :math:`y` is rank-1 and
1614+
nontransposed, the prepended or appended dimension :math:`1` will be
1615+
removed after matrix multipication.
16161616
16171617
Args:
16181618
x (Variable): The input variable which is a Tensor or LoDTensor.
1619-
y (Variable): If :attr:`num_or_sections` is an integer,
1620-
then the integer indicates the number of equal sized sub-tensors
1621-
that the tensor will be divided into. If :attr:`num_or_sections`
1622-
is a list of integers, the length of list indicates the number of
1623-
sub-tensors and the integers indicate the sizes of sub-tensors'
1624-
:attr:`dim` dimension orderly.
1625-
dim (int): The dimension along which to split. If :math:`dim < 0`, the
1626-
dimension to split along is :math:`rank(input) + dim`.
1619+
y (Variable): The input variable which is a Tensor or LoDTensor.
1620+
transpose_x (bool): Whether to transpose :math:`x` before multiplication.
1621+
transpose_y (bool): Whether to transpose :math:`y` before multiplication.
1622+
name(str|None): A name for this layer(optional). If set None, the layer
1623+
will be named automatically.
16271624
16281625
Returns:
1629-
List: The list of segmented tensor variables.
1626+
Variable: The product Tensor variable.
16301627
16311628
Examples:
16321629
.. code-block:: python
16331630
1634-
# x is a Tensor variable with shape [3, 9, 5]:
1635-
x0, x1, x2 = fluid.layers.split(x, num_or_sections=3, dim=1)
1636-
x0.shape # [3, 3, 5]
1637-
x1.shape # [3, 3, 5]
1638-
x2.shape # [3, 3, 5]
1639-
x0, x1, x2 = fluid.layers.split(x, num_or_sections=[2, 3, 4], dim=1)
1640-
x0.shape # [3, 2, 5]
1641-
x1.shape # [3, 3, 5]
1642-
x2.shape # [3, 4, 5]
1631+
# Examples to clarify shapes of the inputs and output
1632+
# x: [B, M, K], y: [B, K, N]
1633+
fluid.layers.matmul(x, y) # out: [B, M, N]
1634+
# x: [B, M, K], y: [K, N]
1635+
fluid.layers.matmul(x, y) # out: [B, M, N]
1636+
# x: [B, M, K], y: [K]
1637+
fluid.layers.matmul(x, y) # out: [B, M]
1638+
# x: [M, K], y: [K, N]
1639+
fluid.layers.matmul(x, y) # out: [M, N]
1640+
# x: [K], y: [K]
1641+
fluid.layers.matmul(x, y) # out: [1]
1642+
# x: [M], y: [N]
1643+
fluid.layers.matmul(x, y, True, True) # out: [M, N]
16431644
"""
1644-
helper = LayerHelper('split', **locals())
1645-
input_shape = input.shape
1646-
dim = (len(input_shape) + dim) if dim < 0 else dim
1647-
if isinstance(num_or_sections, int):
1648-
assert num_or_sections > 1, 'num_or_sections must be more than 1.'
1649-
num = num_or_sections
1650-
else:
1651-
assert len(num_or_sections) < input_shape[
1652-
dim], 'len(num_or_sections) must not be more than input.shape[dim].'
1653-
num = len(num_or_sections)
1654-
outs = [
1655-
helper.create_tmp_variable(dtype=helper.input_dtype())
1656-
for i in range(num)
1657-
]
1645+
helper = LayerHelper('matmul', **locals())
1646+
assert max(
1647+
len(x.shape), len(y.shape)
1648+
) <= 3, 'Currently only rank 1 to rank 3 input tensors are supported.'
1649+
out = helper.create_tmp_variable(dtype=helper.input_dtype())
16581650
helper.append_op(
1659-
type='split',
1660-
inputs={'X': input},
1661-
outputs={'Out': outs},
1662-
attrs={
1663-
'num': num_or_sections if isinstance(num_or_sections, int) else 0,
1664-
'sections': num_or_sections
1665-
if isinstance(num_or_sections, list) else [],
1666-
'axis': dim
1667-
})
1668-
return outs
1651+
type='matmul',
1652+
inputs={'X': x,
1653+
'Y': y},
1654+
outputs={'Out': out},
1655+
attrs={'transpose_X': transpose_x,
1656+
'transpose_Y': transpose_y})
1657+
return out

python/paddle/v2/fluid/nets.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"simple_img_conv_pool",
55
"sequence_conv_pool",
66
"glu",
7+
"",
78
]
89

910

@@ -135,3 +136,55 @@ def glu(input, dim=-1):
135136
a, b = layers.split(input, num_or_sections=2, dim=dim)
136137
out = layers.elementwise_mul(x=a, y=b)
137138
return out
139+
140+
141+
def dot_product_attention(querys, keys, values):
142+
"""
143+
The dot-product attention.
144+
145+
Attention mechanism can be seen as mapping a query and a set of key-value
146+
pairs to an output. The output is computed as a weighted sum of the values,
147+
where the weight assigned to each value is computed by a compatibility
148+
function (dot-product here) of the query with the corresponding key.
149+
150+
The dot-product attention can be implemented through (batch) matrix
151+
multipication as follows:
152+
153+
.. math::
154+
155+
Attention(Q, K, V)= softmax(QK^\mathrm{T})V
156+
157+
Refer to `Attention Is All You Need
158+
<https://arxiv.org/pdf/1706.03762.pdf>`_.
159+
160+
Note that batch data containing sequences with different lengths is not
161+
supported by this because of the (batch) matrix multipication.
162+
163+
Args:
164+
query (Variable): The input variable which is a Tensor or LoDTensor.
165+
key (Variable): The input variable which is a Tensor or LoDTensor.
166+
value (Variable): The input variable which is a Tensor or LoDTensor.
167+
168+
Returns:
169+
tuple: The Tensor variables representing the output and attention scores.
170+
171+
Examples:
172+
.. code-block:: python
173+
174+
# Suppose q, k, v are tensor variables with the following shape:
175+
# q: [3, 5, 9], k: [3, 6, 9], v: [3, 6, 10]
176+
out, attn_scores = fluid.nets.dot_product_attention(q, k, v)
177+
out.shape # [3, 5, 10]
178+
attn_scores.shape # [3, 5, 6]
179+
"""
180+
assert keys.shape[-2] == values.shape[
181+
-2], 'The shapes of keys and values mismatch.'
182+
assert querys.shape[-1] == keys.shape[
183+
-1], 'The shapes of querys and keys mismatch.'
184+
product = layers.matmul(x=querys, y=keys, transpose_y=True)
185+
attn_scores = layers.reshape(
186+
x=layers.reshape(
187+
x=product, shape=[-1, product.shape[-1]], act='softmax'),
188+
shape=product.shape)
189+
out = layers.matmul(attn_scores, values)
190+
return out, attn_scores

python/paddle/v2/fluid/tests/test_matmul_op.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,18 +83,18 @@ def setUp(self):
8383
self.outputs = {'Out': Out}
8484

8585
def test_check_output(self):
86-
self.check_output(atol=1e-2)
86+
self.check_output(atol=1e-3)
8787

8888
def test_check_grad_normal(self):
89-
self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.5)
89+
self.check_grad(['X', 'Y'], 'Out', max_relative_error=1e-3)
9090

9191
def test_check_grad_ignore_x(self):
9292
self.check_grad(
93-
['Y'], 'Out', max_relative_error=0.5, no_grad_set=set("X"))
93+
['Y'], 'Out', max_relative_error=1e-3, no_grad_set=set("X"))
9494

9595
def test_check_grad_ignore_y(self):
9696
self.check_grad(
97-
['X'], 'Out', max_relative_error=0.5, no_grad_set=set('Y'))
97+
['X'], 'Out', max_relative_error=1e-3, no_grad_set=set('Y'))
9898

9999

100100
# Generate test cases for all possibilities

0 commit comments

Comments
 (0)