|
| 1 | +# coding: utf-8 |
| 2 | + |
| 3 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 4 | +# or more contributor license agreements. See the NOTICE file |
| 5 | +# distributed with this work for additional information |
| 6 | +# regarding copyright ownership. The ASF licenses this file |
| 7 | +# to you under the Apache License, Version 2.0 (the |
| 8 | +# 'License'); you may not use this file except in compliance |
| 9 | +# with the License. You may obtain a copy of the License at |
| 10 | +# |
| 11 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 12 | +# |
| 13 | +# Unless required by applicable law or agreed to in writing, |
| 14 | +# software distributed under the License is distributed on an |
| 15 | +# 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 16 | +# KIND, either express or implied. See the License for the |
| 17 | +# specific language governing permissions and limitations |
| 18 | +# under the License. |
| 19 | +# pylint:disable=redefined-outer-name,logging-format-interpolation |
| 20 | +""" Script for converting Fairseq Roberta Model to Gluon. """ |
| 21 | +import argparse |
| 22 | +import logging |
| 23 | +import os |
| 24 | +import sys |
| 25 | +import io |
| 26 | +import numpy as np |
| 27 | + |
| 28 | +import torch |
| 29 | +from fairseq.models.roberta import RobertaModel |
| 30 | + |
| 31 | +import mxnet as mx |
| 32 | +import gluonnlp as nlp |
| 33 | +from gluonnlp.model import BERTEncoder, BERTModel |
| 34 | +from gluonnlp.model.bert import bert_hparams |
| 35 | +from gluonnlp.data.utils import _load_pretrained_vocab |
| 36 | + |
| 37 | +from utils import get_hash, load_text_vocab, tf_vocab_to_gluon_vocab |
| 38 | + |
| 39 | +parser = argparse.ArgumentParser(description='Conversion script for Fairseq RoBERTa model', |
| 40 | + formatter_class=argparse.ArgumentDefaultsHelpFormatter) |
| 41 | +parser.add_argument('--ckpt_dir', type=str, help='Full path to the roberta folder', |
| 42 | + default='/home/ubuntu/roberta/roberta.base') |
| 43 | +parser.add_argument('--model', type=str, help='Model type. ', |
| 44 | + choices=['roberta_12_768_12', 'roberta_24_1024_16'], |
| 45 | + default='roberta_12_768_12') |
| 46 | +parser.add_argument('--verbose', action='store_true', help='Verbose logging') |
| 47 | + |
| 48 | +args = parser.parse_args() |
| 49 | + |
| 50 | +ckpt_dir = os.path.expanduser(args.ckpt_dir) |
| 51 | + |
| 52 | +ckpt = torch.load(os.path.join(ckpt_dir, 'model.pt')) |
| 53 | +pytorch_params = ckpt['model'] |
| 54 | + |
| 55 | +if args.verbose: |
| 56 | + print(ckpt['args']) |
| 57 | + for k, v in pytorch_params.items(): |
| 58 | + print(k, v.shape) |
| 59 | + |
| 60 | +# Load the model in fairseq |
| 61 | +roberta = RobertaModel.from_pretrained(ckpt_dir) |
| 62 | +roberta.eval() |
| 63 | + |
| 64 | +def fairseq_vocab_to_gluon_vocab(torch_vocab): |
| 65 | + index_to_words = [None] * len(torch_vocab) |
| 66 | + |
| 67 | + bos_idx = torch_vocab.bos() |
| 68 | + pad_idx = torch_vocab.pad() |
| 69 | + eos_idx = torch_vocab.eos() |
| 70 | + unk_idx = torch_vocab.unk() |
| 71 | + |
| 72 | + index_to_words[bos_idx] = torch_vocab.symbols[bos_idx] |
| 73 | + index_to_words[pad_idx] = torch_vocab.symbols[pad_idx] |
| 74 | + index_to_words[eos_idx] = torch_vocab.symbols[eos_idx] |
| 75 | + index_to_words[unk_idx] = torch_vocab.symbols[unk_idx] |
| 76 | + |
| 77 | + specials = [bos_idx, pad_idx, eos_idx, unk_idx] |
| 78 | + |
| 79 | + openai_to_roberta = {} |
| 80 | + openai_vocab = _load_pretrained_vocab('openai_webtext', '.') |
| 81 | + |
| 82 | + with io.open(os.path.join(ckpt_dir, 'dict.txt'), encoding='utf-8') as f: |
| 83 | + for i, line in enumerate(f): |
| 84 | + token, count = line.split(' ') |
| 85 | + try: |
| 86 | + fake_token = int(token) |
| 87 | + openai_to_roberta[token] = i + len(specials) |
| 88 | + except ValueError: |
| 89 | + index_to_words[i + len(specials)] = token |
| 90 | + |
| 91 | + for idx, token in enumerate(openai_vocab.idx_to_token): |
| 92 | + if str(idx) in openai_to_roberta: |
| 93 | + index_to_words[openai_to_roberta[str(idx)]] = token |
| 94 | + else: |
| 95 | + assert token == u'<mask>', token |
| 96 | + |
| 97 | + mask_idx = torch_vocab.index(u'<mask>') |
| 98 | + index_to_words[mask_idx] = torch_vocab.string([mask_idx]) |
| 99 | + assert None not in index_to_words |
| 100 | + word2idx = {} |
| 101 | + for idx, token in enumerate(index_to_words): |
| 102 | + word2idx[token] = idx |
| 103 | + |
| 104 | + vocab = nlp.vocab.Vocab(word2idx, token_to_idx=word2idx, |
| 105 | + unknown_token=index_to_words[unk_idx], |
| 106 | + padding_token=index_to_words[pad_idx], |
| 107 | + bos_token=index_to_words[bos_idx], |
| 108 | + eos_token=index_to_words[eos_idx], |
| 109 | + mask_token=u'<mask>') |
| 110 | + return vocab |
| 111 | + |
| 112 | +vocab = fairseq_vocab_to_gluon_vocab(roberta.task.dictionary) |
| 113 | + |
| 114 | +predefined_args = bert_hparams[args.model] |
| 115 | + |
| 116 | +# BERT encoder |
| 117 | +encoder = BERTEncoder(attention_cell=predefined_args['attention_cell'], |
| 118 | + num_layers=predefined_args['num_layers'], units=predefined_args['units'], |
| 119 | + hidden_size=predefined_args['hidden_size'], |
| 120 | + max_length=predefined_args['max_length'], |
| 121 | + num_heads=predefined_args['num_heads'], scaled=predefined_args['scaled'], |
| 122 | + dropout=predefined_args['dropout'], |
| 123 | + use_residual=predefined_args['use_residual'], |
| 124 | + layer_norm_eps=predefined_args['layer_norm_eps']) |
| 125 | + |
| 126 | +# BERT model |
| 127 | +bert = BERTModel(encoder, len(vocab), |
| 128 | + units=predefined_args['units'], embed_size=predefined_args['embed_size'], |
| 129 | + embed_dropout=predefined_args['embed_dropout'], |
| 130 | + word_embed=predefined_args['word_embed'], use_pooler=False, |
| 131 | + use_token_type_embed=False, use_classifier=False) |
| 132 | + |
| 133 | +bert.initialize(init=mx.init.Normal(0.02)) |
| 134 | + |
| 135 | +ones = mx.nd.ones((2, 8)) |
| 136 | +out = bert(ones, None, mx.nd.array([5, 6]), mx.nd.array([[1], [2]])) |
| 137 | +params = bert._collect_params_with_prefix() |
| 138 | + |
| 139 | + |
| 140 | + |
| 141 | +mapping = { |
| 142 | + 'decoder.2' : 'decoder.lm_head.layer_norm', |
| 143 | + 'decoder.0' : 'decoder.lm_head.dense', |
| 144 | + 'decoder.3' : 'decoder.lm_head', |
| 145 | + 'encoder.layer_norm' : 'decoder.sentence_encoder.emb_layer_norm', |
| 146 | + 'encoder.position_weight' : 'decoder.sentence_encoder.embed_positions.weight', |
| 147 | + 'encoder.transformer_cells': 'decoder.sentence_encoder.layers', |
| 148 | + 'attention_cell.proj_key.' : 'self_attn.in_proj_', |
| 149 | + 'attention_cell.proj_value.' : 'self_attn.in_proj_', |
| 150 | + 'attention_cell.proj_query.' : 'self_attn.in_proj_', |
| 151 | + 'ffn.ffn_1' : 'fc1', |
| 152 | + 'ffn.ffn_2' : 'fc2', |
| 153 | + 'layer_norm.gamma' : 'layer_norm.weight', |
| 154 | + 'layer_norm.beta' : 'layer_norm.bias', |
| 155 | + 'ffn.layer_norm' : 'final_layer_norm', |
| 156 | + 'word_embed.0.weight' : 'decoder.sentence_encoder.embed_tokens.weight', |
| 157 | +} |
| 158 | + |
| 159 | +for i in range(24): |
| 160 | + mapping['{}.layer_norm'.format(i)] = '{}.self_attn_layer_norm'.format(i) |
| 161 | + mapping['{}.proj'.format(i)] = '{}.self_attn.out_proj'.format(i) |
| 162 | + |
| 163 | +# set parameter data |
| 164 | +loaded_params = {} |
| 165 | +visited_pytorch_params = {} |
| 166 | +for name in params: |
| 167 | + pytorch_name = name |
| 168 | + for source, dest in mapping.items(): |
| 169 | + pytorch_name = pytorch_name.replace(source, dest) |
| 170 | + |
| 171 | + assert pytorch_name in pytorch_params.keys(), 'Key ' + pytorch_name + ' for ' + name + ' not found.' |
| 172 | + torch_arr = pytorch_params[pytorch_name].cpu() |
| 173 | + # fairseq positional embedding starts with index 2 |
| 174 | + if pytorch_name == 'decoder.sentence_encoder.embed_positions.weight': |
| 175 | + torch_arr = torch_arr[2:] |
| 176 | + |
| 177 | + arr = mx.nd.array(torch_arr) |
| 178 | + if 'attention_cell.proj' in name: |
| 179 | + unfused = ['query', 'key', 'value'] |
| 180 | + arrs = arr.split(num_outputs=3, axis=0) |
| 181 | + for i, p in enumerate(unfused): |
| 182 | + if p in name: |
| 183 | + arr = arrs[i] |
| 184 | + else: |
| 185 | + assert arr.shape == params[name].shape, (arr.shape, params[name].shape, name, pytorch_name) |
| 186 | + params[name].set_data(arr) |
| 187 | + loaded_params[name] = True |
| 188 | + visited_pytorch_params[pytorch_name] = True |
| 189 | + |
| 190 | +assert len(params) == len(loaded_params) |
| 191 | +assert len(visited_pytorch_params) == len(pytorch_params), "Gluon model does not match PyTorch model. " \ |
| 192 | + "Please fix the BERTModel hyperparameters\n" + str(len(visited_pytorch_params)) + ' v.s. ' + str(len(pytorch_params)) |
| 193 | + |
| 194 | + |
| 195 | +texts = 'Hello world. abc, def and 中文!' |
| 196 | +torch_tokens = roberta.encode(texts) |
| 197 | + |
| 198 | +torch_features = roberta.extract_features(torch_tokens) |
| 199 | +pytorch_out = torch_features.detach().numpy() |
| 200 | + |
| 201 | +mx_tokenizer = nlp.data.GPT2BPETokenizer() |
| 202 | +mx_tokens = [vocab.bos_token] + mx_tokenizer(texts) + [vocab.eos_token] |
| 203 | +mx_data = vocab[mx_tokens] |
| 204 | +print(mx_tokens) |
| 205 | +print(vocab[mx_tokens]) |
| 206 | +print(torch_tokens) |
| 207 | +assert mx_data == torch_tokens.tolist() |
| 208 | + |
| 209 | +mx_out = bert(mx.nd.array([mx_data])) |
| 210 | +print('stdev = ', np.std(mx_out.asnumpy() - pytorch_out)) |
| 211 | +mx.test_utils.assert_almost_equal(mx_out.asnumpy(), pytorch_out, atol=1e-3, rtol=1e-3) |
| 212 | +mx.test_utils.assert_almost_equal(mx_out.asnumpy(), pytorch_out, atol=5e-6, rtol=5e-6) |
| 213 | + |
| 214 | +bert.save_parameters(os.path.join(ckpt_dir, args.model + '.params')) |
| 215 | +with io.open(os.path.join(ckpt_dir, args.model + '.vocab'), 'w', encoding='utf-8') as f: |
| 216 | + f.write(vocab.to_json()) |
0 commit comments