21
21
from .. import PretrainedModel , register_base_model
22
22
23
23
__all__ = [
24
- "ConvBertModel" , "ConvBertPretrainedModel" , "ConvBertForTotalPretraining" ,
25
- "ConvBertDiscriminator" , "ConvBertGenerator" , "ConvBertClassificationHead" ,
26
- "ConvBertForSequenceClassification" , "ConvBertForTokenClassification" ,
27
- "ConvBertPretrainingCriterion" , "ConvBertForQuestionAnswering" ,
28
- "ConvBertForMultipleChoice" , "ConvBertForPretraining"
24
+ "ConvBertModel" ,
25
+ "ConvBertPretrainedModel" ,
26
+ "ConvBertForTotalPretraining" ,
27
+ "ConvBertDiscriminator" ,
28
+ "ConvBertGenerator" ,
29
+ "ConvBertClassificationHead" ,
30
+ "ConvBertForSequenceClassification" ,
31
+ "ConvBertForTokenClassification" ,
32
+ "ConvBertPretrainingCriterion" ,
33
+ "ConvBertForQuestionAnswering" ,
34
+ "ConvBertForMultipleChoice" ,
35
+ "ConvBertForPretraining" ,
29
36
]
30
37
dtype_float = paddle .get_default_dtype ()
31
38
@@ -115,7 +122,8 @@ def __init__(
115
122
self .need_weights = need_weights
116
123
self .head_dim = embed_dim // num_heads
117
124
self .scale = self .head_dim ** - 0.5
118
- assert self .head_dim * num_heads == self .embed_dim , "embed_dim must be divisible by num_heads"
125
+ assert self .head_dim * \
126
+ num_heads == self .embed_dim , "embed_dim must be divisible by num_heads"
119
127
120
128
new_num_attention_heads = num_heads // head_ratio
121
129
if num_heads // head_ratio < 1 :
@@ -140,9 +148,7 @@ def __init__(
140
148
self .conv_kernel_layer = nn .Linear (
141
149
self .all_head_size , self .num_heads * self .conv_kernel_size )
142
150
self .conv_out_layer = nn .Linear (embed_dim , self .all_head_size )
143
- self .unfold = nn .Unfold (
144
- kernel_sizes = [self .conv_kernel_size , 1 ],
145
- paddings = [(self .conv_kernel_size - 1 ) // 2 , 0 ], )
151
+ self .padding = (self .conv_kernel_size - 1 ) // 2
146
152
147
153
def forward (self , query , key = None , value = None , attn_mask = None , cache = None ):
148
154
key = query if key is None else key
@@ -153,28 +159,34 @@ def forward(self, query, key=None, value=None, attn_mask=None, cache=None):
153
159
v = self .v_proj (value )
154
160
155
161
if self .conv_type == "sdconv" :
162
+ bs = paddle .shape (q )[0 ]
163
+ seqlen = paddle .shape (q )[1 ]
156
164
mixed_key_conv_attn_layer = self .key_conv_attn_layer (query )
157
165
conv_attn_layer = mixed_key_conv_attn_layer * q
158
- batch_size = q . shape [ 0 ]
166
+
159
167
# conv_kernel_layer
160
168
conv_kernel_layer = self .conv_kernel_layer (conv_attn_layer )
161
169
conv_kernel_layer = tensor .reshape (
162
170
conv_kernel_layer , shape = [- 1 , self .conv_kernel_size , 1 ])
163
171
conv_kernel_layer = F .softmax (conv_kernel_layer , axis = 1 )
164
- # conv_out
165
172
conv_out_layer = self .conv_out_layer (query )
166
- conv_out_layer = tensor .reshape (
167
- conv_out_layer , [batch_size , - 1 , self .all_head_size , 1 ])
168
- conv_out_layer = tensor .transpose (conv_out_layer , perm = [0 , 2 , 1 , 3 ])
169
- conv_out_layer = self .unfold (conv_out_layer )
170
- conv_out_layer = tensor .transpose (conv_out_layer , perm = [0 , 2 , 1 ])
173
+ conv_out_layer = F .pad (conv_out_layer ,
174
+ pad = [self .padding , self .padding ],
175
+ data_format = "NLC" )
176
+ conv_out_layer = paddle .stack (
177
+ [
178
+ paddle .slice (
179
+ conv_out_layer , axes = [1 ], starts = [i ],
180
+ ends = [i + seqlen ]) for i in range (self .conv_kernel_size )
181
+ ],
182
+ axis = - 1 )
171
183
conv_out_layer = tensor .reshape (
172
184
conv_out_layer ,
173
185
shape = [- 1 , self .head_dim , self .conv_kernel_size ])
174
186
conv_out_layer = tensor .matmul (conv_out_layer , conv_kernel_layer )
175
187
conv_out = tensor .reshape (
176
188
conv_out_layer ,
177
- shape = [batch_size , - 1 , self .num_heads , self .head_dim ])
189
+ shape = [bs , seqlen , self .num_heads , self .head_dim ])
178
190
179
191
q = tensor .reshape (x = q , shape = [0 , 0 , self .num_heads , self .head_dim ])
180
192
q = tensor .transpose (x = q , perm = [0 , 2 , 1 , 3 ])
0 commit comments