19
19
import paddle .nn .functional as F
20
20
from paddle .nn import MultiHeadAttention , TransformerEncoderLayer , TransformerEncoder
21
21
from paddle .fluid .data_feeder import convert_dtype
22
+
22
23
from paddlenlp .utils .log import logger
23
- from paddlenlp .transformers import TinyBertForPretraining , TinyBertForSequenceClassification , BertForSequenceClassification
24
+ from paddlenlp .transformers import ErnieForSequenceClassification
25
+ from paddlenlp .transformers import TinyBertForPretraining
26
+ from paddlenlp .transformers import BertForSequenceClassification
27
+
28
+ __all__ = ['to_distill' , 'calc_minilm_loss' , 'calc_multi_relation_loss' ]
29
+
30
+
31
+ def calc_multi_relation_loss (loss_fct ,
32
+ s ,
33
+ t ,
34
+ attn_mask ,
35
+ num_relation_heads = 0 ,
36
+ alpha = 0.0 ,
37
+ beta = 0.0 ):
38
+ """
39
+ Calculates loss for multiple Q-Q, K-K and V-V relation. It supports
40
+ head-head relation, sample-sample relation and origin token-token relation.
41
+ The final loss value could be balanced by weight `alpha` and `beta`.
42
+
43
+ Args:
44
+ loss_fct (callable):
45
+ Loss function for distillation. It only supports kl_div loss now.
46
+ s (Tensor):
47
+ Q, K, V of Student.
48
+ t (Tensor):
49
+ Q, K, V of teacher.
50
+ attn_mask (Tensor):
51
+ Attention mask for relation.
52
+ num_relation_heads (int):
53
+ The number of relation heads. 0 means `num_relation_heads` equals
54
+ to origin head num.
55
+ Defaults to 0.
56
+ alpha (float):
57
+ The weight for head-head relation.
58
+ Defaults to 0.0.
59
+ beta (float):
60
+ The weight for sample-sample relation.
61
+ Defaults to 0.0.
62
+
63
+ Returns:
64
+ Tensor: Weighted loss of token-token loss, head-head loss and
65
+ sample-sample loss.
66
+
67
+ """
68
+ # Initialize head_num
69
+ if num_relation_heads > 0 and num_relation_heads != s .shape [1 ]:
70
+ # s'shape: [bs, seq_len, head_num, head_dim]
71
+ s = tensor .transpose (x = s , perm = [0 , 2 , 1 , 3 ])
72
+ # s'shape: [bs, seq_len, num_relation_heads, head_dim_new]
73
+ s = tensor .reshape (x = s , shape = [0 , 0 , num_relation_heads , - 1 ])
74
+ s1 = tensor .transpose (x = s , perm = [0 , 2 , 1 , 3 ])
75
+ if num_relation_heads > 0 and num_relation_heads != t .shape [1 ]:
76
+ t = tensor .transpose (x = t , perm = [0 , 2 , 1 , 3 ])
77
+ t = tensor .reshape (x = t , shape = [0 , 0 , num_relation_heads , - 1 ])
78
+ t1 = tensor .transpose (x = t , perm = [0 , 2 , 1 , 3 ])
79
+
80
+ s_head_dim , t_head_dim = s .shape [3 ], t .shape [3 ]
81
+
82
+ if alpha + beta == 1.0 :
83
+ loss_token_token = 0.0
84
+ else :
85
+ scaled_dot_product_s1 = tensor .matmul (
86
+ x = s1 , y = s1 , transpose_y = True ) / math .sqrt (s_head_dim )
87
+ del s1
88
+ scaled_dot_product_s1 += attn_mask
89
+ scaled_dot_product_t1 = tensor .matmul (
90
+ x = t1 , y = t1 , transpose_y = True ) / math .sqrt (t_head_dim )
91
+ del t1
92
+ scaled_dot_product_t1 += attn_mask
93
+ loss_token_token = loss_fct (
94
+ F .log_softmax (scaled_dot_product_s1 ),
95
+ F .softmax (scaled_dot_product_t1 ))
96
+
97
+ if alpha == 0.0 :
98
+ loss_head_head = 0.0
99
+ else :
100
+ scaled_dot_product_s = tensor .matmul (
101
+ x = s , y = s , transpose_y = True ) / math .sqrt (s_head_dim )
102
+ attn_mask_head_head = tensor .transpose (x = attn_mask , perm = [0 , 3 , 1 , 2 ])
103
+
104
+ scaled_dot_product_s += attn_mask_head_head
105
+ scaled_dot_product_t = tensor .matmul (
106
+ x = t , y = t , transpose_y = True ) / math .sqrt (t_head_dim )
107
+ scaled_dot_product_t += attn_mask_head_head
108
+ loss_head_head = loss_fct (
109
+ F .log_softmax (scaled_dot_product_s ),
110
+ F .softmax (scaled_dot_product_t ))
111
+ if beta == 0.0 :
112
+ loss_sample_sample = 0.0
113
+ else :
114
+ s2 = tensor .transpose (x = s , perm = [1 , 2 , 0 , 3 ])
115
+ scaled_dot_product_s2 = tensor .matmul (
116
+ x = s2 , y = s2 , transpose_y = True ) / math .sqrt (s_head_dim )
117
+
118
+ del s , s2
119
+ # Shape: [seq_len, 1, batch_size, 1]
120
+ attn_mask_sample_sample = tensor .transpose (
121
+ x = attn_mask , perm = [3 , 1 , 0 , 2 ])
24
122
25
- __all__ = ['to_distill' , 'calc_minilm_loss' ]
123
+ # Shape: [seq_len, head_num, batch_size, batch_size]
124
+ scaled_dot_product_s2 += attn_mask_sample_sample
125
+ t2 = tensor .transpose (x = t , perm = [1 , 2 , 0 , 3 ])
126
+ scaled_dot_product_t2 = tensor .matmul (
127
+ x = t2 , y = t2 , transpose_y = True ) / math .sqrt (t_head_dim )
128
+
129
+ del t , t2
130
+ scaled_dot_product_t2 += attn_mask_sample_sample
131
+ loss_sample_sample = loss_fct (
132
+ F .log_softmax (scaled_dot_product_s2 ),
133
+ F .softmax (scaled_dot_product_t2 ))
134
+
135
+ return (
136
+ 1 - alpha - beta
137
+ ) * loss_token_token + alpha * loss_head_head + beta * loss_sample_sample
26
138
27
139
28
140
def calc_minilm_loss (loss_fct , s , t , attn_mask , num_relation_heads = 0 ):
141
+ """
142
+ Calculates loss for Q-Q, K-K, V-V relation from MiniLMv2.
143
+
144
+ Args:
145
+ loss_fct (callable):
146
+ Loss function for distillation. It only supports kl_div loss now.
147
+ s (Tensor):
148
+ Q, K, V of Student.
149
+ t (Tensor):
150
+ Q, K, V of teacher.
151
+ attn_mask (Tensor):
152
+ Attention mask for relation.
153
+ num_relation_heads (int):
154
+ The number of relation heads. 0 means `num_relation_heads` equals
155
+ to origin head num.
156
+ Defaults to 0.
157
+
158
+ Returns:
159
+ Tensor: MiniLM loss value.
160
+
161
+ """
29
162
# Initialize head_num
30
163
if num_relation_heads > 0 and num_relation_heads != s .shape [1 ]:
31
164
# s'shape: [bs, seq_len, head_num, head_dim]
32
165
s = tensor .transpose (x = s , perm = [0 , 2 , 1 , 3 ])
33
166
# s'shape: [bs, seq_len, num_relation_heads, head_dim_new]
34
167
s = tensor .reshape (x = s , shape = [0 , 0 , num_relation_heads , - 1 ])
35
- #s's shape: [bs, num_relation_heads, seq_len, , head_dim_new]
168
+ # s' shape: [bs, num_relation_heads, seq_len, head_dim_new]
36
169
s = tensor .transpose (x = s , perm = [0 , 2 , 1 , 3 ])
37
170
if num_relation_heads > 0 and num_relation_heads != t .shape [1 ]:
38
171
t = tensor .transpose (x = t , perm = [0 , 2 , 1 , 3 ])
39
172
t = tensor .reshape (x = t , shape = [0 , 0 , num_relation_heads , - 1 ])
40
173
t = tensor .transpose (x = t , perm = [0 , 2 , 1 , 3 ])
41
174
42
- pad_seq_len = s .shape [2 ]
43
175
s_head_dim , t_head_dim = s .shape [3 ], t .shape [3 ]
44
176
scaled_dot_product_s = tensor .matmul (
45
177
x = s , y = s , transpose_y = True ) / math .sqrt (s_head_dim )
@@ -62,24 +194,31 @@ def to_distill(self,
62
194
layer_index = - 1 ):
63
195
"""
64
196
Can be bound to object with transformer encoder layers, and make model
65
- expose attributes `outputs.qs `, `outputs.ks `, `outputs.vs `,
197
+ expose attributes `outputs.q `, `outputs.k `, `outputs.v `,
66
198
`outputs.scaled_qks`, `outputs.hidden_states`and `outputs.attentions` of
67
199
the object for distillation.
200
+
201
+ It could be returned intermediate tensor using in MiniLM and TinyBERT
202
+ strategy.
68
203
"""
69
204
logger .warning ("`to_distill` is an experimental API and subject to change." )
70
205
MultiHeadAttention ._forward = attention_forward
71
206
TransformerEncoderLayer ._forward = transformer_encoder_layer_forward
72
207
TransformerEncoder ._forward = transformer_encoder_forward
73
208
BertForSequenceClassification ._forward = bert_forward
209
+
74
210
if return_qkv :
211
+ # forward function of student class should be replaced for distributed training.
75
212
TinyBertForPretraining ._forward = minilm_pretraining_forward
213
+ ErnieForSequenceClassification ._forward = minilm_pretraining_forward
76
214
else :
77
215
TinyBertForPretraining ._forward = tinybert_forward
78
216
79
217
def init_func (layer ):
80
218
if isinstance (layer , (MultiHeadAttention , TransformerEncoderLayer ,
81
219
TransformerEncoder , TinyBertForPretraining ,
82
- BertForSequenceClassification )):
220
+ BertForSequenceClassification ,
221
+ ErnieForSequenceClassification )):
83
222
layer .forward = layer ._forward
84
223
if isinstance (layer , TransformerEncoder ):
85
224
layer .return_layer_outputs = return_layer_outputs
@@ -125,17 +264,17 @@ def attention_forward(self,
125
264
attn_mask = None ,
126
265
cache = None ):
127
266
"""
128
- Redefines the `forward` function of `paddle.nn.MultiHeadAttention`
267
+ Redefines the `forward` function of `paddle.nn.MultiHeadAttention`.
129
268
"""
130
269
key = query if key is None else key
131
270
value = query if value is None else value
132
- # compute q ,k ,v
271
+ # Computes q ,k ,v
133
272
if cache is None :
134
273
q , k , v = self ._prepare_qkv (query , key , value , cache )
135
274
else :
136
275
q , k , v , cache = self ._prepare_qkv (query , key , value , cache )
137
276
138
- # scale dot product attention
277
+ # Scale dot product attention
139
278
product = tensor .matmul (x = q , y = k , transpose_y = True )
140
279
product /= math .sqrt (self .head_dim )
141
280
@@ -159,11 +298,11 @@ def attention_forward(self,
159
298
self .k = k
160
299
self .v = v
161
300
162
- # combine heads
301
+ # Combine heads
163
302
out = tensor .transpose (out , perm = [0 , 2 , 1 , 3 ])
164
303
out = tensor .reshape (x = out , shape = [0 , 0 , out .shape [2 ] * out .shape [3 ]])
165
304
166
- # project to output
305
+ # Project to output
167
306
out = self .out_proj (out )
168
307
169
308
outs = [out ]
@@ -176,7 +315,7 @@ def attention_forward(self,
176
315
177
316
def transformer_encoder_layer_forward (self , src , src_mask = None , cache = None ):
178
317
"""
179
- Redefines the `forward` function of `paddle.nn.TransformerEncoderLayer`
318
+ Redefines the `forward` function of `paddle.nn.TransformerEncoderLayer`.
180
319
"""
181
320
src_mask = _convert_attention_mask (src_mask , src .dtype )
182
321
@@ -210,7 +349,7 @@ def transformer_encoder_layer_forward(self, src, src_mask=None, cache=None):
210
349
211
350
def transformer_encoder_forward (self , src , src_mask = None , cache = None ):
212
351
"""
213
- Redefines the `forward` function of `paddle.nn.TransformerEncoder`
352
+ Redefines the `forward` function of `paddle.nn.TransformerEncoder`.
214
353
"""
215
354
src_mask = _convert_attention_mask (src_mask , src .dtype )
216
355
@@ -251,7 +390,7 @@ def minilm_pretraining_forward(self,
251
390
single GPU, this `forward` could not be replaced.
252
391
The type of `self` should inherit from base class of pretrained LMs, such as
253
392
`TinyBertForPretraining`.
254
- Strategy MINILM only need q, k and v of transformers.
393
+ Strategy MINILM only needs q, k and v of transformers.
255
394
"""
256
395
assert hasattr (self , self .base_model_prefix ), \
257
396
"Student class should inherit from %s" % (self .base_model_class )
@@ -275,7 +414,8 @@ def tinybert_forward(self, input_ids, token_type_ids=None, attention_mask=None):
275
414
sequence_output , pooled_output = model (input_ids , token_type_ids ,
276
415
attention_mask )
277
416
for i in range (len (encoder .hidden_states )):
278
- # While using tinybert-4l-312d, tinybert-6l-768d, tinybert-4l-312d-zh, tinybert-6l-768d-zh
417
+ # While using tinybert-4l-312d, tinybert-6l-768d, tinybert-4l-312d-zh,
418
+ # tinybert-6l-768d-zh
279
419
# While using tinybert-4l-312d-v2, tinybert-6l-768d-v2
280
420
# encoder.hidden_states[i] = self.tinybert.fit_dense(encoder.hidden_states[i])
281
421
encoder .hidden_states [i ] = self .tinybert .fit_denses [i ](
0 commit comments