|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 |
| - |
| 14 | +import math |
15 | 15 |
|
16 | 16 | from activations import LinearActivation, ReluActivation, SoftmaxActivation, \
|
17 | 17 | IdentityActivation, TanhActivation, SequenceSoftmaxActivation
|
|
26 | 26 | 'sequence_conv_pool', 'simple_lstm', "simple_img_conv_pool",
|
27 | 27 | "img_conv_bn_pool", 'lstmemory_group', 'lstmemory_unit', 'small_vgg',
|
28 | 28 | 'img_conv_group', 'vgg_16_network', 'gru_unit', 'gru_group', 'simple_gru',
|
29 |
| - 'simple_attention', 'dot_product_attention', 'simple_gru2', |
30 |
| - 'bidirectional_gru', 'text_conv_pool', 'bidirectional_lstm', 'inputs', |
31 |
| - 'outputs' |
| 29 | + 'simple_attention', 'dot_product_attention', 'multi_head_attention', |
| 30 | + 'simple_gru2', 'bidirectional_gru', 'text_conv_pool', 'bidirectional_lstm', |
| 31 | + 'inputs', 'outputs' |
32 | 32 | ]
|
33 | 33 |
|
34 | 34 | ######################################################
|
@@ -1496,6 +1496,134 @@ def dot_product_attention(encoded_sequence,
|
1496 | 1496 | input=scaled, pooling_type=SumPooling(), name="%s_pooling" % name)
|
1497 | 1497 |
|
1498 | 1498 |
|
| 1499 | +@wrap_name_default() |
| 1500 | +def multi_head_attention(query, |
| 1501 | + key, |
| 1502 | + value, |
| 1503 | + key_proj_size, |
| 1504 | + value_proj_size, |
| 1505 | + head_num, |
| 1506 | + attention_type, |
| 1507 | + softmax_param_attr=None, |
| 1508 | + name=None): |
| 1509 | + """ |
| 1510 | + Calculate and return a context vector with dot-product attention mechanism. |
| 1511 | + The dimension of the context vector equals to value_proj_size * head_num. |
| 1512 | +
|
| 1513 | + Please refer to **Attention Is All You Need** for more details. The link is |
| 1514 | + as follows: |
| 1515 | + https://arxiv.org/abs/1706.03762. |
| 1516 | +
|
| 1517 | + The example usage is: |
| 1518 | +
|
| 1519 | + .. code-block:: python |
| 1520 | +
|
| 1521 | + context = multi_head_attention(query=decoder_state, |
| 1522 | + key=enc_seq, |
| 1523 | + value=enc_seq, |
| 1524 | + key_proj_size=64, |
| 1525 | + value_pro_size=64, |
| 1526 | + head_num=8, |
| 1527 | + attention_type='dot-product attention') |
| 1528 | +
|
| 1529 | + :param name: A prefix attached to the name of each layer that defined inside |
| 1530 | + the multi_head_attention. |
| 1531 | + :type name: basestring |
| 1532 | + :param softmax_param_attr: The parameter attribute of sequence softmax |
| 1533 | + that is used to produce attention weight. |
| 1534 | + :type softmax_param_attr: ParameterAttribute |
| 1535 | + :param query: query is used to calculate attention weights over values at current step. |
| 1536 | + :type query: LayerOutput |
| 1537 | + :param key: key is used to calculate the attention weight of the corresponding value. |
| 1538 | + :type key: LayerOutput |
| 1539 | + :param value: value is the sequence to be attended. |
| 1540 | + :type value: LayerOutput |
| 1541 | + :param key_proj_size: The dimension of the linear projection performed on key and query. |
| 1542 | + :type key_proj_size: int |
| 1543 | + :param value_proj_size: The dimension of the linear projection performed on value. |
| 1544 | + :type value_proj_size: int |
| 1545 | + :param head_num: The number of attention heads. |
| 1546 | + :type head_num: int |
| 1547 | + :param attention_type: The type of the attention mechanism used in each attention |
| 1548 | + heads. Now, we only support scaled dot-product attention and |
| 1549 | + additive attention. |
| 1550 | + :type attention_type: basestring |
| 1551 | + :return: The context vector. |
| 1552 | + :rtype: LayerOutput |
| 1553 | + """ |
| 1554 | + assert attention_type in ['dot-product attention', 'additive attention'] |
| 1555 | + |
| 1556 | + with mixed_layer( |
| 1557 | + size=key_proj_size * head_num, |
| 1558 | + name='%s_query_proj' % name) as query_proj: |
| 1559 | + query_proj += full_matrix_projection(query) |
| 1560 | + query_proj = expand_layer(input=query_proj, expand_as=key) |
| 1561 | + |
| 1562 | + with mixed_layer( |
| 1563 | + size=key_proj_size * head_num, |
| 1564 | + name='%s_key_proj' % name) as key_proj: |
| 1565 | + key_proj += full_matrix_projection(key) |
| 1566 | + |
| 1567 | + with mixed_layer( |
| 1568 | + size=value_proj_size * head_num, |
| 1569 | + name='%s_value_proj' % name) as value_proj: |
| 1570 | + value_proj += full_matrix_projection(value) |
| 1571 | + |
| 1572 | + head_list = [] |
| 1573 | + for i in range(head_num): |
| 1574 | + with mixed_layer(size=key_proj_size) as sub_query_proj: |
| 1575 | + sub_query_proj += identity_projection( |
| 1576 | + query_proj, offset=key_proj_size * i, size=key_proj_size) |
| 1577 | + |
| 1578 | + with mixed_layer(size=key_proj_size) as sub_key_proj: |
| 1579 | + sub_key_proj += identity_projection( |
| 1580 | + key_proj, offset=key_proj_size * i, size=key_proj_size) |
| 1581 | + |
| 1582 | + with mixed_layer(size=value_proj_size) as sub_value_proj: |
| 1583 | + sub_value_proj += identity_projection( |
| 1584 | + value_proj, offset=value_proj_size * i, size=value_proj_size) |
| 1585 | + |
| 1586 | + if attention_type == 'dot-product attention': |
| 1587 | + m = dot_prod_layer( |
| 1588 | + input1=sub_query_proj, |
| 1589 | + input2=sub_key_proj, |
| 1590 | + name='%s_dot-product_%d' % (name, i)) |
| 1591 | + m = slope_intercept_layer( |
| 1592 | + input=m, |
| 1593 | + slope=math.sqrt(1.0 / key_proj_size), |
| 1594 | + name='%s_dot-product_scaling_%d' % (name, i)) |
| 1595 | + else: |
| 1596 | + with mixed_layer( |
| 1597 | + size=key_proj_size, |
| 1598 | + act=TanhActivation(), |
| 1599 | + name='%s_combine_%d' % (name, i)) as m: |
| 1600 | + m += identity_projection(sub_query_proj) |
| 1601 | + m += identity_projection(sub_key_proj) |
| 1602 | + |
| 1603 | + attention_weight = fc_layer( |
| 1604 | + input=m, |
| 1605 | + size=1, |
| 1606 | + act=SequenceSoftmaxActivation(), |
| 1607 | + param_attr=softmax_param_attr, |
| 1608 | + name="%s_softmax_%d" % (name, i), |
| 1609 | + bias_attr=False) |
| 1610 | + |
| 1611 | + scaled = scaling_layer( |
| 1612 | + weight=attention_weight, |
| 1613 | + input=sub_value_proj, |
| 1614 | + name='%s_scaling_%d' % (name, i)) |
| 1615 | + head = pooling_layer( |
| 1616 | + input=scaled, |
| 1617 | + pooling_type=SumPooling(), |
| 1618 | + name="%s_pooling_%d" % (name, i)) |
| 1619 | + |
| 1620 | + head_list.append(head) |
| 1621 | + |
| 1622 | + attended = concat_layer(head_list) |
| 1623 | + |
| 1624 | + return attended |
| 1625 | + |
| 1626 | + |
1499 | 1627 | def inputs(layers, *args):
|
1500 | 1628 | """
|
1501 | 1629 | Declare the inputs of network. The order of input should be as same as
|
|
0 commit comments