Skip to content

Commit 8268726

Browse files
authored
Update distill_utils to support PP-MINILM (#1426)
* update distill utils to support PP-MiniLM update function name update remove ununsed import * update doc for loss_fct * add description for loss_fct
1 parent 41ab265 commit 8268726

File tree

1 file changed

+155
-15
lines changed

1 file changed

+155
-15
lines changed

paddlenlp/transformers/distill_utils.py

Lines changed: 155 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,27 +19,159 @@
1919
import paddle.nn.functional as F
2020
from paddle.nn import MultiHeadAttention, TransformerEncoderLayer, TransformerEncoder
2121
from paddle.fluid.data_feeder import convert_dtype
22+
2223
from paddlenlp.utils.log import logger
23-
from paddlenlp.transformers import TinyBertForPretraining, TinyBertForSequenceClassification, BertForSequenceClassification
24+
from paddlenlp.transformers import ErnieForSequenceClassification
25+
from paddlenlp.transformers import TinyBertForPretraining
26+
from paddlenlp.transformers import BertForSequenceClassification
27+
28+
__all__ = ['to_distill', 'calc_minilm_loss', 'calc_multi_relation_loss']
29+
30+
31+
def calc_multi_relation_loss(loss_fct,
32+
s,
33+
t,
34+
attn_mask,
35+
num_relation_heads=0,
36+
alpha=0.0,
37+
beta=0.0):
38+
"""
39+
Calculates loss for multiple Q-Q, K-K and V-V relation. It supports
40+
head-head relation, sample-sample relation and origin token-token relation.
41+
The final loss value could be balanced by weight `alpha` and `beta`.
42+
43+
Args:
44+
loss_fct (callable):
45+
Loss function for distillation. It only supports kl_div loss now.
46+
s (Tensor):
47+
Q, K, V of Student.
48+
t (Tensor):
49+
Q, K, V of teacher.
50+
attn_mask (Tensor):
51+
Attention mask for relation.
52+
num_relation_heads (int):
53+
The number of relation heads. 0 means `num_relation_heads` equals
54+
to origin head num.
55+
Defaults to 0.
56+
alpha (float):
57+
The weight for head-head relation.
58+
Defaults to 0.0.
59+
beta (float):
60+
The weight for sample-sample relation.
61+
Defaults to 0.0.
62+
63+
Returns:
64+
Tensor: Weighted loss of token-token loss, head-head loss and
65+
sample-sample loss.
66+
67+
"""
68+
# Initialize head_num
69+
if num_relation_heads > 0 and num_relation_heads != s.shape[1]:
70+
# s'shape: [bs, seq_len, head_num, head_dim]
71+
s = tensor.transpose(x=s, perm=[0, 2, 1, 3])
72+
# s'shape: [bs, seq_len, num_relation_heads, head_dim_new]
73+
s = tensor.reshape(x=s, shape=[0, 0, num_relation_heads, -1])
74+
s1 = tensor.transpose(x=s, perm=[0, 2, 1, 3])
75+
if num_relation_heads > 0 and num_relation_heads != t.shape[1]:
76+
t = tensor.transpose(x=t, perm=[0, 2, 1, 3])
77+
t = tensor.reshape(x=t, shape=[0, 0, num_relation_heads, -1])
78+
t1 = tensor.transpose(x=t, perm=[0, 2, 1, 3])
79+
80+
s_head_dim, t_head_dim = s.shape[3], t.shape[3]
81+
82+
if alpha + beta == 1.0:
83+
loss_token_token = 0.0
84+
else:
85+
scaled_dot_product_s1 = tensor.matmul(
86+
x=s1, y=s1, transpose_y=True) / math.sqrt(s_head_dim)
87+
del s1
88+
scaled_dot_product_s1 += attn_mask
89+
scaled_dot_product_t1 = tensor.matmul(
90+
x=t1, y=t1, transpose_y=True) / math.sqrt(t_head_dim)
91+
del t1
92+
scaled_dot_product_t1 += attn_mask
93+
loss_token_token = loss_fct(
94+
F.log_softmax(scaled_dot_product_s1),
95+
F.softmax(scaled_dot_product_t1))
96+
97+
if alpha == 0.0:
98+
loss_head_head = 0.0
99+
else:
100+
scaled_dot_product_s = tensor.matmul(
101+
x=s, y=s, transpose_y=True) / math.sqrt(s_head_dim)
102+
attn_mask_head_head = tensor.transpose(x=attn_mask, perm=[0, 3, 1, 2])
103+
104+
scaled_dot_product_s += attn_mask_head_head
105+
scaled_dot_product_t = tensor.matmul(
106+
x=t, y=t, transpose_y=True) / math.sqrt(t_head_dim)
107+
scaled_dot_product_t += attn_mask_head_head
108+
loss_head_head = loss_fct(
109+
F.log_softmax(scaled_dot_product_s),
110+
F.softmax(scaled_dot_product_t))
111+
if beta == 0.0:
112+
loss_sample_sample = 0.0
113+
else:
114+
s2 = tensor.transpose(x=s, perm=[1, 2, 0, 3])
115+
scaled_dot_product_s2 = tensor.matmul(
116+
x=s2, y=s2, transpose_y=True) / math.sqrt(s_head_dim)
117+
118+
del s, s2
119+
# Shape: [seq_len, 1, batch_size, 1]
120+
attn_mask_sample_sample = tensor.transpose(
121+
x=attn_mask, perm=[3, 1, 0, 2])
24122

25-
__all__ = ['to_distill', 'calc_minilm_loss']
123+
# Shape: [seq_len, head_num, batch_size, batch_size]
124+
scaled_dot_product_s2 += attn_mask_sample_sample
125+
t2 = tensor.transpose(x=t, perm=[1, 2, 0, 3])
126+
scaled_dot_product_t2 = tensor.matmul(
127+
x=t2, y=t2, transpose_y=True) / math.sqrt(t_head_dim)
128+
129+
del t, t2
130+
scaled_dot_product_t2 += attn_mask_sample_sample
131+
loss_sample_sample = loss_fct(
132+
F.log_softmax(scaled_dot_product_s2),
133+
F.softmax(scaled_dot_product_t2))
134+
135+
return (
136+
1 - alpha - beta
137+
) * loss_token_token + alpha * loss_head_head + beta * loss_sample_sample
26138

27139

28140
def calc_minilm_loss(loss_fct, s, t, attn_mask, num_relation_heads=0):
141+
"""
142+
Calculates loss for Q-Q, K-K, V-V relation from MiniLMv2.
143+
144+
Args:
145+
loss_fct (callable):
146+
Loss function for distillation. It only supports kl_div loss now.
147+
s (Tensor):
148+
Q, K, V of Student.
149+
t (Tensor):
150+
Q, K, V of teacher.
151+
attn_mask (Tensor):
152+
Attention mask for relation.
153+
num_relation_heads (int):
154+
The number of relation heads. 0 means `num_relation_heads` equals
155+
to origin head num.
156+
Defaults to 0.
157+
158+
Returns:
159+
Tensor: MiniLM loss value.
160+
161+
"""
29162
# Initialize head_num
30163
if num_relation_heads > 0 and num_relation_heads != s.shape[1]:
31164
# s'shape: [bs, seq_len, head_num, head_dim]
32165
s = tensor.transpose(x=s, perm=[0, 2, 1, 3])
33166
# s'shape: [bs, seq_len, num_relation_heads, head_dim_new]
34167
s = tensor.reshape(x=s, shape=[0, 0, num_relation_heads, -1])
35-
#s's shape: [bs, num_relation_heads, seq_len,, head_dim_new]
168+
# s' shape: [bs, num_relation_heads, seq_len, head_dim_new]
36169
s = tensor.transpose(x=s, perm=[0, 2, 1, 3])
37170
if num_relation_heads > 0 and num_relation_heads != t.shape[1]:
38171
t = tensor.transpose(x=t, perm=[0, 2, 1, 3])
39172
t = tensor.reshape(x=t, shape=[0, 0, num_relation_heads, -1])
40173
t = tensor.transpose(x=t, perm=[0, 2, 1, 3])
41174

42-
pad_seq_len = s.shape[2]
43175
s_head_dim, t_head_dim = s.shape[3], t.shape[3]
44176
scaled_dot_product_s = tensor.matmul(
45177
x=s, y=s, transpose_y=True) / math.sqrt(s_head_dim)
@@ -62,24 +194,31 @@ def to_distill(self,
62194
layer_index=-1):
63195
"""
64196
Can be bound to object with transformer encoder layers, and make model
65-
expose attributes `outputs.qs`, `outputs.ks`, `outputs.vs`,
197+
expose attributes `outputs.q`, `outputs.k`, `outputs.v`,
66198
`outputs.scaled_qks`, `outputs.hidden_states`and `outputs.attentions` of
67199
the object for distillation.
200+
201+
It could be returned intermediate tensor using in MiniLM and TinyBERT
202+
strategy.
68203
"""
69204
logger.warning("`to_distill` is an experimental API and subject to change.")
70205
MultiHeadAttention._forward = attention_forward
71206
TransformerEncoderLayer._forward = transformer_encoder_layer_forward
72207
TransformerEncoder._forward = transformer_encoder_forward
73208
BertForSequenceClassification._forward = bert_forward
209+
74210
if return_qkv:
211+
# forward function of student class should be replaced for distributed training.
75212
TinyBertForPretraining._forward = minilm_pretraining_forward
213+
ErnieForSequenceClassification._forward = minilm_pretraining_forward
76214
else:
77215
TinyBertForPretraining._forward = tinybert_forward
78216

79217
def init_func(layer):
80218
if isinstance(layer, (MultiHeadAttention, TransformerEncoderLayer,
81219
TransformerEncoder, TinyBertForPretraining,
82-
BertForSequenceClassification)):
220+
BertForSequenceClassification,
221+
ErnieForSequenceClassification)):
83222
layer.forward = layer._forward
84223
if isinstance(layer, TransformerEncoder):
85224
layer.return_layer_outputs = return_layer_outputs
@@ -125,17 +264,17 @@ def attention_forward(self,
125264
attn_mask=None,
126265
cache=None):
127266
"""
128-
Redefines the `forward` function of `paddle.nn.MultiHeadAttention`
267+
Redefines the `forward` function of `paddle.nn.MultiHeadAttention`.
129268
"""
130269
key = query if key is None else key
131270
value = query if value is None else value
132-
# compute q ,k ,v
271+
# Computes q ,k ,v
133272
if cache is None:
134273
q, k, v = self._prepare_qkv(query, key, value, cache)
135274
else:
136275
q, k, v, cache = self._prepare_qkv(query, key, value, cache)
137276

138-
# scale dot product attention
277+
# Scale dot product attention
139278
product = tensor.matmul(x=q, y=k, transpose_y=True)
140279
product /= math.sqrt(self.head_dim)
141280

@@ -159,11 +298,11 @@ def attention_forward(self,
159298
self.k = k
160299
self.v = v
161300

162-
# combine heads
301+
# Combine heads
163302
out = tensor.transpose(out, perm=[0, 2, 1, 3])
164303
out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
165304

166-
# project to output
305+
# Project to output
167306
out = self.out_proj(out)
168307

169308
outs = [out]
@@ -176,7 +315,7 @@ def attention_forward(self,
176315

177316
def transformer_encoder_layer_forward(self, src, src_mask=None, cache=None):
178317
"""
179-
Redefines the `forward` function of `paddle.nn.TransformerEncoderLayer`
318+
Redefines the `forward` function of `paddle.nn.TransformerEncoderLayer`.
180319
"""
181320
src_mask = _convert_attention_mask(src_mask, src.dtype)
182321

@@ -210,7 +349,7 @@ def transformer_encoder_layer_forward(self, src, src_mask=None, cache=None):
210349

211350
def transformer_encoder_forward(self, src, src_mask=None, cache=None):
212351
"""
213-
Redefines the `forward` function of `paddle.nn.TransformerEncoder`
352+
Redefines the `forward` function of `paddle.nn.TransformerEncoder`.
214353
"""
215354
src_mask = _convert_attention_mask(src_mask, src.dtype)
216355

@@ -251,7 +390,7 @@ def minilm_pretraining_forward(self,
251390
single GPU, this `forward` could not be replaced.
252391
The type of `self` should inherit from base class of pretrained LMs, such as
253392
`TinyBertForPretraining`.
254-
Strategy MINILM only need q, k and v of transformers.
393+
Strategy MINILM only needs q, k and v of transformers.
255394
"""
256395
assert hasattr(self, self.base_model_prefix), \
257396
"Student class should inherit from %s" % (self.base_model_class)
@@ -275,7 +414,8 @@ def tinybert_forward(self, input_ids, token_type_ids=None, attention_mask=None):
275414
sequence_output, pooled_output = model(input_ids, token_type_ids,
276415
attention_mask)
277416
for i in range(len(encoder.hidden_states)):
278-
# While using tinybert-4l-312d, tinybert-6l-768d, tinybert-4l-312d-zh, tinybert-6l-768d-zh
417+
# While using tinybert-4l-312d, tinybert-6l-768d, tinybert-4l-312d-zh,
418+
# tinybert-6l-768d-zh
279419
# While using tinybert-4l-312d-v2, tinybert-6l-768d-v2
280420
# encoder.hidden_states[i] = self.tinybert.fit_dense(encoder.hidden_states[i])
281421
encoder.hidden_states[i] = self.tinybert.fit_denses[i](

0 commit comments

Comments
 (0)