Skip to content

Commit 8139863

Browse files
JunnYuyingyibiao
andauthored
update MultiHeadAttentionWithConv (#1643)
Co-authored-by: yingyibiao <[email protected]>
1 parent 7b8dd49 commit 8139863

File tree

1 file changed

+29
-17
lines changed

1 file changed

+29
-17
lines changed

paddlenlp/transformers/convbert/modeling.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,18 @@
2121
from .. import PretrainedModel, register_base_model
2222

2323
__all__ = [
24-
"ConvBertModel", "ConvBertPretrainedModel", "ConvBertForTotalPretraining",
25-
"ConvBertDiscriminator", "ConvBertGenerator", "ConvBertClassificationHead",
26-
"ConvBertForSequenceClassification", "ConvBertForTokenClassification",
27-
"ConvBertPretrainingCriterion", "ConvBertForQuestionAnswering",
28-
"ConvBertForMultipleChoice", "ConvBertForPretraining"
24+
"ConvBertModel",
25+
"ConvBertPretrainedModel",
26+
"ConvBertForTotalPretraining",
27+
"ConvBertDiscriminator",
28+
"ConvBertGenerator",
29+
"ConvBertClassificationHead",
30+
"ConvBertForSequenceClassification",
31+
"ConvBertForTokenClassification",
32+
"ConvBertPretrainingCriterion",
33+
"ConvBertForQuestionAnswering",
34+
"ConvBertForMultipleChoice",
35+
"ConvBertForPretraining",
2936
]
3037
dtype_float = paddle.get_default_dtype()
3138

@@ -115,7 +122,8 @@ def __init__(
115122
self.need_weights = need_weights
116123
self.head_dim = embed_dim // num_heads
117124
self.scale = self.head_dim**-0.5
118-
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
125+
assert self.head_dim * \
126+
num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
119127

120128
new_num_attention_heads = num_heads // head_ratio
121129
if num_heads // head_ratio < 1:
@@ -140,9 +148,7 @@ def __init__(
140148
self.conv_kernel_layer = nn.Linear(
141149
self.all_head_size, self.num_heads * self.conv_kernel_size)
142150
self.conv_out_layer = nn.Linear(embed_dim, self.all_head_size)
143-
self.unfold = nn.Unfold(
144-
kernel_sizes=[self.conv_kernel_size, 1],
145-
paddings=[(self.conv_kernel_size - 1) // 2, 0], )
151+
self.padding = (self.conv_kernel_size - 1) // 2
146152

147153
def forward(self, query, key=None, value=None, attn_mask=None, cache=None):
148154
key = query if key is None else key
@@ -153,28 +159,34 @@ def forward(self, query, key=None, value=None, attn_mask=None, cache=None):
153159
v = self.v_proj(value)
154160

155161
if self.conv_type == "sdconv":
162+
bs = paddle.shape(q)[0]
163+
seqlen = paddle.shape(q)[1]
156164
mixed_key_conv_attn_layer = self.key_conv_attn_layer(query)
157165
conv_attn_layer = mixed_key_conv_attn_layer * q
158-
batch_size = q.shape[0]
166+
159167
# conv_kernel_layer
160168
conv_kernel_layer = self.conv_kernel_layer(conv_attn_layer)
161169
conv_kernel_layer = tensor.reshape(
162170
conv_kernel_layer, shape=[-1, self.conv_kernel_size, 1])
163171
conv_kernel_layer = F.softmax(conv_kernel_layer, axis=1)
164-
# conv_out
165172
conv_out_layer = self.conv_out_layer(query)
166-
conv_out_layer = tensor.reshape(
167-
conv_out_layer, [batch_size, -1, self.all_head_size, 1])
168-
conv_out_layer = tensor.transpose(conv_out_layer, perm=[0, 2, 1, 3])
169-
conv_out_layer = self.unfold(conv_out_layer)
170-
conv_out_layer = tensor.transpose(conv_out_layer, perm=[0, 2, 1])
173+
conv_out_layer = F.pad(conv_out_layer,
174+
pad=[self.padding, self.padding],
175+
data_format="NLC")
176+
conv_out_layer = paddle.stack(
177+
[
178+
paddle.slice(
179+
conv_out_layer, axes=[1], starts=[i],
180+
ends=[i + seqlen]) for i in range(self.conv_kernel_size)
181+
],
182+
axis=-1)
171183
conv_out_layer = tensor.reshape(
172184
conv_out_layer,
173185
shape=[-1, self.head_dim, self.conv_kernel_size])
174186
conv_out_layer = tensor.matmul(conv_out_layer, conv_kernel_layer)
175187
conv_out = tensor.reshape(
176188
conv_out_layer,
177-
shape=[batch_size, -1, self.num_heads, self.head_dim])
189+
shape=[bs, seqlen, self.num_heads, self.head_dim])
178190

179191
q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
180192
q = tensor.transpose(x=q, perm=[0, 2, 1, 3])

0 commit comments

Comments
 (0)