Skip to content

Commit 4df95ed

Browse files
authored
Merge pull request #7602 from guoshengCS/add-dot_product_attention
Add Python wrapper for dot-product-attention.
2 parents 939e1b1 + db959d6 commit 4df95ed

File tree

5 files changed

+140
-4
lines changed

5 files changed

+140
-4
lines changed

doc/api/v2/fluid/layers.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,12 @@ split
364364
.. autofunction:: paddle.v2.fluid.layers.split
365365
:noindex:
366366

367+
368+
matmul
369+
------
370+
.. autofunction:: paddle.v2.fluid.layers.matmul
371+
:noindex:
372+
367373
logsigmoid
368374
----------
369375
.. autofunction:: paddle.v2.fluid.layers.logsigmoid

doc/api/v2/fluid/nets.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,9 @@ glu
2525
.. autofunction:: paddle.v2.fluid.nets.glu
2626
:noindex:
2727

28+
29+
dot_product_attention
30+
---------------------
31+
.. autofunction:: paddle.v2.fluid.nets.dot_product_attention
32+
:noindex:
33+

python/paddle/v2/fluid/layers/nn.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
'sequence_last_step',
5151
'dropout',
5252
'split',
53+
'matmul',
5354
]
5455

5556

@@ -1597,3 +1598,73 @@ def split(input, num_or_sections, dim=-1):
15971598
'axis': dim
15981599
})
15991600
return outs
1601+
1602+
1603+
def matmul(x, y, transpose_x=False, transpose_y=False, name=None):
1604+
"""
1605+
Applies matrix multipication to two tensors. Currently only rank 1 to rank
1606+
3 input tensors are supported.
1607+
1608+
The actual behavior depends on the shapes of :math:`x`, :math:`y` and the
1609+
flag values of :attr:`transpose_x`, :attr:`transpose_y`. Specifically:
1610+
1611+
- If a transpose flag is specified, the last two dimensions of the tensor
1612+
are transposed. If the tensor is rank-1 of shape :math:`[D]`, then for
1613+
:math:`x` it is treated as :math:`[1, D]` in nontransposed form and as
1614+
:math:`[D, 1]` in transposed form, whereas for :math:`y` it is the
1615+
opposite: It is treated as :math:`[D, 1]` in nontransposed form and as
1616+
:math:`[1, D]` in transposed form.
1617+
1618+
- After transpose, the two tensors are 2-D or 3-D and matrix multipication
1619+
performs in the following way.
1620+
1621+
- If both are 2-D, they are multiplied like conventional matrices.
1622+
- If either is 3-D, it is treated as a stack of matrices residing in the
1623+
last two dimensions and a batched matrix multiply supporting broadcast
1624+
applies on the two tensors.
1625+
1626+
Also note that if the raw tensor :math:`x` or :math:`y` is rank-1 and
1627+
nontransposed, the prepended or appended dimension :math:`1` will be
1628+
removed after matrix multipication.
1629+
1630+
Args:
1631+
x (Variable): The input variable which is a Tensor or LoDTensor.
1632+
y (Variable): The input variable which is a Tensor or LoDTensor.
1633+
transpose_x (bool): Whether to transpose :math:`x` before multiplication.
1634+
transpose_y (bool): Whether to transpose :math:`y` before multiplication.
1635+
name(str|None): A name for this layer(optional). If set None, the layer
1636+
will be named automatically.
1637+
1638+
Returns:
1639+
Variable: The product Tensor variable.
1640+
1641+
Examples:
1642+
.. code-block:: python
1643+
1644+
# Examples to clarify shapes of the inputs and output
1645+
# x: [B, M, K], y: [B, K, N]
1646+
fluid.layers.matmul(x, y) # out: [B, M, N]
1647+
# x: [B, M, K], y: [K, N]
1648+
fluid.layers.matmul(x, y) # out: [B, M, N]
1649+
# x: [B, M, K], y: [K]
1650+
fluid.layers.matmul(x, y) # out: [B, M]
1651+
# x: [M, K], y: [K, N]
1652+
fluid.layers.matmul(x, y) # out: [M, N]
1653+
# x: [K], y: [K]
1654+
fluid.layers.matmul(x, y) # out: [1]
1655+
# x: [M], y: [N]
1656+
fluid.layers.matmul(x, y, True, True) # out: [M, N]
1657+
"""
1658+
helper = LayerHelper('matmul', **locals())
1659+
assert max(
1660+
len(x.shape), len(y.shape)
1661+
) <= 3, 'Currently only rank 1 to rank 3 input tensors are supported.'
1662+
out = helper.create_tmp_variable(dtype=helper.input_dtype())
1663+
helper.append_op(
1664+
type='matmul',
1665+
inputs={'X': x,
1666+
'Y': y},
1667+
outputs={'Out': out},
1668+
attrs={'transpose_X': transpose_x,
1669+
'transpose_Y': transpose_y})
1670+
return out

python/paddle/v2/fluid/nets.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"simple_img_conv_pool",
1818
"sequence_conv_pool",
1919
"glu",
20+
"dot_product_attention",
2021
]
2122

2223

@@ -150,3 +151,55 @@ def glu(input, dim=-1):
150151
act_b = layers.sigmoid(x=b)
151152
out = layers.elementwise_mul(x=a, y=act_b)
152153
return out
154+
155+
156+
def dot_product_attention(querys, keys, values):
157+
"""
158+
The dot-product attention.
159+
160+
Attention mechanism can be seen as mapping a query and a set of key-value
161+
pairs to an output. The output is computed as a weighted sum of the values,
162+
where the weight assigned to each value is computed by a compatibility
163+
function (dot-product here) of the query with the corresponding key.
164+
165+
The dot-product attention can be implemented through (batch) matrix
166+
multipication as follows:
167+
168+
.. math::
169+
170+
Attention(Q, K, V)= softmax(QK^\mathrm{T})V
171+
172+
Refer to `Attention Is All You Need
173+
<https://arxiv.org/pdf/1706.03762.pdf>`_.
174+
175+
Note that batch data containing sequences with different lengths is not
176+
supported by this because of the (batch) matrix multipication.
177+
178+
Args:
179+
query (Variable): The input variable which is a Tensor or LoDTensor.
180+
key (Variable): The input variable which is a Tensor or LoDTensor.
181+
value (Variable): The input variable which is a Tensor or LoDTensor.
182+
183+
Returns:
184+
tuple: The Tensor variables representing the output and attention scores.
185+
186+
Examples:
187+
.. code-block:: python
188+
189+
# Suppose q, k, v are tensor variables with the following shape:
190+
# q: [3, 5, 9], k: [3, 6, 9], v: [3, 6, 10]
191+
out, attn_scores = fluid.nets.dot_product_attention(q, k, v)
192+
out.shape # [3, 5, 10]
193+
attn_scores.shape # [3, 5, 6]
194+
"""
195+
assert keys.shape[-2] == values.shape[
196+
-2], 'The shapes of keys and values mismatch.'
197+
assert querys.shape[-1] == keys.shape[
198+
-1], 'The shapes of querys and keys mismatch.'
199+
product = layers.matmul(x=querys, y=keys, transpose_y=True)
200+
attn_scores = layers.reshape(
201+
x=layers.reshape(
202+
x=product, shape=[-1, product.shape[-1]], act='softmax'),
203+
shape=product.shape)
204+
out = layers.matmul(attn_scores, values)
205+
return out, attn_scores

python/paddle/v2/fluid/tests/test_matmul_op.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,18 +96,18 @@ def setUp(self):
9696
self.outputs = {'Out': Out}
9797

9898
def test_check_output(self):
99-
self.check_output(atol=1e-2)
99+
self.check_output(atol=1e-3)
100100

101101
def test_check_grad_normal(self):
102-
self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.5)
102+
self.check_grad(['X', 'Y'], 'Out', max_relative_error=1e-3)
103103

104104
def test_check_grad_ignore_x(self):
105105
self.check_grad(
106-
['Y'], 'Out', max_relative_error=0.5, no_grad_set=set("X"))
106+
['Y'], 'Out', max_relative_error=1e-3, no_grad_set=set("X"))
107107

108108
def test_check_grad_ignore_y(self):
109109
self.check_grad(
110-
['X'], 'Out', max_relative_error=0.5, no_grad_set=set('Y'))
110+
['X'], 'Out', max_relative_error=1e-3, no_grad_set=set('Y'))
111111

112112

113113
# Generate test cases for all possibilities

0 commit comments

Comments
 (0)