Skip to content

Commit 5d22c2c

Browse files
yzchenJianfeng Wang
authored andcommitted
feat(nlp): Update bert for MegEngine v1.0 (#73)
1 parent 72632d7 commit 5d22c2c

File tree

3 files changed

+41
-75
lines changed

3 files changed

+41
-75
lines changed

official/nlp/bert/model.py

Lines changed: 23 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -32,52 +32,24 @@
3232
import megengine.hub as hub
3333
import numpy as np
3434
from megengine import Parameter
35-
from megengine.functional import cross_entropy_with_softmax
35+
from megengine.functional.loss import cross_entropy
3636
from megengine.module import Dropout, Embedding, Linear, Module, Sequential
3737
from megengine.module.activation import Softmax
3838

3939

4040
def transpose(inp, a, b):
41-
cur_shape = list(range(0, len(inp.shape)))
41+
cur_shape = list(range(0, inp.ndim))
4242
cur_shape[a], cur_shape[b] = cur_shape[b], cur_shape[a]
43-
return inp.dimshuffle(*cur_shape)
43+
return inp.transpose(cur_shape)
4444

4545

46-
def matmul(a, b, transpose_b=None):
47-
dim = len(b.shape)
48-
49-
if transpose_b:
50-
b = transpose(b, dim - 1, dim - 2)
51-
52-
if dim > 3:
53-
a_shape = list(a.shape)
54-
b_shape = list(b.shape)
55-
reshape_batch_size = 1
56-
for i in a_shape[0 : dim - 2]:
57-
reshape_batch_size *= i
58-
a = a.reshape(*([reshape_batch_size] + a_shape[dim - 2 : dim]))
59-
b = b.reshape(*([reshape_batch_size] + b_shape[dim - 2 : dim]))
60-
c = F.batched_matrix_mul(a, b)
61-
c = c.reshape(*(a_shape[0 : dim - 1] + b_shape[dim - 1 : dim]))
62-
return c
63-
elif dim == 3:
64-
return F.batched_matrix_mul(a, b)
65-
else:
66-
return F.matrix_mul(a, b)
67-
68-
def zeros_like(inp):
69-
return mge.zeros(inp.shape, dtype=inp.dtype)
70-
71-
def ones_like(inp):
72-
return mge.ones(inp.shape, dtype=inp.dtype)
73-
7446
def gelu(x):
7547
"""Implementation of the gelu activation function.
7648
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
77-
x * 0.5 * (1.0 + F.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * (x ** 3)))))
49+
x * 0.5 * (1.0 + F.tanh((F.sqrt(2 / math.pi) * (x + 0.044715 * (x ** 3)))))
7850
Also see https://arxiv.org/abs/1606.08415
7951
"""
80-
return x * 0.5 * (1.0 + F.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * (x ** 3)))))
52+
return x * 0.5 * (1.0 + F.tanh(F.sqrt(2 / math.pi) * (x + 0.044715 * (x ** 3))))
8153

8254

8355
ACT2FN = {"gelu": gelu, "relu": F.relu}
@@ -221,10 +193,10 @@ def forward(self, input_ids, token_type_ids=None):
221193
seq_length = input_ids.shape[1]
222194

223195
if token_type_ids is None:
224-
token_type_ids = zeros_like(input_ids)
196+
token_type_ids = F.zeros_like(input_ids)
225197

226198
position_ids = F.linspace(0, seq_length - 1, seq_length).astype(np.int32)
227-
position_ids = F.add_axis(position_ids, 0).broadcast(*input_ids.shape)
199+
position_ids = F.broadcast_to(F.expand_dims(position_ids, 0), input_ids.shape)
228200
words_embeddings = self.word_embeddings(input_ids)
229201

230202
position_embeddings = self.position_embeddings(position_ids)
@@ -255,12 +227,11 @@ def __init__(self, config):
255227
self.dropout = Dropout(config.attention_probs_dropout_prob)
256228

257229
def transpose_for_scores(self, x):
258-
new_x_shape = x.shape[:-1] + (
259-
self.num_attention_heads,
260-
self.attention_head_size,
261-
)
262-
x = x.reshape(*new_x_shape)
263-
return x.dimshuffle(0, 2, 1, 3)
230+
# using symbolic shapes to make trace happy
231+
x_shape = mge.tensor(x.shape)
232+
new_x_shape = F.concat([x_shape[:-1], (self.num_attention_heads, self.attention_head_size)])
233+
x = x.reshape(new_x_shape)
234+
return x.transpose(0, 2, 1, 3)
264235

265236
def forward(self, hidden_states, attention_mask):
266237
mixed_query_layer = self.query(hidden_states)
@@ -272,7 +243,7 @@ def forward(self, hidden_states, attention_mask):
272243
value_layer = self.transpose_for_scores(mixed_value_layer)
273244

274245
# Take the dot product between "query" and "key" to get the raw attention scores.
275-
attention_scores = matmul(query_layer, transpose(key_layer, -1, -2))
246+
attention_scores = F.matmul(query_layer, transpose(key_layer, -1, -2))
276247
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
277248
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
278249
attention_scores = attention_scores + attention_mask
@@ -284,10 +255,12 @@ def forward(self, hidden_states, attention_mask):
284255
# seem a bit unusual, but is taken from the original Transformer paper.
285256
attention_probs = self.dropout(attention_probs)
286257

287-
context_layer = matmul(attention_probs, value_layer)
288-
context_layer = context_layer.dimshuffle(0, 2, 1, 3)
289-
new_context_layer_shape = context_layer.shape[:-2] + (self.all_head_size,)
290-
context_layer = context_layer.reshape(*new_context_layer_shape)
258+
context_layer = F.matmul(attention_probs, value_layer)
259+
context_layer = context_layer.transpose(0, 2, 1, 3)
260+
# using symbolic shapes to make trace happy
261+
context_shape = mge.tensor(context_layer.shape)
262+
new_context_layer_shape = F.concat([context_shape[:-2], self.all_head_size])
263+
context_layer = context_layer.reshape(new_context_layer_shape)
291264
return context_layer
292265

293266

@@ -453,17 +426,17 @@ def forward(
453426
output_all_encoded_layers=True,
454427
):
455428
if attention_mask is None:
456-
attention_mask = ones_like(input_ids)
429+
attention_mask = F.ones_like(input_ids)
457430
if token_type_ids is None:
458-
token_type_ids = zeros_like(input_ids)
431+
token_type_ids = F.zeros_like(input_ids)
459432
# print('input_ids', input_ids.sum())
460433
# We create a 3D attention mask from a 2D tensor mask.
461434
# Sizes are [batch_size, 1, 1, to_seq_length]
462435
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
463436
# this attention mask is more simple than the triangular masking of causal attention
464437
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
465438
# print('attention_mask', attention_mask.sum())
466-
extended_attention_mask = F.add_axis(attention_mask, (1, 2))
439+
extended_attention_mask = F.expand_dims(attention_mask, (1, 2))
467440

468441
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
469442
# masked positions, this operation will create a tensor which is 0.0 for
@@ -554,7 +527,7 @@ def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=No
554527
logits = self.classifier(pooled_output)
555528

556529
if labels is not None:
557-
loss = cross_entropy_with_softmax(
530+
loss = cross_entropy(
558531
logits.reshape(-1, self.num_labels), labels.reshape(-1)
559532
)
560533
return logits, loss

official/nlp/bert/test.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,14 @@
2020
logger = mge.get_logger(__name__)
2121

2222

23-
@trace(symbolic=True)
23+
# @trace(symbolic=True)
2424
def net_eval(input_ids, segment_ids, input_mask, label_ids, net=None):
2525
net.eval()
2626
results = net(input_ids, segment_ids, input_mask, label_ids)
2727
logits, loss = results
2828
return loss, logits, label_ids
2929

3030

31-
def accuracy(out, labels):
32-
outputs = F.argmax(out, axis=1)
33-
return F.sum(outputs == labels)
34-
35-
3631
def eval(dataloader, net):
3732
logger.info("***** Running evaluation *****")
3833
logger.info("batch size = %d", args.eval_batch_size)
@@ -48,7 +43,7 @@ def eval(dataloader, net):
4843
input_ids, segment_ids, input_mask, label_ids, net=net
4944
)
5045
sum_loss += loss.mean().item()
51-
sum_accuracy += accuracy(logits, label_ids)
46+
sum_accuracy += F.topk_accuracy(logits, label_ids) * batch_size
5247
total_examples += batch_size
5348
total_steps += 1
5449

official/nlp/bert/train.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import megengine as mge
1111
import megengine.functional as F
1212
import megengine.optimizer as optim
13+
from megengine.autodiff import GradManager
1314
from megengine.jit import trace
1415
from tqdm import tqdm
1516

@@ -21,28 +22,24 @@
2122
logger = mge.get_logger(__name__)
2223

2324

24-
@trace(symbolic=True)
25+
# @trace(symbolic=True)
2526
def net_eval(input_ids, segment_ids, input_mask, label_ids, net=None):
2627
net.eval()
2728
results = net(input_ids, segment_ids, input_mask, label_ids)
2829
logits, loss = results
29-
return loss, logits, label_ids
30+
return loss, logits
3031

3132

32-
@trace(symbolic=True)
33-
def net_train(input_ids, segment_ids, input_mask, label_ids, opt=None, net=None):
33+
# @trace(symbolic=True)
34+
def net_train(input_ids, segment_ids, input_mask, label_ids, gm=None, net=None):
3435
net.train()
35-
results = net(input_ids, segment_ids, input_mask, label_ids)
36-
logits, loss = results
37-
opt.backward(loss)
36+
with gm:
37+
results = net(input_ids, segment_ids, input_mask, label_ids)
38+
logits, loss = results
39+
gm.backward(loss)
3840
return loss, logits, label_ids
3941

4042

41-
def accuracy(out, labels):
42-
outputs = F.argmax(out, axis=1)
43-
return F.sum(outputs == labels)
44-
45-
4643
def eval(dataloader, net):
4744
logger.info("***** Running evaluation *****")
4845
logger.info("batch size = %d", args.eval_batch_size)
@@ -56,11 +53,11 @@ def eval(dataloader, net):
5653
batch_size = input_ids.shape[0]
5754
if batch_size != args.eval_batch_size:
5855
break
59-
loss, logits, label_ids = net_eval(
56+
loss, logits = net_eval(
6057
input_ids, segment_ids, input_mask, label_ids, net=net
6158
)
6259
sum_loss += loss.mean().item()
63-
sum_accuracy += accuracy(logits, label_ids)
60+
sum_accuracy += F.topk_accuracy(logits, label_ids) * batch_size
6461
total_examples += batch_size
6562
total_steps += 1
6663

@@ -79,18 +76,19 @@ def train(dataloader, net, opt):
7976
logger.info("batch size = %d", args.train_batch_size)
8077
sum_loss, sum_accuracy, total_steps, total_examples = 0, 0, 0, 0
8178

79+
gm = GradManager().attach(net.parameters())
80+
8281
for _, batch in enumerate(tqdm(dataloader, desc="Iteration")):
8382
input_ids, input_mask, segment_ids, label_ids = tuple(
8483
mge.tensor(t) for t in batch
8584
)
8685
batch_size = input_ids.shape[0]
87-
opt.zero_grad()
8886
loss, logits, label_ids = net_train(
89-
input_ids, segment_ids, input_mask, label_ids, opt=opt, net=net
87+
input_ids, segment_ids, input_mask, label_ids, gm=gm, net=net
9088
)
91-
optimizer.step()
89+
opt.step().clear_grad()
9290
sum_loss += loss.mean().item()
93-
sum_accuracy += accuracy(logits, label_ids)
91+
sum_accuracy += F.topk_accuracy(logits, label_ids) * batch_size
9492
total_examples += batch_size
9593
total_steps += 1
9694

0 commit comments

Comments
 (0)