@@ -1557,15 +1557,15 @@ def multi_head_attention(query,
1557
1557
for i in range (head_num ):
1558
1558
with mixed_layer (size = key_proj_size ) as sub_query_proj :
1559
1559
sub_query_proj += identity_projection (
1560
- query_proj , offset = key_proj_size * i )
1560
+ query_proj , offset = key_proj_size * i , size = key_proj_size )
1561
1561
1562
1562
with mixed_layer (size = key_proj_size ) as sub_key_proj :
1563
1563
sub_key_proj += identity_projection (
1564
- key_proj , offset = key_proj_size * i )
1564
+ key_proj , offset = key_proj_size * i , size = key_proj_size )
1565
1565
1566
1566
with mixed_layer (size = value_proj_size ) as sub_value_proj :
1567
1567
sub_value_proj += identity_projection (
1568
- value_proj , offset = value_proj_size * i )
1568
+ value_proj , offset = value_proj_size * i , size = value_proj_size )
1569
1569
1570
1570
if attention_type == 'dot-product attention' :
1571
1571
m = linear_comb_layer (
@@ -1603,11 +1603,7 @@ def multi_head_attention(query,
1603
1603
1604
1604
head_list .append (head )
1605
1605
1606
- multi_head = concat_layer (head_list )
1607
-
1608
- with mixed_layer (
1609
- size = value_proj_size * head_num , name = '%s_proj' % name ) as attended :
1610
- attended += full_matrix_projection (multi_head )
1606
+ attended = concat_layer (head_list )
1611
1607
1612
1608
return attended
1613
1609
0 commit comments