Skip to content

Commit 9a33392

Browse files
authored
New models and refactoring (#276)
* Add `CodeLlamaTokenizer` * Add `codellama` for testing * Update default quantization settings * Refactor `PretrainedModel` * Remove unnecessary error message * Update llama-code-tokenizer test * Add support for `GPTNeoX` models * Fix `GPTNeoXPreTrainedModel` config * Add support for `GPTJ` models * Add support for `WavLM` models * Update list of supported models - CodeLlama - GPT NeoX - GPT-J - WavLM * Add support for XLM models * Add support for `ResNet` models * Add support for `BeiT` models * Fix casing of `BeitModel` * Remove duplicate code * Update variable name * Remove `ts-ignore` * Remove unnecessary duplication * Update demo model sizes * [demo] Update default summarization parameters * Update default quantization parameters for new models * Remove duplication in mapping * Update list of supported marian models * Add support for `CamemBERT` models * Add support for `MBart` models * Add support for `OPT` models * Add `MBartTokenizer` and `MBart50Tokenizer` * Add example of multilingual translation with MBart models * Add `CamembertTokenizer` * Add support for `HerBERT` models * Add support for `XLMTokenizer` * Fix `fuse_unk` config * Do not remove duplicate keys for `Unigram` models See https://huggingface.co/camembert-base for an example of a Unigram tokenizer that has two tokens with the same value (`<unk>`) * Update HerBERT supported model text * Update generate_tests.py * Update list of supported models * Use enum object instead of classes for model types Fixes #283 * Add link to issue * Update dependencies for unit tests * Add `sentencepiece` as a testing requirement * Add `protobuf` to test dependency * Remove duplicated models to test
1 parent 109a7f9 commit 9a33392

File tree

14 files changed

+1145
-959
lines changed

14 files changed

+1145
-959
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ jobs:
2727
node-version: ${{ matrix.node-version }}
2828
- run: npm ci
2929
- run: npm run build
30-
- run: pip install -r scripts/requirements.txt
30+
- run: pip install -r tests/requirements.txt
3131

3232
# Setup the testing environment
3333
- run: npm run generate-tests

README.md

Lines changed: 12 additions & 0 deletions
Large diffs are not rendered by default.

docs/snippets/6_supported-models.snippet

Lines changed: 12 additions & 0 deletions
Large diffs are not rendered by default.

examples/demo-site/src/index.html

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,16 +87,16 @@ <h2 class="fw-bolder">Demo</h2>
8787
<div class="col-12 mt-1">
8888
<select id="task" class="form-select">
8989
<option value="translation" selected>
90-
Translation w/ t5-small (95 MB)
90+
Translation w/ t5-small (78 MB)
9191
</option>
9292
<option value="text-generation">
93-
Text generation w/ distilgpt2 (122 MB)
93+
Text generation w/ distilgpt2 (85 MB)
9494
</option>
9595
<option value="masked-language-modelling">
96-
Masked language modelling w/ bert-base-cased (132 MB)
96+
Masked language modelling w/ bert-base-cased (110 MB)
9797
</option>
9898
<option value="sequence-classification">
99-
Text classification w/ bert-base-multilingual-uncased-sentiment (168 MB)
99+
Text classification w/ bert-base-multilingual-uncased-sentiment (169 MB)
100100
</option>
101101
<option value="token-classification">
102102
Token classification w/ Davlan/bert-base-multilingual-cased-ner-hrl (178 MB)
@@ -108,16 +108,16 @@ <h2 class="fw-bolder">Demo</h2>
108108
Question answering w/ distilbert-base-uncased-distilled-squad (66 MB)
109109
</option>
110110
<option value="summarization">
111-
Summarization w/ distilbart-cnn-6-6 (335 MB)
111+
Summarization w/ distilbart-cnn-6-6 (284 MB)
112112
</option>
113113
<option value="code-completion">
114114
Code completion w/ Salesforce/codegen-350M-mono (369 MB)
115115
</option>
116116
<option value="automatic-speech-recognition">
117-
Speech to text w/ whisper-tiny.en (61 MB)
117+
Speech to text w/ whisper-tiny.en (41 MB)
118118
</option>
119119
<option value="image-to-text">
120-
Image to text w/ vit-gpt2-image-captioning (283 MB)
120+
Image to text w/ vit-gpt2-image-captioning (246 MB)
121121
</option>
122122
<option value="image-classification">
123123
Image classification w/ google/vit-base-patch16-224 (88 MB)

examples/demo-site/src/main.js

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,13 @@ const TASK_DEFAULT_PARAMS = {
124124
multi_label: false
125125
},
126126
'question-answering': {},
127-
'summarization': DEFAULT_GREEDY_PARAMS,
127+
'summarization': {
128+
max_new_tokens: 50,
129+
num_beams: 2,
130+
temperature: 1,
131+
top_k: 0,
132+
do_sample: false
133+
},
128134
'automatic-speech-recognition': DEFAULT_GREEDY_PARAMS,
129135
'image-to-text': DEFAULT_GREEDY_PARAMS,
130136
'image-classification': {},

scripts/convert.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,62 @@
2626
}
2727

2828
MODEL_SPECIFIC_QUANTIZE_PARAMS = {
29-
'whisper': {
29+
# Decoder-only models
30+
'codegen': {
31+
'per_channel': False,
32+
'reduce_range': False,
33+
},
34+
'gpt2': {
35+
'per_channel': False,
36+
'reduce_range': False,
37+
},
38+
'gpt_bigcode': {
39+
'per_channel': False,
40+
'reduce_range': False,
41+
},
42+
'gptj': {
43+
'per_channel': False,
44+
'reduce_range': False,
45+
},
46+
'gpt-neo': {
47+
'per_channel': False,
48+
'reduce_range': False,
49+
},
50+
'gpt-neox': {
51+
'per_channel': False,
52+
'reduce_range': False,
53+
},
54+
'mpt': {
3055
'per_channel': False,
3156
'reduce_range': False,
3257
},
3358
'bloom': {
3459
'per_channel': False,
3560
'reduce_range': False,
61+
},
62+
'llama': {
63+
'per_channel': False,
64+
'reduce_range': False,
65+
},
66+
'opt': {
67+
'per_channel': False,
68+
'reduce_range': False,
69+
},
70+
71+
# Encoder-decoder models
72+
'whisper': {
73+
'per_channel': False,
74+
'reduce_range': False,
75+
},
76+
'vision-encoder-decoder': {
77+
'per_channel': False,
78+
'reduce_range': False,
3679
}
3780
}
3881

3982
MODELS_WITHOUT_TOKENIZERS = [
40-
'wav2vec2'
83+
'wav2vec2',
84+
'wavlm',
4185
]
4286

4387

@@ -294,6 +338,13 @@ def main():
294338
quantize_config = MODEL_SPECIFIC_QUANTIZE_PARAMS.get(
295339
config.model_type, DEFAULT_QUANTIZE_PARAMS)
296340

341+
# Update if user specified values
342+
if conv_args.per_channel is not None:
343+
quantize_config['per_channel'] = conv_args.per_channel
344+
345+
if conv_args.reduce_range is not None:
346+
quantize_config['reduce_range'] = conv_args.reduce_range
347+
297348
quantize([
298349
os.path.join(output_model_folder, x)
299350
for x in os.listdir(output_model_folder)

scripts/extra/marian.py

Lines changed: 47 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -6,41 +6,54 @@
66
# and make a pull request to this repo.
77

88
SUPPORTED_HELSINKI_NLP_MODELS = [
9-
'en-es', 'es-en', # English <-> Spanish
10-
'en-fr', 'fr-en', # English <-> French
11-
'en-hi', 'hi-en', # English <-> Hindi
12-
'en-de', 'de-en', # English <-> German
13-
'en-ru', 'ru-en', # English <-> Russian
14-
'en-it', 'it-en', # English <-> Italian
15-
'en-ar', 'ar-en', # English <-> Arabic
16-
'en-zh', 'zh-en', # English <-> Chinese
17-
'en-sv', 'sv-en', # English <-> Swedish
18-
'en-mul', 'mul-en', # English <-> Multilingual
19-
'en-nl', 'nl-en', # English <-> Dutch
20-
'en-fi', 'fi-en', # English <-> Finnish
21-
'en-jap', 'jap-en', # English <-> Japanese
22-
'en-cs', 'cs-en', # English <-> Czech
23-
'en-vi', 'vi-en', # English <-> Vietnamese
24-
'en-xh', 'xh-en', # English <-> Xhosa
25-
'en-hu', 'hu-en', # English <-> Hungarian
26-
'en-da', 'da-en', # English <-> Danish
27-
'en-id', 'id-en', # English <-> Indonesia
28-
'en-uk', 'uk-en', # English <-> Ukranian
29-
'en-af', 'af-en', # English <-> Afrikaans
30-
'de-es', 'es-de', # German <-> Spanish
31-
'fr-es', 'es-fr', # French <-> Spanish
32-
'fr-de', 'de-fr', # French <-> German
33-
'es-it', 'it-es', # Spanish <-> Italian
9+
'en-es', 'es-en', # English <-> Spanish
10+
'en-fr', 'fr-en', # English <-> French
11+
'en-hi', 'hi-en', # English <-> Hindi
12+
'en-de', 'de-en', # English <-> German
13+
'en-ru', 'ru-en', # English <-> Russian
14+
'en-it', 'it-en', # English <-> Italian
15+
'en-ar', 'ar-en', # English <-> Arabic
16+
'en-zh', 'zh-en', # English <-> Chinese
17+
'en-sv', 'sv-en', # English <-> Swedish
18+
'en-mul', 'mul-en', # English <-> Multilingual
19+
'en-nl', 'nl-en', # English <-> Dutch
20+
'en-fi', 'fi-en', # English <-> Finnish
21+
'en-jap', 'jap-en', # English <-> Japanese
22+
'en-cs', 'cs-en', # English <-> Czech
23+
'en-vi', 'vi-en', # English <-> Vietnamese
24+
'en-xh', 'xh-en', # English <-> Xhosa
25+
'en-hu', 'hu-en', # English <-> Hungarian
26+
'en-da', 'da-en', # English <-> Danish
27+
'en-id', 'id-en', # English <-> Indonesia
28+
'en-uk', 'uk-en', # English <-> Ukranian
29+
'en-af', 'af-en', # English <-> Afrikaans
30+
'en-ROMANCE', 'ROMANCE-en', # English <-> ROMANCE
31+
'de-es', 'es-de', # German <-> Spanish
32+
'fr-es', 'es-fr', # French <-> Spanish
33+
'fr-de', 'de-fr', # French <-> German
34+
'es-it', 'it-es', # Spanish <-> Italian
35+
'es-ru', 'ru-es', # Spanish <-> Russian
36+
'fr-ru', 'ru-fr', # French <-> Russian
37+
'fr-ro', 'ro-fr', # French <-> Romanian
38+
'uk-ru', 'ru-uk', # Ukranian <-> Russian
3439

35-
'en-ro', # English --> Romanian
36-
'pl-en', # Poland --> English
37-
'tr-en', # Turkey --> English
38-
'ko-en', # Korean --> English
39-
40-
'es-ru', 'ru-es', # Spanish <-> Russian
41-
'fr-ru', 'ru-fr', # French <-> Russian
42-
'fr-ro', 'ro-fr', # French <-> Romanian
43-
'uk-ru', 'ru-uk', # Ukranian <-> Russian
40+
'it-fr', # Italian --> French
41+
'en-ro', # English --> Romanian
42+
'pl-en', # Poland --> English
43+
'tr-en', # Turkey --> English
44+
'ko-en', # Korean --> English
45+
'bat-en', # Baltic --> English
46+
'et-en', # Estonian --> English
47+
'fi-de', # Finnish --> German
48+
'gem-gem', # Germanic <-> Germanic
49+
'gmw-gmw', # West Germanic <-> West Germanic
50+
'da-de', # Danish <-> German
51+
'ja-en', # Japanese --> English
52+
'nl-fr', # Netherlands --> French
53+
'no-de', # Norwegian --> German
54+
'tc-big-tr-en', # Turkish --> English
55+
'th-en', # Thai --> English
56+
'en-cs', # English --> Czech
4457
]
4558

4659

0 commit comments

Comments
 (0)