diff --git a/python-lib/macro/model_configurations.py b/python-lib/macro/model_configurations.py index 4e47f2d..fa02acb 100644 --- a/python-lib/macro/model_configurations.py +++ b/python-lib/macro/model_configurations.py @@ -1,3 +1,15 @@ +from transformers import (tokenization_bert, + tokenization_gpt2, + tokenization_transfo_xl, + tokenization_xlnet, + tokenization_roberta, + tokenization_distilbert, + tokenization_ctrl, + tokenization_camembert, + tokenization_albert, + tokenization_t5, + tokenization_bart) + NON_TRANSFORMER_MODELS = ["word2vec","fasttext","glove","elmo","use"] TRANSFORMERS_MODELS = ['bert-base-uncased', 'bert-large-uncased', 'bert-base-cased', 'bert-large-cased', 'bert-base-multilingual-uncased', 'bert-base-multilingual-cased', 'bert-base-chinese', 'bert-base-german-cased', 'bert-large-uncased-whole-word-masking', 'bert-large-cased-whole-word-masking', 'bert-large-uncased-whole-word-masking-finetuned-squad', 'bert-large-cased-whole-word-masking-finetuned-squad', 'bert-base-cased-finetuned-mrpc', 'bert-base-german-dbmdz-cased', 'bert-base-german-dbmdz-uncased', 'cl-tohoku/bert-base-japanese', 'cl-tohoku/bert-base-japanese-whole-word-masking', 'cl-tohoku/bert-base-japanese-char', 'cl-tohoku/bert-base-japanese-char-whole-word-masking', 'TurkuNLP/bert-base-finnish-cased-v1', 'TurkuNLP/bert-base-finnish-uncased-v1', 'wietsedv/bert-base-dutch-cased', 'facebook/bart-large', 'facebook/bart-large-mnli', 'facebook/bart-large-cnn', 'facebook/mbart-large-en-ro', 'openai-gpt', 'transfo-xl-wt103', 'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl', 'distilgpt2', 'ctrl', 'xlnet-base-cased', 'xlnet-large-cased', 'xlm-mlm-en-2048', 'xlm-mlm-ende-1024', 'xlm-mlm-enfr-1024', 'xlm-mlm-enro-1024', 'xlm-mlm-tlm-xnli15-1024', 'xlm-mlm-xnli15-1024', 'xlm-clm-enfr-1024', 'xlm-clm-ende-1024', 'xlm-mlm-17-1280', 'xlm-mlm-100-1280', 'roberta-base', 'roberta-large', 'roberta-large-mnli', 'distilroberta-base', 'roberta-base-openai-detector', 'roberta-large-openai-detector', 'distilbert-base-uncased', 'distilbert-base-uncased-distilled-squad', 'distilbert-base-cased', 'distilbert-base-cased-distilled-squad', 'distilbert-base-german-cased', 'distilbert-base-multilingual-cased', 'albert-base-v1', 'albert-large-v1', 'albert-xlarge-v1', 'albert-xxlarge-v1', 'albert-base-v2', 'albert-large-v2', 'albert-xlarge-v2', 'albert-xxlarge-v2', 'camembert-base', 't5-small', 't5-base', 't5-large', 'xlm-roberta-base', 'xlm-roberta-large', 'flaubert/flaubert_small_cased', 'flaubert/flaubert_base_uncased', 'flaubert/flaubert_base_cased', 'flaubert/flaubert_large_cased', 'allenai/longformer-base-4096', 'allenai/longformer-large-4096'] MODEL_CONFIFURATIONS = { @@ -1077,4 +1089,58 @@ } +BART_URL_MAP = { + 'vocab_file': {model: tokenization_bart.vocab_url for model in tokenization_bart._all_bart_models}, + 'merges_file': {model: tokenization_bart.merges_url for model in tokenization_bart._all_bart_models} +} +BART_FILE_NAMES = tokenization_roberta.VOCAB_FILES_NAMES + + +TOKENIZER_CONFIGURATIONS = { + 'BERT': { + 'url_map': tokenization_bert.PRETRAINED_VOCAB_FILES_MAP, + 'file_names': tokenization_bert.VOCAB_FILES_NAMES + }, + 'GPT-2': { + 'url_map': tokenization_gpt2.PRETRAINED_VOCAB_FILES_MAP, + 'file_names': tokenization_gpt2.VOCAB_FILES_NAMES + }, + 'CamemBERT': { + 'url_map': tokenization_camembert.PRETRAINED_VOCAB_FILES_MAP, + 'file_names': tokenization_camembert.VOCAB_FILES_NAMES + }, + 'ALBERT': { + 'url_map': tokenization_albert.PRETRAINED_VOCAB_FILES_MAP, + 'file_names': tokenization_albert.VOCAB_FILES_NAMES + }, + 'CTRL': { + 'url_map': tokenization_ctrl.PRETRAINED_VOCAB_FILES_MAP, + 'file_names': tokenization_ctrl.VOCAB_FILES_NAMES + }, + 'DistilBERT': { + 'url_map': tokenization_distilbert.PRETRAINED_VOCAB_FILES_MAP, + 'file_names': tokenization_distilbert.VOCAB_FILES_NAMES + }, + 'RoBERTa': { + 'url_map': tokenization_roberta.PRETRAINED_VOCAB_FILES_MAP, + 'file_names': tokenization_roberta.VOCAB_FILES_NAMES + }, + 'Bart': { + 'url_map': BART_URL_MAP, + 'file_names': BART_FILE_NAMES + }, + 'T5': { + 'url_map': tokenization_t5.PRETRAINED_VOCAB_FILES_MAP, + 'file_names': tokenization_t5.VOCAB_FILES_NAMES + }, + 'Transformer-XL': { + 'url_map': tokenization_transfo_xl.PRETRAINED_VOCAB_FILES_MAP, + 'file_names': tokenization_transfo_xl.VOCAB_FILES_NAMES + }, + 'XLNet': { + 'url_map': tokenization_xlnet.PRETRAINED_VOCAB_FILES_MAP, + 'file_names': tokenization_xlnet.VOCAB_FILES_NAMES + }, +} + diff --git a/python-lib/macro/model_downloaders.py b/python-lib/macro/model_downloaders.py index 927f359..f4ec298 100644 --- a/python-lib/macro/model_downloaders.py +++ b/python-lib/macro/model_downloaders.py @@ -8,7 +8,7 @@ import io import zipfile import logging -from macro.model_configurations import MODEL_CONFIFURATIONS +from macro.model_configurations import MODEL_CONFIFURATIONS, TOKENIZER_CONFIGURATIONS import time from transformers.file_utils import (S3_BUCKET_PREFIX, CLOUDFRONT_DISTRIB_PREFIX, @@ -21,7 +21,7 @@ WORD2VEC_BASE_URL = "http://vectors.nlpl.eu/repository/20/{}.zip" FASTTEXT_BASE_URL = "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.{}.300.vec.gz" -HG_FILENAMES = ["pytorch_model.bin","config.json","vocab.txt"] +HG_FILENAMES = ["pytorch_model.bin", "config.json"] class BaseDownloader(object): def __init__(self,folder,macro_inputs,proxy,progress_callback): @@ -293,6 +293,16 @@ def run(self): bytes_so_far = self.download_plain(response, bytes_so_far) elif response.status_code == 404: pass + tokenizer_dict = TOKENIZER_CONFIGURATIONS[self.embedding_family] + for file_type in tokenizer_dict['url_map'].keys(): + self.archive_name = self.language + '/' + self.embedding_family + '/' + self.model_shortcut_name.replace("/","_") + '/' + tokenizer_dict['file_names'][file_type] + download_link = tokenizer_dict['url_map'][file_type][self.model_shortcut_name] + response = self.get_stream(download_link) + if response.status_code == 200: + bytes_so_far = self.download_plain(response, bytes_so_far) + elif response.status_code == 404: + pass + def get_file_size(self, response=None): total_size = 0 @@ -303,6 +313,14 @@ def get_file_size(self, response=None): total_size += int(response.headers.get('content-length')) elif response.status_code == 404: total_size += 0 + tokenizer_dict = TOKENIZER_CONFIGURATIONS[self.embedding_family] + for file_type in tokenizer_dict['url_map'].keys(): + download_link = tokenizer_dict['url_map'][file_type][self.model_shortcut_name] + response = self.get_stream(download_link) + if response.status_code == 200: + total_size += int(response.headers.get('content-length')) + elif response.status_code == 404: + total_size += 0 return total_size