Skip to content
This repository was archived by the owner on Jan 15, 2024. It is now read-only.

Commit 3922d06

Browse files
eric-haibin-linszha
authored andcommitted
[model] Roberta converted weights (#870)
* +roberta * fix vocab * remove self attention * add model store * add test * add doc * fix doc * fix tset * fix lint * separate class for roberta * fix lint * fix doc
1 parent da936e0 commit 3922d06

File tree

10 files changed

+659
-26
lines changed

10 files changed

+659
-26
lines changed
Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
# coding: utf-8
2+
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# 'License'); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
# pylint:disable=redefined-outer-name,logging-format-interpolation
20+
""" Script for converting Fairseq Roberta Model to Gluon. """
21+
import argparse
22+
import logging
23+
import os
24+
import sys
25+
import io
26+
import numpy as np
27+
28+
import torch
29+
from fairseq.models.roberta import RobertaModel
30+
31+
import mxnet as mx
32+
import gluonnlp as nlp
33+
from gluonnlp.model import BERTEncoder, BERTModel
34+
from gluonnlp.model.bert import bert_hparams
35+
from gluonnlp.data.utils import _load_pretrained_vocab
36+
37+
from utils import get_hash, load_text_vocab, tf_vocab_to_gluon_vocab
38+
39+
parser = argparse.ArgumentParser(description='Conversion script for Fairseq RoBERTa model',
40+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
41+
parser.add_argument('--ckpt_dir', type=str, help='Full path to the roberta folder',
42+
default='/home/ubuntu/roberta/roberta.base')
43+
parser.add_argument('--model', type=str, help='Model type. ',
44+
choices=['roberta_12_768_12', 'roberta_24_1024_16'],
45+
default='roberta_12_768_12')
46+
parser.add_argument('--verbose', action='store_true', help='Verbose logging')
47+
48+
args = parser.parse_args()
49+
50+
ckpt_dir = os.path.expanduser(args.ckpt_dir)
51+
52+
ckpt = torch.load(os.path.join(ckpt_dir, 'model.pt'))
53+
pytorch_params = ckpt['model']
54+
55+
if args.verbose:
56+
print(ckpt['args'])
57+
for k, v in pytorch_params.items():
58+
print(k, v.shape)
59+
60+
# Load the model in fairseq
61+
roberta = RobertaModel.from_pretrained(ckpt_dir)
62+
roberta.eval()
63+
64+
def fairseq_vocab_to_gluon_vocab(torch_vocab):
65+
index_to_words = [None] * len(torch_vocab)
66+
67+
bos_idx = torch_vocab.bos()
68+
pad_idx = torch_vocab.pad()
69+
eos_idx = torch_vocab.eos()
70+
unk_idx = torch_vocab.unk()
71+
72+
index_to_words[bos_idx] = torch_vocab.symbols[bos_idx]
73+
index_to_words[pad_idx] = torch_vocab.symbols[pad_idx]
74+
index_to_words[eos_idx] = torch_vocab.symbols[eos_idx]
75+
index_to_words[unk_idx] = torch_vocab.symbols[unk_idx]
76+
77+
specials = [bos_idx, pad_idx, eos_idx, unk_idx]
78+
79+
openai_to_roberta = {}
80+
openai_vocab = _load_pretrained_vocab('openai_webtext', '.')
81+
82+
with io.open(os.path.join(ckpt_dir, 'dict.txt'), encoding='utf-8') as f:
83+
for i, line in enumerate(f):
84+
token, count = line.split(' ')
85+
try:
86+
fake_token = int(token)
87+
openai_to_roberta[token] = i + len(specials)
88+
except ValueError:
89+
index_to_words[i + len(specials)] = token
90+
91+
for idx, token in enumerate(openai_vocab.idx_to_token):
92+
if str(idx) in openai_to_roberta:
93+
index_to_words[openai_to_roberta[str(idx)]] = token
94+
else:
95+
assert token == u'<mask>', token
96+
97+
mask_idx = torch_vocab.index(u'<mask>')
98+
index_to_words[mask_idx] = torch_vocab.string([mask_idx])
99+
assert None not in index_to_words
100+
word2idx = {}
101+
for idx, token in enumerate(index_to_words):
102+
word2idx[token] = idx
103+
104+
vocab = nlp.vocab.Vocab(word2idx, token_to_idx=word2idx,
105+
unknown_token=index_to_words[unk_idx],
106+
padding_token=index_to_words[pad_idx],
107+
bos_token=index_to_words[bos_idx],
108+
eos_token=index_to_words[eos_idx],
109+
mask_token=u'<mask>')
110+
return vocab
111+
112+
vocab = fairseq_vocab_to_gluon_vocab(roberta.task.dictionary)
113+
114+
predefined_args = bert_hparams[args.model]
115+
116+
# BERT encoder
117+
encoder = BERTEncoder(attention_cell=predefined_args['attention_cell'],
118+
num_layers=predefined_args['num_layers'], units=predefined_args['units'],
119+
hidden_size=predefined_args['hidden_size'],
120+
max_length=predefined_args['max_length'],
121+
num_heads=predefined_args['num_heads'], scaled=predefined_args['scaled'],
122+
dropout=predefined_args['dropout'],
123+
use_residual=predefined_args['use_residual'],
124+
layer_norm_eps=predefined_args['layer_norm_eps'])
125+
126+
# BERT model
127+
bert = BERTModel(encoder, len(vocab),
128+
units=predefined_args['units'], embed_size=predefined_args['embed_size'],
129+
embed_dropout=predefined_args['embed_dropout'],
130+
word_embed=predefined_args['word_embed'], use_pooler=False,
131+
use_token_type_embed=False, use_classifier=False)
132+
133+
bert.initialize(init=mx.init.Normal(0.02))
134+
135+
ones = mx.nd.ones((2, 8))
136+
out = bert(ones, None, mx.nd.array([5, 6]), mx.nd.array([[1], [2]]))
137+
params = bert._collect_params_with_prefix()
138+
139+
140+
141+
mapping = {
142+
'decoder.2' : 'decoder.lm_head.layer_norm',
143+
'decoder.0' : 'decoder.lm_head.dense',
144+
'decoder.3' : 'decoder.lm_head',
145+
'encoder.layer_norm' : 'decoder.sentence_encoder.emb_layer_norm',
146+
'encoder.position_weight' : 'decoder.sentence_encoder.embed_positions.weight',
147+
'encoder.transformer_cells': 'decoder.sentence_encoder.layers',
148+
'attention_cell.proj_key.' : 'self_attn.in_proj_',
149+
'attention_cell.proj_value.' : 'self_attn.in_proj_',
150+
'attention_cell.proj_query.' : 'self_attn.in_proj_',
151+
'ffn.ffn_1' : 'fc1',
152+
'ffn.ffn_2' : 'fc2',
153+
'layer_norm.gamma' : 'layer_norm.weight',
154+
'layer_norm.beta' : 'layer_norm.bias',
155+
'ffn.layer_norm' : 'final_layer_norm',
156+
'word_embed.0.weight' : 'decoder.sentence_encoder.embed_tokens.weight',
157+
}
158+
159+
for i in range(24):
160+
mapping['{}.layer_norm'.format(i)] = '{}.self_attn_layer_norm'.format(i)
161+
mapping['{}.proj'.format(i)] = '{}.self_attn.out_proj'.format(i)
162+
163+
# set parameter data
164+
loaded_params = {}
165+
visited_pytorch_params = {}
166+
for name in params:
167+
pytorch_name = name
168+
for source, dest in mapping.items():
169+
pytorch_name = pytorch_name.replace(source, dest)
170+
171+
assert pytorch_name in pytorch_params.keys(), 'Key ' + pytorch_name + ' for ' + name + ' not found.'
172+
torch_arr = pytorch_params[pytorch_name].cpu()
173+
# fairseq positional embedding starts with index 2
174+
if pytorch_name == 'decoder.sentence_encoder.embed_positions.weight':
175+
torch_arr = torch_arr[2:]
176+
177+
arr = mx.nd.array(torch_arr)
178+
if 'attention_cell.proj' in name:
179+
unfused = ['query', 'key', 'value']
180+
arrs = arr.split(num_outputs=3, axis=0)
181+
for i, p in enumerate(unfused):
182+
if p in name:
183+
arr = arrs[i]
184+
else:
185+
assert arr.shape == params[name].shape, (arr.shape, params[name].shape, name, pytorch_name)
186+
params[name].set_data(arr)
187+
loaded_params[name] = True
188+
visited_pytorch_params[pytorch_name] = True
189+
190+
assert len(params) == len(loaded_params)
191+
assert len(visited_pytorch_params) == len(pytorch_params), "Gluon model does not match PyTorch model. " \
192+
"Please fix the BERTModel hyperparameters\n" + str(len(visited_pytorch_params)) + ' v.s. ' + str(len(pytorch_params))
193+
194+
195+
texts = 'Hello world. abc, def and 中文!'
196+
torch_tokens = roberta.encode(texts)
197+
198+
torch_features = roberta.extract_features(torch_tokens)
199+
pytorch_out = torch_features.detach().numpy()
200+
201+
mx_tokenizer = nlp.data.GPT2BPETokenizer()
202+
mx_tokens = [vocab.bos_token] + mx_tokenizer(texts) + [vocab.eos_token]
203+
mx_data = vocab[mx_tokens]
204+
print(mx_tokens)
205+
print(vocab[mx_tokens])
206+
print(torch_tokens)
207+
assert mx_data == torch_tokens.tolist()
208+
209+
mx_out = bert(mx.nd.array([mx_data]))
210+
print('stdev = ', np.std(mx_out.asnumpy() - pytorch_out))
211+
mx.test_utils.assert_almost_equal(mx_out.asnumpy(), pytorch_out, atol=1e-3, rtol=1e-3)
212+
mx.test_utils.assert_almost_equal(mx_out.asnumpy(), pytorch_out, atol=5e-6, rtol=5e-6)
213+
214+
bert.save_parameters(os.path.join(ckpt_dir, args.model + '.params'))
215+
with io.open(os.path.join(ckpt_dir, args.model + '.vocab'), 'w', encoding='utf-8') as f:
216+
f.write(vocab.to_json())
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# coding: utf-8
2+
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
"""Utility functions for BERT."""
20+
21+
import logging
22+
import collections
23+
import hashlib
24+
import io
25+
26+
import mxnet as mx
27+
import gluonnlp as nlp
28+
29+
__all__ = ['tf_vocab_to_gluon_vocab', 'load_text_vocab']
30+
31+
32+
def tf_vocab_to_gluon_vocab(tf_vocab):
33+
special_tokens = ['[UNK]', '[PAD]', '[SEP]', '[MASK]', '[CLS]']
34+
assert all(t in tf_vocab for t in special_tokens)
35+
counter = nlp.data.count_tokens(tf_vocab.keys())
36+
vocab = nlp.vocab.BERTVocab(counter, token_to_idx=tf_vocab)
37+
return vocab
38+
39+
40+
def get_hash(filename):
41+
sha1 = hashlib.sha1()
42+
with open(filename, 'rb') as f:
43+
while True:
44+
data = f.read(1048576)
45+
if not data:
46+
break
47+
sha1.update(data)
48+
return sha1.hexdigest(), str(sha1.hexdigest())[:8]
49+
50+
51+
def read_tf_checkpoint(path):
52+
"""read tensorflow checkpoint"""
53+
from tensorflow.python import pywrap_tensorflow
54+
tensors = {}
55+
reader = pywrap_tensorflow.NewCheckpointReader(path)
56+
var_to_shape_map = reader.get_variable_to_shape_map()
57+
for key in sorted(var_to_shape_map):
58+
tensor = reader.get_tensor(key)
59+
tensors[key] = tensor
60+
return tensors
61+
62+
def load_text_vocab(vocab_file):
63+
"""Loads a vocabulary file into a dictionary."""
64+
vocab = collections.OrderedDict()
65+
index = 0
66+
with io.open(vocab_file, 'r') as reader:
67+
while True:
68+
token = reader.readline()
69+
if not token:
70+
break
71+
token = token.strip()
72+
vocab[token] = index
73+
index += 1
74+
return vocab

scripts/bert/index.rst

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,31 @@ The following pre-trained BERT models are available from the **gluonnlp.model.ge
4747

4848
where **bert_12_768_12** refers to the BERT BASE model, and **bert_24_1024_16** refers to the BERT LARGE model.
4949

50+
.. code-block:: python
51+
52+
import gluonnlp as nlp; import mxnet as mx;
53+
model, vocab = nlp.model.get_model('bert_12_768_12', dataset_name='book_corpus_wiki_en_uncased', use_classifier=False);
54+
tokenizer = nlp.data.BERTTokenizer(vocab, lower=True);
55+
transform = nlp.data.BERTSentenceTransform(tokenizer, max_seq_length=512, pair=False, pad=False);
56+
sample = transform(['Hello world!']);
57+
words, valid_len, segments = mx.nd.array([sample[0]]), mx.nd.array([sample[1]]), mx.nd.array([sample[2]]);
58+
seq_encoding, cls_encoding = model(words, segments, valid_len);
59+
60+
Additionally, GluonNLP supports the "`RoBERTa <https://arxiv.org/abs/1907.11692>`_" model:
61+
62+
+-----------------------------------------+-------------------+--------------------+
63+
| | roberta_12_768_12 | roberta_24_1024_16 |
64+
+=========================================+===================+====================+
65+
| openwebtext_ccnews_stories_books_cased |||
66+
+-----------------------------------------+-------------------+--------------------+
67+
68+
.. code-block:: python
69+
70+
import gluonnlp as nlp; import mxnet as mx;
71+
model, vocab = nlp.model.get_model('roberta_12_768_12', dataset_name='openwebtext_ccnews_stories_books_cased');
72+
tokenizer = nlp.data.GPT2BPETokenizer();
73+
text = [vocab.bos_token] + tokenizer('Hello world!') + [vocab.eos_token];
74+
seq_encoding = model(mx.nd.array([vocab[text]]))
5075
5176
.. hint::
5277

scripts/bert/pretraining_utils.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040
__all__ = ['get_model_loss', 'get_pretrain_data_npz', 'get_dummy_dataloader',
4141
'save_parameters', 'save_states', 'evaluate', 'forward', 'split_and_load',
42-
'get_argparser', 'get_pretrain_data_text', 'generate_dev_set']
42+
'get_argparser', 'get_pretrain_data_text', 'generate_dev_set', 'profile']
4343

4444
def get_model_loss(ctx, model, pretrained, dataset_name, vocab, dtype,
4545
ckpt_dir=None, start_step=None):
@@ -505,3 +505,20 @@ def generate_dev_set(tokenizer, vocab, cache_file, args):
505505
1, args.num_data_workers,
506506
worker_pool, cache_file))
507507
logging.info('Done generating validation set on rank 0.')
508+
509+
def profile(curr_step, start_step, end_step, profile_name='profile.json',
510+
early_exit=True):
511+
"""profile the program between [start_step, end_step)."""
512+
if curr_step == start_step:
513+
mx.nd.waitall()
514+
mx.profiler.set_config(profile_memory=False, profile_symbolic=True,
515+
profile_imperative=True, filename=profile_name,
516+
aggregate_stats=True)
517+
mx.profiler.set_state('run')
518+
elif curr_step == end_step:
519+
mx.nd.waitall()
520+
mx.profiler.set_state('stop')
521+
logging.info(mx.profiler.dumps())
522+
mx.profiler.dump()
523+
if early_exit:
524+
exit()

scripts/bert/run_pretraining.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,10 @@
3737
import mxnet as mx
3838
import gluonnlp as nlp
3939

40-
from utils import profile
4140
from fp16_utils import FP16Trainer
4241
from pretraining_utils import get_model_loss, get_pretrain_data_npz, get_dummy_dataloader
4342
from pretraining_utils import log, evaluate, forward, split_and_load, get_argparser
44-
from pretraining_utils import save_parameters, save_states
43+
from pretraining_utils import save_parameters, save_states, profile
4544

4645
# arg parser
4746
parser = get_argparser()

scripts/bert/run_pretraining_hvd.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,10 @@
4040
import mxnet as mx
4141
import gluonnlp as nlp
4242

43-
from utils import profile
4443
from fp16_utils import FP16Trainer
4544
from pretraining_utils import get_model_loss, get_pretrain_data_npz, get_dummy_dataloader
4645
from pretraining_utils import split_and_load, log, evaluate, forward, get_argparser
47-
from pretraining_utils import save_parameters, save_states
46+
from pretraining_utils import save_parameters, save_states, profile
4847
from pretraining_utils import get_pretrain_data_text, generate_dev_set
4948

5049
# parser

src/gluonnlp/data/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,8 @@ def _slice_pad_length(num_items, length, overlap=0):
225225
'book_corpus_wiki_en_uncased': 'a66073971aa0b1a262453fe51342e57166a8abcf',
226226
'openwebtext_book_corpus_wiki_en_uncased':
227227
'a66073971aa0b1a262453fe51342e57166a8abcf',
228+
'openwebtext_ccnews_stories_books_cased':
229+
'2b804f8f90f9f93c07994b703ce508725061cf43',
228230
'wiki_multilingual_cased': '0247cb442074237c38c62021f36b7a4dbd2e55f7',
229231
'wiki_cn_cased': 'ddebd8f3867bca5a61023f73326fb125cf12b4f5',
230232
'wiki_multilingual_uncased': '2b2514cc539047b9179e9d98a4e68c36db05c97a',

src/gluonnlp/model/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,8 @@ def get_model(name, **kwargs):
142142
'transformer_en_de_512': transformer_en_de_512,
143143
'bert_12_768_12' : bert_12_768_12,
144144
'bert_24_1024_16' : bert_24_1024_16,
145+
'roberta_12_768_12' : roberta_12_768_12,
146+
'roberta_24_1024_16' : roberta_24_1024_16,
145147
'ernie_12_768_12' : ernie_12_768_12}
146148
name = name.lower()
147149
if name not in models:

0 commit comments

Comments
 (0)