11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
-
14
+ import math
15
15
16
16
from activations import LinearActivation , ReluActivation , SoftmaxActivation , \
17
17
IdentityActivation , TanhActivation , SequenceSoftmaxActivation
26
26
'sequence_conv_pool' , 'simple_lstm' , "simple_img_conv_pool" ,
27
27
"img_conv_bn_pool" , 'lstmemory_group' , 'lstmemory_unit' , 'small_vgg' ,
28
28
'img_conv_group' , 'vgg_16_network' , 'gru_unit' , 'gru_group' , 'simple_gru' ,
29
- 'simple_attention' , 'dot_product_attention' , 'simple_gru2 ' ,
30
- 'bidirectional_gru ' , 'text_conv_pool ' , 'bidirectional_lstm ' , 'inputs ' ,
31
- 'outputs'
29
+ 'simple_attention' , 'dot_product_attention' , 'multi_head_attention ' ,
30
+ 'simple_gru2 ' , 'bidirectional_gru ' , 'text_conv_pool ' , 'bidirectional_lstm ' ,
31
+ 'inputs' , ' outputs'
32
32
]
33
33
34
34
######################################################
@@ -1476,10 +1476,8 @@ def dot_product_attention(encoded_sequence,
1476
1476
expand_as = encoded_sequence ,
1477
1477
name = '%s_expand' % name )
1478
1478
1479
- m = linear_comb_layer (
1480
- weights = expanded ,
1481
- vectors = encoded_sequence ,
1482
- name = '%s_dot-product' % name )
1479
+ m = dot_prod_layer (
1480
+ input1 = expanded , input2 = encoded_sequence , name = '%s_dot-product' % name )
1483
1481
1484
1482
attention_weight = fc_layer (
1485
1483
input = m ,
@@ -1498,6 +1496,134 @@ def dot_product_attention(encoded_sequence,
1498
1496
input = scaled , pooling_type = SumPooling (), name = "%s_pooling" % name )
1499
1497
1500
1498
1499
+ @wrap_name_default ()
1500
+ def multi_head_attention (query ,
1501
+ key ,
1502
+ value ,
1503
+ key_proj_size ,
1504
+ value_proj_size ,
1505
+ head_num ,
1506
+ attention_type ,
1507
+ softmax_param_attr = None ,
1508
+ name = None ):
1509
+ """
1510
+ Calculate and return a context vector with dot-product attention mechanism.
1511
+ The dimension of the context vector equals to value_proj_size * head_num.
1512
+
1513
+ Please refer to **Attention Is All You Need** for more details. The link is
1514
+ as follows:
1515
+ https://arxiv.org/abs/1706.03762.
1516
+
1517
+ The example usage is:
1518
+
1519
+ .. code-block:: python
1520
+
1521
+ context = multi_head_attention(query=decoder_state,
1522
+ key=enc_seq,
1523
+ value=enc_seq,
1524
+ key_proj_size=64,
1525
+ value_pro_size=64,
1526
+ head_num=8,
1527
+ attention_type='dot-product attention')
1528
+
1529
+ :param name: A prefix attached to the name of each layer that defined inside
1530
+ the multi_head_attention.
1531
+ :type name: basestring
1532
+ :param softmax_param_attr: The parameter attribute of sequence softmax
1533
+ that is used to produce attention weight.
1534
+ :type softmax_param_attr: ParameterAttribute
1535
+ :param query: query is used to calculate attention weights over values at current step.
1536
+ :type query: LayerOutput
1537
+ :param key: key is used to calculate the attention weight of the corresponding value.
1538
+ :type key: LayerOutput
1539
+ :param value: value is the sequence to be attended.
1540
+ :type value: LayerOutput
1541
+ :param key_proj_size: The dimension of the linear projection performed on key and query.
1542
+ :type key_proj_size: int
1543
+ :param value_proj_size: The dimension of the linear projection performed on value.
1544
+ :type value_proj_size: int
1545
+ :param head_num: The number of attention heads.
1546
+ :type head_num: int
1547
+ :param attention_type: The type of the attention mechanism used in each attention
1548
+ heads. Now, we only support scaled dot-product attention and
1549
+ additive attention.
1550
+ :type attention_type: basestring
1551
+ :return: The context vector.
1552
+ :rtype: LayerOutput
1553
+ """
1554
+ assert attention_type in ['dot-product attention' , 'additive attention' ]
1555
+
1556
+ with mixed_layer (
1557
+ size = key_proj_size * head_num ,
1558
+ name = '%s_query_proj' % name ) as query_proj :
1559
+ query_proj += full_matrix_projection (query )
1560
+ query_proj = expand_layer (input = query_proj , expand_as = key )
1561
+
1562
+ with mixed_layer (
1563
+ size = key_proj_size * head_num ,
1564
+ name = '%s_key_proj' % name ) as key_proj :
1565
+ key_proj += full_matrix_projection (key )
1566
+
1567
+ with mixed_layer (
1568
+ size = value_proj_size * head_num ,
1569
+ name = '%s_value_proj' % name ) as value_proj :
1570
+ value_proj += full_matrix_projection (value )
1571
+
1572
+ head_list = []
1573
+ for i in range (head_num ):
1574
+ with mixed_layer (size = key_proj_size ) as sub_query_proj :
1575
+ sub_query_proj += identity_projection (
1576
+ query_proj , offset = key_proj_size * i , size = key_proj_size )
1577
+
1578
+ with mixed_layer (size = key_proj_size ) as sub_key_proj :
1579
+ sub_key_proj += identity_projection (
1580
+ key_proj , offset = key_proj_size * i , size = key_proj_size )
1581
+
1582
+ with mixed_layer (size = value_proj_size ) as sub_value_proj :
1583
+ sub_value_proj += identity_projection (
1584
+ value_proj , offset = value_proj_size * i , size = value_proj_size )
1585
+
1586
+ if attention_type == 'dot-product attention' :
1587
+ m = dot_prod_layer (
1588
+ input1 = sub_query_proj ,
1589
+ input2 = sub_key_proj ,
1590
+ name = '%s_dot-product_%d' % (name , i ))
1591
+ m = slope_intercept_layer (
1592
+ input = m ,
1593
+ slope = math .sqrt (1.0 / key_proj_size ),
1594
+ name = '%s_dot-product_scaling_%d' % (name , i ))
1595
+ else :
1596
+ with mixed_layer (
1597
+ size = key_proj_size ,
1598
+ act = TanhActivation (),
1599
+ name = '%s_combine_%d' % (name , i )) as m :
1600
+ m += identity_projection (sub_query_proj )
1601
+ m += identity_projection (sub_key_proj )
1602
+
1603
+ attention_weight = fc_layer (
1604
+ input = m ,
1605
+ size = 1 ,
1606
+ act = SequenceSoftmaxActivation (),
1607
+ param_attr = softmax_param_attr ,
1608
+ name = "%s_softmax_%d" % (name , i ),
1609
+ bias_attr = False )
1610
+
1611
+ scaled = scaling_layer (
1612
+ weight = attention_weight ,
1613
+ input = sub_value_proj ,
1614
+ name = '%s_scaling_%d' % (name , i ))
1615
+ head = pooling_layer (
1616
+ input = scaled ,
1617
+ pooling_type = SumPooling (),
1618
+ name = "%s_pooling_%d" % (name , i ))
1619
+
1620
+ head_list .append (head )
1621
+
1622
+ attended = concat_layer (head_list )
1623
+
1624
+ return attended
1625
+
1626
+
1501
1627
def inputs (layers , * args ):
1502
1628
"""
1503
1629
Declare the inputs of network. The order of input should be as same as
0 commit comments