Skip to content

Commit 25014fa

Browse files
yzchenJianfeng Wang
authored andcommitted
refactor(nlp): fix some pylint problems of bert (#80)
1 parent a3c3baf commit 25014fa

File tree

8 files changed

+104
-87
lines changed

8 files changed

+104
-87
lines changed

.github/workflows/ci.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,10 @@ jobs:
3838
run: |
3939
export PYTHONPATH=$PWD:$PYTHONPATH
4040
41-
CHECK_DIR=official/vision/
41+
CHECK_VISION=official/vision/
42+
CHECK_NLP=official/nlp/
4243
pip install pylint==2.5.2
43-
pylint $CHECK_DIR --rcfile=.pylintrc || pylint_ret=$?
44+
pylint $CHECK_VISION $CHECK_NLP --rcfile=.pylintrc || pylint_ret=$?
4445
echo test, and deploy your project.
4546
if [ "$pylint_ret" ]; then
4647
exit $pylint_ret

official/nlp/bert/config_args.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@
1414
def get_args():
1515
parser = argparse.ArgumentParser()
1616

17-
## parameters
17+
# parameters
1818
parser.add_argument(
1919
"--data_dir",
2020
default=None,
2121
type=str,
2222
required=True,
23-
help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
23+
help="The input data dir. Should contain the .tsv files (or other data files)"
24+
" for the task.",
2425
)
2526

2627
parser.add_argument(

official/nlp/bert/model.py

Lines changed: 68 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,6 @@
1616

1717
"""Megengine BERT model."""
1818

19-
from __future__ import (absolute_import, division, print_function,
20-
unicode_literals)
21-
2219
import copy
2320
import json
2421
import math
@@ -27,10 +24,11 @@
2724
import urllib.request
2825
from io import open
2926

27+
import numpy as np
28+
3029
import megengine as mge
3130
import megengine.functional as F
3231
import megengine.hub as hub
33-
import numpy as np
3432
from megengine import Parameter
3533
from megengine.functional.loss import cross_entropy
3634
from megengine.module import Dropout, Embedding, Linear, Module, Sequential
@@ -45,7 +43,8 @@ def transpose(inp, a, b):
4543

4644
def gelu(x):
4745
"""Implementation of the gelu activation function.
48-
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
46+
For information: OpenAI GPT's gelu is slightly different
47+
(and gives slightly different results):
4948
x * 0.5 * (1.0 + F.tanh((F.sqrt(2 / math.pi) * (x + 0.044715 * (x ** 3)))))
5049
Also see https://arxiv.org/abs/1606.08415
5150
"""
@@ -98,7 +97,7 @@ def __init__(
9897
initializing all weight matrices.
9998
"""
10099
if isinstance(vocab_size_or_config_json_file, str):
101-
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
100+
with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
102101
json_config = json.loads(reader.read())
103102
for key, value in json_config.items():
104103
self.__dict__[key] = value
@@ -158,7 +157,7 @@ class BertLayerNorm(Module):
158157
"""
159158

160159
def __init__(self, hidden_size, eps=1e-12):
161-
super(BertLayerNorm, self).__init__()
160+
super().__init__()
162161
self.weight = Parameter(np.ones(hidden_size).astype(np.float32))
163162
self.bias = Parameter(np.zeros(hidden_size).astype(np.float32))
164163
self.variance_epsilon = eps
@@ -175,7 +174,7 @@ class BertEmbeddings(Module):
175174
"""
176175

177176
def __init__(self, config):
178-
super(BertEmbeddings, self).__init__()
177+
super().__init__()
179178
self.word_embeddings = Embedding(config.vocab_size, config.hidden_size)
180179
self.position_embeddings = Embedding(
181180
config.max_position_embeddings, config.hidden_size
@@ -184,8 +183,8 @@ def __init__(self, config):
184183
config.type_vocab_size, config.hidden_size
185184
)
186185

187-
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
188-
# any TensorFlow checkpoint file
186+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name
187+
# and be able to load any TensorFlow checkpoint file
189188
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
190189
self.dropout = Dropout(config.hidden_dropout_prob)
191190

@@ -210,7 +209,7 @@ def forward(self, input_ids, token_type_ids=None):
210209

211210
class BertSelfAttention(Module):
212211
def __init__(self, config):
213-
super(BertSelfAttention, self).__init__()
212+
super().__init__()
214213
if config.hidden_size % config.num_attention_heads != 0:
215214
raise ValueError(
216215
"The hidden size (%d) is not a multiple of the number of attention "
@@ -229,7 +228,9 @@ def __init__(self, config):
229228
def transpose_for_scores(self, x):
230229
# using symbolic shapes to make trace happy
231230
x_shape = mge.tensor(x.shape)
232-
new_x_shape = F.concat([x_shape[:-1], (self.num_attention_heads, self.attention_head_size)])
231+
new_x_shape = F.concat(
232+
[x_shape[:-1], (self.num_attention_heads, self.attention_head_size)]
233+
)
233234
x = x.reshape(new_x_shape)
234235
return x.transpose(0, 2, 1, 3)
235236

@@ -266,7 +267,7 @@ def forward(self, hidden_states, attention_mask):
266267

267268
class BertSelfOutput(Module):
268269
def __init__(self, config):
269-
super(BertSelfOutput, self).__init__()
270+
super().__init__()
270271
self.dense = Linear(config.hidden_size, config.hidden_size)
271272
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
272273
self.dropout = Dropout(config.hidden_dropout_prob)
@@ -280,7 +281,7 @@ def forward(self, hidden_states, input_tensor):
280281

281282
class BertAttention(Module):
282283
def __init__(self, config):
283-
super(BertAttention, self).__init__()
284+
super().__init__()
284285
self.self = BertSelfAttention(config)
285286
self.output = BertSelfOutput(config)
286287

@@ -292,7 +293,7 @@ def forward(self, input_tensor, attention_mask):
292293

293294
class BertIntermediate(Module):
294295
def __init__(self, config):
295-
super(BertIntermediate, self).__init__()
296+
super().__init__()
296297
self.dense = Linear(config.hidden_size, config.intermediate_size)
297298
if isinstance(config.hidden_act, str):
298299
self.intermediate_act_fn = ACT2FN[config.hidden_act]
@@ -307,7 +308,7 @@ def forward(self, hidden_states):
307308

308309
class BertOutput(Module):
309310
def __init__(self, config):
310-
super(BertOutput, self).__init__()
311+
super().__init__()
311312
self.dense = Linear(config.intermediate_size, config.hidden_size)
312313
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
313314
self.dropout = Dropout(config.hidden_dropout_prob)
@@ -321,7 +322,7 @@ def forward(self, hidden_states, input_tensor):
321322

322323
class BertLayer(Module):
323324
def __init__(self, config):
324-
super(BertLayer, self).__init__()
325+
super().__init__()
325326
self.attention = BertAttention(config)
326327
self.intermediate = BertIntermediate(config)
327328
self.output = BertOutput(config)
@@ -335,7 +336,7 @@ def forward(self, hidden_states, attention_mask):
335336

336337
class BertEncoder(Module):
337338
def __init__(self, config):
338-
super(BertEncoder, self).__init__()
339+
super().__init__()
339340
self.layer = Sequential(
340341
*[BertLayer(config) for _ in range(config.num_hidden_layers)]
341342
)
@@ -354,7 +355,7 @@ def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True)
354355

355356
class BertPooler(Module):
356357
def __init__(self, config):
357-
super(BertPooler, self).__init__()
358+
super().__init__()
358359
self.dense = Linear(config.hidden_size, config.hidden_size)
359360
self.activation = F.tanh
360361

@@ -375,26 +376,34 @@ class BertModel(Module):
375376
376377
Inputs:
377378
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
378-
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
379+
with the word token indices in the vocabulary
380+
(see the tokens preprocessing logic in the scripts
379381
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
380-
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
381-
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
382+
`token_type_ids`: an optional torch.LongTensor of shape
383+
[batch_size, sequence_length] with the token types indices selected in [0, 1].
384+
Type 0 corresponds to a `sentence A` and type 1 corresponds to
382385
a `sentence B` token (see BERT paper for more details).
383-
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
384-
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
385-
input sequence length in the current batch. It's the mask that we typically use for attention when
386+
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length]
387+
with indices selected in [0, 1]. It's a mask to be used if the input sequence length
388+
is smaller than the max input sequence length in the current batch.
389+
It's the mask that we typically use for attention when
386390
a batch has varying length sentences.
387-
`output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
391+
`output_all_encoded_layers`: boolean which controls the content of the `encoded_layers`
392+
output as described below. Default: `True`.
388393
389394
Outputs: Tuple of (encoded_layers, pooled_output)
390395
`encoded_layers`: controled by `output_all_encoded_layers` argument:
391-
- `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
392-
of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
393-
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
394-
- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
395-
to the last attention block of shape [batch_size, sequence_length, hidden_size],
396-
`pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
397-
classifier pretrained on top of the hidden state associated to the first character of the
396+
- `output_all_encoded_layers=True`: outputs a list of the full sequences of
397+
encoded-hidden-states at the end of each attention block
398+
(i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
399+
encoded-hidden-state is a torch.FloatTensor of size
400+
[batch_size, sequence_length, hidden_size],
401+
- `output_all_encoded_layers=False`: outputs only the full sequence of
402+
hidden-states corresponding to the last attention block of shape
403+
[batch_size, sequence_length, hidden_size],
404+
`pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size]
405+
which is the output of classifier pretrained on top of the hidden state
406+
associated to the first character of the
398407
input (`CLS`) to train on the Next-Sentence task (see BERT's paper).
399408
400409
Example usage:
@@ -474,15 +483,17 @@ class BertForSequenceClassification(Module):
474483
475484
Inputs:
476485
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
477-
with the word token indices in the vocabulary. Items in the batch should begin with the special "CLS" token. (see the tokens preprocessing logic in the scripts
486+
with the word token indices in the vocabulary.
487+
Items in the batch should begin with the special "CLS" token.
488+
(see the tokens preprocessing logic in the scripts
478489
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
479-
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
480-
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
481-
a `sentence B` token (see BERT paper for more details).
482-
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
483-
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
484-
input sequence length in the current batch. It's the mask that we typically use for attention when
485-
a batch has varying length sentences.
490+
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length]
491+
with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A`
492+
and type 1 corresponds to a `sentence B` token (see BERT paper for more details).
493+
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length]
494+
with indices selected in [0, 1]. It's a mask to be used if the input sequence length
495+
is smaller than the max input sequence length in the current batch. It's the mask
496+
that we typically use for attention when a batch has varying length sentences.
486497
`labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
487498
with indices selected in [0, ..., num_labels].
488499
@@ -580,7 +591,8 @@ def create_hub_bert(model_name, pretrained):
580591

581592

582593
@hub.pretrained(
583-
"https://data.megengine.org.cn/models/weights/bert/uncased_L-12_H-768_A-12/bert_4f2157f7_uncased_L-12_H-768_A-12.pkl"
594+
"https://data.megengine.org.cn/models/weights/bert/"
595+
"uncased_L-12_H-768_A-12/bert_4f2157f7_uncased_L-12_H-768_A-12.pkl"
584596
)
585597
def uncased_L_12_H_768_A_12():
586598
config_dict = {
@@ -601,7 +613,8 @@ def uncased_L_12_H_768_A_12():
601613

602614

603615
@hub.pretrained(
604-
"https://data.megengine.org.cn/models/weights/bert/cased_L-12_H-768_A-12/bert_b9727c2f_cased_L-12_H-768_A-12.pkl"
616+
"https://data.megengine.org.cn/models/weights/bert/"
617+
"cased_L-12_H-768_A-12/bert_b9727c2f_cased_L-12_H-768_A-12.pkl"
605618
)
606619
def cased_L_12_H_768_A_12():
607620
config_dict = {
@@ -622,7 +635,8 @@ def cased_L_12_H_768_A_12():
622635

623636

624637
@hub.pretrained(
625-
"https://data.megengine.org.cn/models/weights/bert/uncased_L-24_H-1024_A-16/bert_222f5012_uncased_L-24_H-1024_A-16.pkl"
638+
"https://data.megengine.org.cn/models/weights/bert/"
639+
"uncased_L-24_H-1024_A-16/bert_222f5012_uncased_L-24_H-1024_A-16.pkl"
626640
)
627641
def uncased_L_24_H_1024_A_16():
628642
config_dict = {
@@ -644,7 +658,8 @@ def uncased_L_24_H_1024_A_16():
644658

645659

646660
@hub.pretrained(
647-
"https://data.megengine.org.cn/models/weights/bert/cased_L-24_H-1024_A-16/bert_01f2a65f_cased_L-24_H-1024_A-16.pkl"
661+
"https://data.megengine.org.cn/models/weights/bert/"
662+
"cased_L-24_H-1024_A-16/bert_01f2a65f_cased_L-24_H-1024_A-16.pkl"
648663
)
649664
def cased_L_24_H_1024_A_16():
650665
config_dict = {
@@ -672,7 +687,8 @@ def cased_L_24_H_1024_A_16():
672687

673688

674689
@hub.pretrained(
675-
"https://data.megengine.org.cn/models/weights/bert/chinese_L-12_H-768_A-12/bert_ee91be1a_chinese_L-12_H-768_A-12.pkl"
690+
"https://data.megengine.org.cn/models/weights/bert/"
691+
"chinese_L-12_H-768_A-12/bert_ee91be1a_chinese_L-12_H-768_A-12.pkl"
676692
)
677693
def chinese_L_12_H_768_A_12():
678694
config_dict = {
@@ -699,7 +715,8 @@ def chinese_L_12_H_768_A_12():
699715

700716

701717
@hub.pretrained(
702-
"https://data.megengine.org.cn/models/weights/bert/multi_cased_L-12_H-768_A-12/bert_283ceec5_multi_cased_L-12_H-768_A-12.pkl"
718+
"https://data.megengine.org.cn/models/weights/bert/"
719+
"multi_cased_L-12_H-768_A-12/bert_283ceec5_multi_cased_L-12_H-768_A-12.pkl"
703720
)
704721
def multi_cased_L_12_H_768_A_12():
705722
config_dict = {
@@ -727,7 +744,8 @@ def multi_cased_L_12_H_768_A_12():
727744

728745

729746
@hub.pretrained(
730-
"https://data.megengine.org.cn/models/weights/bert/wwm_uncased_L-24_H-1024_A-16/bert_e2780a6a_wwm_uncased_L-24_H-1024_A-16.pkl"
747+
"https://data.megengine.org.cn/models/weights/bert/"
748+
"wwm_uncased_L-24_H-1024_A-16/bert_e2780a6a_wwm_uncased_L-24_H-1024_A-16.pkl"
731749
)
732750
def wwm_uncased_L_24_H_1024_A_16():
733751
config_dict = {
@@ -748,7 +766,8 @@ def wwm_uncased_L_24_H_1024_A_16():
748766

749767

750768
@hub.pretrained(
751-
"https://data.megengine.org.cn/models/weights/bert/wwm_cased_L-24_H-1024_A-16/bert_0a8f1389_wwm_cased_L-24_H-1024_A-16.pkl"
769+
"https://data.megengine.org.cn/models/weights/bert/"
770+
"wwm_cased_L-24_H-1024_A-16/bert_0a8f1389_wwm_cased_L-24_H-1024_A-16.pkl"
752771
)
753772
def wwm_cased_L_24_H_1024_A_16():
754773
config_dict = {

official/nlp/bert/mrpc_dataset.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,15 @@
99
import csv
1010
import os
1111

12-
import megengine as mge
12+
from tokenization import BertTokenizer
13+
1314
import numpy as np
15+
16+
import megengine as mge
1417
from megengine.data import DataLoader
1518
from megengine.data.dataset import ArrayDataset
1619
from megengine.data.sampler import RandomSampler, SequentialSampler
1720

18-
from tokenization import BertTokenizer
19-
2021
logger = mge.get_logger(__name__)
2122

2223

@@ -199,7 +200,9 @@ def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer
199200
logger.info("tokens: {}".format(" ".join([str(x) for x in tokens])))
200201
logger.info("input_ids: {}".format(" ".join([str(x) for x in input_ids])))
201202
logger.info("input_mask: {}".format(" ".join([str(x) for x in input_mask])))
202-
logger.info("segment_ids: {}".format(" ".join([str(x) for x in segment_ids])))
203+
logger.info(
204+
"segment_ids: {}".format(" ".join([str(x) for x in segment_ids]))
205+
)
203206
logger.info("label: {} (id = {})".format(example.label, label_id))
204207

205208
features.append(

official/nlp/bert/test.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,21 @@
77
# software distributed under the License is distributed on an
88
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99

10-
import megengine as mge
11-
import megengine.functional as F
12-
from megengine.jit import trace
1310
from tqdm import tqdm
1411

15-
from model import BertForSequenceClassification, create_hub_bert
16-
from mrpc_dataset import MRPCDataset
1712
# pylint: disable=import-outside-toplevel
1813
import config_args
14+
from mrpc_dataset import MRPCDataset
15+
16+
import megengine as mge
17+
import megengine.functional as F
18+
19+
from official.nlp.bert.model import BertForSequenceClassification, create_hub_bert
20+
1921
args = config_args.get_args()
2022
logger = mge.get_logger(__name__)
2123

2224

23-
# @trace(symbolic=True)
2425
def net_eval(input_ids, segment_ids, input_mask, label_ids, net=None):
2526
net.eval()
2627
results = net(input_ids, segment_ids, input_mask, label_ids)

0 commit comments

Comments
 (0)