Skip to content

Commit 7461b35

Browse files
author
ranqiu
committed
Refine multi-head attention
1 parent 947c528 commit 7461b35

File tree

1 file changed

+4
-8
lines changed

1 file changed

+4
-8
lines changed

python/paddle/trainer_config_helpers/networks.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1557,15 +1557,15 @@ def multi_head_attention(query,
15571557
for i in range(head_num):
15581558
with mixed_layer(size=key_proj_size) as sub_query_proj:
15591559
sub_query_proj += identity_projection(
1560-
query_proj, offset=key_proj_size * i)
1560+
query_proj, offset=key_proj_size * i, size=key_proj_size)
15611561

15621562
with mixed_layer(size=key_proj_size) as sub_key_proj:
15631563
sub_key_proj += identity_projection(
1564-
key_proj, offset=key_proj_size * i)
1564+
key_proj, offset=key_proj_size * i, size=key_proj_size)
15651565

15661566
with mixed_layer(size=value_proj_size) as sub_value_proj:
15671567
sub_value_proj += identity_projection(
1568-
value_proj, offset=value_proj_size * i)
1568+
value_proj, offset=value_proj_size * i, size=value_proj_size)
15691569

15701570
if attention_type == 'dot-product attention':
15711571
m = linear_comb_layer(
@@ -1603,11 +1603,7 @@ def multi_head_attention(query,
16031603

16041604
head_list.append(head)
16051605

1606-
multi_head = concat_layer(head_list)
1607-
1608-
with mixed_layer(
1609-
size=value_proj_size * head_num, name='%s_proj' % name) as attended:
1610-
attended += full_matrix_projection(multi_head)
1606+
attended = concat_layer(head_list)
16111607

16121608
return attended
16131609

0 commit comments

Comments
 (0)