Skip to content

Commit 56ec40a

Browse files
authored
Merge pull request #4924 from ranqiu92/attention
Add the configuration helper for multi-head attention.
2 parents 01d6ccb + f224029 commit 56ec40a

File tree

1 file changed

+132
-4
lines changed

1 file changed

+132
-4
lines changed

python/paddle/trainer_config_helpers/networks.py

Lines changed: 132 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
14+
import math
1515

1616
from activations import LinearActivation, ReluActivation, SoftmaxActivation, \
1717
IdentityActivation, TanhActivation, SequenceSoftmaxActivation
@@ -26,9 +26,9 @@
2626
'sequence_conv_pool', 'simple_lstm', "simple_img_conv_pool",
2727
"img_conv_bn_pool", 'lstmemory_group', 'lstmemory_unit', 'small_vgg',
2828
'img_conv_group', 'vgg_16_network', 'gru_unit', 'gru_group', 'simple_gru',
29-
'simple_attention', 'dot_product_attention', 'simple_gru2',
30-
'bidirectional_gru', 'text_conv_pool', 'bidirectional_lstm', 'inputs',
31-
'outputs'
29+
'simple_attention', 'dot_product_attention', 'multi_head_attention',
30+
'simple_gru2', 'bidirectional_gru', 'text_conv_pool', 'bidirectional_lstm',
31+
'inputs', 'outputs'
3232
]
3333

3434
######################################################
@@ -1496,6 +1496,134 @@ def dot_product_attention(encoded_sequence,
14961496
input=scaled, pooling_type=SumPooling(), name="%s_pooling" % name)
14971497

14981498

1499+
@wrap_name_default()
1500+
def multi_head_attention(query,
1501+
key,
1502+
value,
1503+
key_proj_size,
1504+
value_proj_size,
1505+
head_num,
1506+
attention_type,
1507+
softmax_param_attr=None,
1508+
name=None):
1509+
"""
1510+
Calculate and return a context vector with dot-product attention mechanism.
1511+
The dimension of the context vector equals to value_proj_size * head_num.
1512+
1513+
Please refer to **Attention Is All You Need** for more details. The link is
1514+
as follows:
1515+
https://arxiv.org/abs/1706.03762.
1516+
1517+
The example usage is:
1518+
1519+
.. code-block:: python
1520+
1521+
context = multi_head_attention(query=decoder_state,
1522+
key=enc_seq,
1523+
value=enc_seq,
1524+
key_proj_size=64,
1525+
value_pro_size=64,
1526+
head_num=8,
1527+
attention_type='dot-product attention')
1528+
1529+
:param name: A prefix attached to the name of each layer that defined inside
1530+
the multi_head_attention.
1531+
:type name: basestring
1532+
:param softmax_param_attr: The parameter attribute of sequence softmax
1533+
that is used to produce attention weight.
1534+
:type softmax_param_attr: ParameterAttribute
1535+
:param query: query is used to calculate attention weights over values at current step.
1536+
:type query: LayerOutput
1537+
:param key: key is used to calculate the attention weight of the corresponding value.
1538+
:type key: LayerOutput
1539+
:param value: value is the sequence to be attended.
1540+
:type value: LayerOutput
1541+
:param key_proj_size: The dimension of the linear projection performed on key and query.
1542+
:type key_proj_size: int
1543+
:param value_proj_size: The dimension of the linear projection performed on value.
1544+
:type value_proj_size: int
1545+
:param head_num: The number of attention heads.
1546+
:type head_num: int
1547+
:param attention_type: The type of the attention mechanism used in each attention
1548+
heads. Now, we only support scaled dot-product attention and
1549+
additive attention.
1550+
:type attention_type: basestring
1551+
:return: The context vector.
1552+
:rtype: LayerOutput
1553+
"""
1554+
assert attention_type in ['dot-product attention', 'additive attention']
1555+
1556+
with mixed_layer(
1557+
size=key_proj_size * head_num,
1558+
name='%s_query_proj' % name) as query_proj:
1559+
query_proj += full_matrix_projection(query)
1560+
query_proj = expand_layer(input=query_proj, expand_as=key)
1561+
1562+
with mixed_layer(
1563+
size=key_proj_size * head_num,
1564+
name='%s_key_proj' % name) as key_proj:
1565+
key_proj += full_matrix_projection(key)
1566+
1567+
with mixed_layer(
1568+
size=value_proj_size * head_num,
1569+
name='%s_value_proj' % name) as value_proj:
1570+
value_proj += full_matrix_projection(value)
1571+
1572+
head_list = []
1573+
for i in range(head_num):
1574+
with mixed_layer(size=key_proj_size) as sub_query_proj:
1575+
sub_query_proj += identity_projection(
1576+
query_proj, offset=key_proj_size * i, size=key_proj_size)
1577+
1578+
with mixed_layer(size=key_proj_size) as sub_key_proj:
1579+
sub_key_proj += identity_projection(
1580+
key_proj, offset=key_proj_size * i, size=key_proj_size)
1581+
1582+
with mixed_layer(size=value_proj_size) as sub_value_proj:
1583+
sub_value_proj += identity_projection(
1584+
value_proj, offset=value_proj_size * i, size=value_proj_size)
1585+
1586+
if attention_type == 'dot-product attention':
1587+
m = dot_prod_layer(
1588+
input1=sub_query_proj,
1589+
input2=sub_key_proj,
1590+
name='%s_dot-product_%d' % (name, i))
1591+
m = slope_intercept_layer(
1592+
input=m,
1593+
slope=math.sqrt(1.0 / key_proj_size),
1594+
name='%s_dot-product_scaling_%d' % (name, i))
1595+
else:
1596+
with mixed_layer(
1597+
size=key_proj_size,
1598+
act=TanhActivation(),
1599+
name='%s_combine_%d' % (name, i)) as m:
1600+
m += identity_projection(sub_query_proj)
1601+
m += identity_projection(sub_key_proj)
1602+
1603+
attention_weight = fc_layer(
1604+
input=m,
1605+
size=1,
1606+
act=SequenceSoftmaxActivation(),
1607+
param_attr=softmax_param_attr,
1608+
name="%s_softmax_%d" % (name, i),
1609+
bias_attr=False)
1610+
1611+
scaled = scaling_layer(
1612+
weight=attention_weight,
1613+
input=sub_value_proj,
1614+
name='%s_scaling_%d' % (name, i))
1615+
head = pooling_layer(
1616+
input=scaled,
1617+
pooling_type=SumPooling(),
1618+
name="%s_pooling_%d" % (name, i))
1619+
1620+
head_list.append(head)
1621+
1622+
attended = concat_layer(head_list)
1623+
1624+
return attended
1625+
1626+
14991627
def inputs(layers, *args):
15001628
"""
15011629
Declare the inputs of network. The order of input should be as same as

0 commit comments

Comments
 (0)