Skip to content

Commit 1e157ba

Browse files
xenovaAschen
andauthored
Add support for Deberta models (#244)
* add documentation for zero shot classification * add multi_label example * review comments * edit examples data * Add deberta and deberta-v2 model definitions * Update model mapping * Implement missing `Strip` normalizer * Add deberta and deberta-v2 tokenizers * Add fast path to `Strip` normalizer * Add token types to deberta tokenizer output * Update supported_models.py * Fix default Precompiled normalization * Update supported models list * Update JSDoc * Support `not_entailment` label * Update mult-label example JSDoc --------- Co-authored-by: Aschen <[email protected]>
1 parent db7d0f0 commit 1e157ba

File tree

7 files changed

+282
-13
lines changed

7 files changed

+282
-13
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,8 @@ You can refine your search by selecting the task you're interested in (e.g., [te
258258
1. **[BERT](https://huggingface.co/docs/transformers/model_doc/bert)** (from Google) released with the paper [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova.
259259
1. **[CLIP](https://huggingface.co/docs/transformers/model_doc/clip)** (from OpenAI) released with the paper [Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020) by Alec Radford, Jong Wook Kim, Chris Hallacy, Aditya Ramesh, Gabriel Goh, Sandhini Agarwal, Girish Sastry, Amanda Askell, Pamela Mishkin, Jack Clark, Gretchen Krueger, Ilya Sutskever.
260260
1. **[CodeGen](https://huggingface.co/docs/transformers/model_doc/codegen)** (from Salesforce) released with the paper [A Conversational Paradigm for Program Synthesis](https://arxiv.org/abs/2203.13474) by Erik Nijkamp, Bo Pang, Hiroaki Hayashi, Lifu Tu, Huan Wang, Yingbo Zhou, Silvio Savarese, Caiming Xiong.
261+
1. **[DeBERTa](https://huggingface.co/docs/transformers/model_doc/deberta)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen.
262+
1. **[DeBERTa-v2](https://huggingface.co/docs/transformers/model_doc/deberta-v2)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen.
261263
1. **[DETR](https://huggingface.co/docs/transformers/model_doc/detr)** (from Facebook) released with the paper [End-to-End Object Detection with Transformers](https://arxiv.org/abs/2005.12872) by Nicolas Carion, Francisco Massa, Gabriel Synnaeve, Nicolas Usunier, Alexander Kirillov, Sergey Zagoruyko.
262264
1. **[DistilBERT](https://huggingface.co/docs/transformers/model_doc/distilbert)** (from HuggingFace), released together with the paper [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/abs/1910.01108) by Victor Sanh, Lysandre Debut and Thomas Wolf. The same method has been applied to compress GPT2 into [DistilGPT2](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation), RoBERTa into [DistilRoBERTa](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation), Multilingual BERT into [DistilmBERT](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation) and a German version of DistilBERT.
263265
1. **[FLAN-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5)** (from Google AI) released in the repository [google-research/t5x](https://github.com/google-research/t5x/blob/main/docs/models.md#flan-t5-checkpoints) by Hyung Won Chung, Le Hou, Shayne Longpre, Barret Zoph, Yi Tay, William Fedus, Eric Li, Xuezhi Wang, Mostafa Dehghani, Siddhartha Brahma, Albert Webson, Shixiang Shane Gu, Zhuyun Dai, Mirac Suzgun, Xinyun Chen, Aakanksha Chowdhery, Sharan Narang, Gaurav Mishra, Adams Yu, Vincent Zhao, Yanping Huang, Andrew Dai, Hongkun Yu, Slav Petrov, Ed H. Chi, Jeff Dean, Jacob Devlin, Adam Roberts, Denny Zhou, Quoc V. Le, and Jason Wei

docs/snippets/6_supported-models.snippet

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
1. **[BERT](https://huggingface.co/docs/transformers/model_doc/bert)** (from Google) released with the paper [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova.
77
1. **[CLIP](https://huggingface.co/docs/transformers/model_doc/clip)** (from OpenAI) released with the paper [Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020) by Alec Radford, Jong Wook Kim, Chris Hallacy, Aditya Ramesh, Gabriel Goh, Sandhini Agarwal, Girish Sastry, Amanda Askell, Pamela Mishkin, Jack Clark, Gretchen Krueger, Ilya Sutskever.
88
1. **[CodeGen](https://huggingface.co/docs/transformers/model_doc/codegen)** (from Salesforce) released with the paper [A Conversational Paradigm for Program Synthesis](https://arxiv.org/abs/2203.13474) by Erik Nijkamp, Bo Pang, Hiroaki Hayashi, Lifu Tu, Huan Wang, Yingbo Zhou, Silvio Savarese, Caiming Xiong.
9+
1. **[DeBERTa](https://huggingface.co/docs/transformers/model_doc/deberta)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen.
10+
1. **[DeBERTa-v2](https://huggingface.co/docs/transformers/model_doc/deberta-v2)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen.
911
1. **[DETR](https://huggingface.co/docs/transformers/model_doc/detr)** (from Facebook) released with the paper [End-to-End Object Detection with Transformers](https://arxiv.org/abs/2005.12872) by Nicolas Carion, Francisco Massa, Gabriel Synnaeve, Nicolas Usunier, Alexander Kirillov, Sergey Zagoruyko.
1012
1. **[DistilBERT](https://huggingface.co/docs/transformers/model_doc/distilbert)** (from HuggingFace), released together with the paper [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/abs/1910.01108) by Victor Sanh, Lysandre Debut and Thomas Wolf. The same method has been applied to compress GPT2 into [DistilGPT2](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation), RoBERTa into [DistilRoBERTa](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation), Multilingual BERT into [DistilmBERT](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation) and a German version of DistilBERT.
1113
1. **[FLAN-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5)** (from Google AI) released in the repository [google-research/t5x](https://github.com/google-research/t5x/blob/main/docs/models.md#flan-t5-checkpoints) by Hyung Won Chung, Le Hou, Shayne Longpre, Barret Zoph, Yi Tay, William Fedus, Eric Li, Xuezhi Wang, Mostafa Dehghani, Siddhartha Brahma, Albert Webson, Shixiang Shane Gu, Zhuyun Dai, Mirac Suzgun, Xinyun Chen, Aakanksha Chowdhery, Sharan Narang, Gaurav Mishra, Adams Yu, Vincent Zhao, Yanping Huang, Andrew Dai, Hongkun Yu, Slav Petrov, Ed H. Chi, Jeff Dean, Jacob Devlin, Adam Roberts, Denny Zhou, Quoc V. Le, and Jason Wei

scripts/supported_models.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,24 @@
8282
'Salesforce/codegen-350M-multi',
8383
'Salesforce/codegen-350M-nl',
8484
],
85+
'deberta': [
86+
'cross-encoder/nli-deberta-base',
87+
'Narsil/deberta-large-mnli-zero-cls',
88+
],
89+
'deberta-v2': [
90+
'cross-encoder/nli-deberta-v3-xsmall',
91+
'cross-encoder/nli-deberta-v3-small',
92+
'cross-encoder/nli-deberta-v3-base',
93+
'cross-encoder/nli-deberta-v3-large',
94+
'MoritzLaurer/DeBERTa-v3-xsmall-mnli-fever-anli-ling-binary',
95+
'MoritzLaurer/DeBERTa-v3-base-mnli',
96+
'MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli',
97+
'MoritzLaurer/DeBERTa-v3-large-mnli-fever-anli-ling-wanli',
98+
'MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7',
99+
'navteca/nli-deberta-v3-xsmall',
100+
'sileod/deberta-v3-base-tasksource-nli',
101+
'sileod/deberta-v3-large-tasksource-nli',
102+
],
85103
'detr': [
86104
'facebook/detr-resnet-50',
87105
'facebook/detr-resnet-101',
@@ -133,7 +151,7 @@
133151
# https://github.com/huggingface/optimum/issues/1027
134152
# 'google/mobilebert-uncased',
135153
],
136-
'mobilevit':[
154+
'mobilevit': [
137155
'apple/mobilevit-small',
138156
'apple/mobilevit-x-small',
139157
'apple/mobilevit-xx-small',

src/models.js

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1283,6 +1283,148 @@ export class BertForQuestionAnswering extends BertPreTrainedModel {
12831283
}
12841284
//////////////////////////////////////////////////
12851285

1286+
//////////////////////////////////////////////////
1287+
// DeBERTa models
1288+
export class DebertaPreTrainedModel extends PreTrainedModel { }
1289+
1290+
/**
1291+
* The bare DeBERTa Model transformer outputting raw hidden-states without any specific head on top.
1292+
*/
1293+
export class DebertaModel extends DebertaPreTrainedModel { }
1294+
1295+
/**
1296+
* DeBERTa Model with a `language modeling` head on top.
1297+
*/
1298+
export class DebertaForMaskedLM extends DebertaPreTrainedModel {
1299+
/**
1300+
* Calls the model on new inputs.
1301+
*
1302+
* @param {Object} model_inputs The inputs to the model.
1303+
* @returns {Promise<MaskedLMOutput>} An object containing the model's output logits for masked language modeling.
1304+
*/
1305+
async _call(model_inputs) {
1306+
return new MaskedLMOutput(await super._call(model_inputs));
1307+
}
1308+
}
1309+
1310+
/**
1311+
* DeBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled output)
1312+
*/
1313+
export class DebertaForSequenceClassification extends DebertaPreTrainedModel {
1314+
/**
1315+
* Calls the model on new inputs.
1316+
*
1317+
* @param {Object} model_inputs The inputs to the model.
1318+
* @returns {Promise<SequenceClassifierOutput>} An object containing the model's output logits for sequence classification.
1319+
*/
1320+
async _call(model_inputs) {
1321+
return new SequenceClassifierOutput(await super._call(model_inputs));
1322+
}
1323+
}
1324+
1325+
/**
1326+
* DeBERTa Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks.
1327+
*/
1328+
export class DebertaForTokenClassification extends DebertaPreTrainedModel {
1329+
/**
1330+
* Calls the model on new inputs.
1331+
*
1332+
* @param {Object} model_inputs The inputs to the model.
1333+
* @returns {Promise<TokenClassifierOutput>} An object containing the model's output logits for token classification.
1334+
*/
1335+
async _call(model_inputs) {
1336+
return new TokenClassifierOutput(await super._call(model_inputs));
1337+
}
1338+
}
1339+
1340+
/**
1341+
* DeBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1342+
* layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1343+
*/
1344+
export class DebertaForQuestionAnswering extends DebertaPreTrainedModel {
1345+
/**
1346+
* Calls the model on new inputs.
1347+
*
1348+
* @param {Object} model_inputs The inputs to the model.
1349+
* @returns {Promise<QuestionAnsweringModelOutput>} An object containing the model's output logits for question answering.
1350+
*/
1351+
async _call(model_inputs) {
1352+
return new QuestionAnsweringModelOutput(await super._call(model_inputs));
1353+
}
1354+
}
1355+
//////////////////////////////////////////////////
1356+
1357+
//////////////////////////////////////////////////
1358+
// DeBERTa-v2 models
1359+
export class DebertaV2PreTrainedModel extends PreTrainedModel { }
1360+
1361+
/**
1362+
* The bare DeBERTa-V2 Model transformer outputting raw hidden-states without any specific head on top.
1363+
*/
1364+
export class DebertaV2Model extends DebertaV2PreTrainedModel { }
1365+
1366+
/**
1367+
* DeBERTa-V2 Model with a `language modeling` head on top.
1368+
*/
1369+
export class DebertaV2ForMaskedLM extends DebertaV2PreTrainedModel {
1370+
/**
1371+
* Calls the model on new inputs.
1372+
*
1373+
* @param {Object} model_inputs The inputs to the model.
1374+
* @returns {Promise<MaskedLMOutput>} An object containing the model's output logits for masked language modeling.
1375+
*/
1376+
async _call(model_inputs) {
1377+
return new MaskedLMOutput(await super._call(model_inputs));
1378+
}
1379+
}
1380+
1381+
/**
1382+
* DeBERTa-V2 Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled output)
1383+
*/
1384+
export class DebertaV2ForSequenceClassification extends DebertaV2PreTrainedModel {
1385+
/**
1386+
* Calls the model on new inputs.
1387+
*
1388+
* @param {Object} model_inputs The inputs to the model.
1389+
* @returns {Promise<SequenceClassifierOutput>} An object containing the model's output logits for sequence classification.
1390+
*/
1391+
async _call(model_inputs) {
1392+
return new SequenceClassifierOutput(await super._call(model_inputs));
1393+
}
1394+
}
1395+
1396+
/**
1397+
* DeBERTa-V2 Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks.
1398+
*/
1399+
export class DebertaV2ForTokenClassification extends DebertaV2PreTrainedModel {
1400+
/**
1401+
* Calls the model on new inputs.
1402+
*
1403+
* @param {Object} model_inputs The inputs to the model.
1404+
* @returns {Promise<TokenClassifierOutput>} An object containing the model's output logits for token classification.
1405+
*/
1406+
async _call(model_inputs) {
1407+
return new TokenClassifierOutput(await super._call(model_inputs));
1408+
}
1409+
}
1410+
1411+
/**
1412+
* DeBERTa-V2 Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1413+
* layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1414+
*/
1415+
export class DebertaV2ForQuestionAnswering extends DebertaV2PreTrainedModel {
1416+
/**
1417+
* Calls the model on new inputs.
1418+
*
1419+
* @param {Object} model_inputs The inputs to the model.
1420+
* @returns {Promise<QuestionAnsweringModelOutput>} An object containing the model's output logits for question answering.
1421+
*/
1422+
async _call(model_inputs) {
1423+
return new QuestionAnsweringModelOutput(await super._call(model_inputs));
1424+
}
1425+
}
1426+
//////////////////////////////////////////////////
1427+
12861428
//////////////////////////////////////////////////
12871429
// DistilBert models
12881430
export class DistilBertPreTrainedModel extends PreTrainedModel { }
@@ -3089,6 +3231,8 @@ export class PretrainedMixin {
30893231

30903232
const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([
30913233
['bert', BertModel],
3234+
['deberta', DebertaModel],
3235+
['deberta-v2', DebertaV2Model],
30923236
['mpnet', MPNetModel],
30933237
['albert', AlbertModel],
30943238
['distilbert', DistilBertModel],
@@ -3120,6 +3264,8 @@ const MODEL_MAPPING_NAMES_DECODER_ONLY = new Map([
31203264

31213265
const MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = new Map([
31223266
['bert', BertForSequenceClassification],
3267+
['deberta', DebertaForSequenceClassification],
3268+
['deberta-v2', DebertaV2ForSequenceClassification],
31233269
['mpnet', MPNetForSequenceClassification],
31243270
['albert', AlbertForSequenceClassification],
31253271
['distilbert', DistilBertForSequenceClassification],
@@ -3132,6 +3278,8 @@ const MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = new Map([
31323278

31333279
const MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = new Map([
31343280
['bert', BertForTokenClassification],
3281+
['deberta', DebertaForTokenClassification],
3282+
['deberta-v2', DebertaV2ForTokenClassification],
31353283
['mpnet', MPNetForTokenClassification],
31363284
['distilbert', DistilBertForTokenClassification],
31373285
['roberta', RobertaForTokenClassification],
@@ -3156,6 +3304,8 @@ const MODEL_WITH_LM_HEAD_MAPPING_NAMES = new Map([
31563304

31573305
const MODEL_FOR_MASKED_LM_MAPPING_NAMES = new Map([
31583306
['bert', BertForMaskedLM],
3307+
['deberta', DebertaForMaskedLM],
3308+
['deberta-v2', DebertaV2ForMaskedLM],
31593309
['mpnet', MPNetForMaskedLM],
31603310
['albert', AlbertForMaskedLM],
31613311
['distilbert', DistilBertForMaskedLM],
@@ -3167,6 +3317,8 @@ const MODEL_FOR_MASKED_LM_MAPPING_NAMES = new Map([
31673317

31683318
const MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = new Map([
31693319
['bert', BertForQuestionAnswering],
3320+
['deberta', DebertaForQuestionAnswering],
3321+
['deberta-v2', DebertaV2ForQuestionAnswering],
31703322
['mpnet', MPNetForQuestionAnswering],
31713323
['albert', AlbertForQuestionAnswering],
31723324
['distilbert', DistilBertForQuestionAnswering],

src/pipelines.js

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -460,17 +460,17 @@ export class TranslationPipeline extends Text2TextGenerationPipeline {
460460
* **Example:** Text generation with `Xenova/distilgpt2` (default settings).
461461
* ```javascript
462462
* let text = 'I enjoy walking with my cute dog,';
463-
* let generator = await pipeline('text-generation', 'Xenova/distilgpt2');
464-
* let output = await generator(text);
463+
* let classifier = await pipeline('text-generation', 'Xenova/distilgpt2');
464+
* let output = await classifier(text);
465465
* console.log(output);
466466
* // [{ generated_text: "I enjoy walking with my cute dog, and I love to play with the other dogs." }]
467467
* ```
468468
*
469469
* **Example:** Text generation with `Xenova/distilgpt2` (custom settings).
470470
* ```javascript
471471
* let text = 'Once upon a time, there was';
472-
* let generator = await pipeline('text-generation', 'Xenova/distilgpt2');
473-
* let output = await generator(text, {
472+
* let classifier = await pipeline('text-generation', 'Xenova/distilgpt2');
473+
* let output = await classifier(text, {
474474
* temperature: 2,
475475
* max_new_tokens: 10,
476476
* repetition_penalty: 1.5,
@@ -489,8 +489,8 @@ export class TranslationPipeline extends Text2TextGenerationPipeline {
489489
* **Example:** Run code generation with `Xenova/codegen-350M-mono`.
490490
* ```javascript
491491
* let text = 'def fib(n):';
492-
* let generator = await pipeline('text-generation', 'Xenova/codegen-350M-mono');
493-
* let output = await generator(text, {
492+
* let classifier = await pipeline('text-generation', 'Xenova/codegen-350M-mono');
493+
* let output = await classifier(text, {
494494
* max_new_tokens: 40,
495495
* });
496496
* console.log(output[0].generated_text);
@@ -550,6 +550,35 @@ export class TextGenerationPipeline extends Pipeline {
550550
* trained on NLI (natural language inference) tasks. Equivalent of `text-classification`
551551
* pipelines, but these models don't require a hardcoded number of potential classes, they
552552
* can be chosen at runtime. It usually means it's slower but it is **much** more flexible.
553+
*
554+
* **Example:** Zero shot classification with `Xenova/mobilebert-uncased-mnli`.
555+
* ```javascript
556+
* let text = 'Last week I upgraded my iOS version and ever since then my phone has been overheating whenever I use your app.';
557+
* let labels = [ 'mobile', 'billing', 'website', 'account access' ];
558+
* let classifier = await pipeline('zero-shot-classification', 'Xenova/mobilebert-uncased-mnli');
559+
* let output = await classifier(text, labels);
560+
* console.log(output);
561+
* // {
562+
* // sequence: 'Last week I upgraded my iOS version and ever since then my phone has been overheating whenever I use your app.',
563+
* // labels: [ 'mobile', 'website', 'billing', 'account access' ],
564+
* // scores: [ 0.5562091040482018, 0.1843621307860853, 0.13942646639336376, 0.12000229877234923 ]
565+
* // }
566+
* ```
567+
*
568+
* **Example:** Zero shot classification with `Xenova/nli-deberta-v3-xsmall` (multi-label).
569+
* ```javascript
570+
* let text = 'I have a problem with my iphone that needs to be resolved asap!';
571+
* let labels = [ 'urgent', 'not urgent', 'phone', 'tablet', 'computer' ];
572+
* let classifier = await pipeline('zero-shot-classification', 'Xenova/nli-deberta-v3-xsmall');
573+
* let output = await classifier(text, labels, { multi_label: true });
574+
* console.log(output);
575+
* // {
576+
* // sequence: 'I have a problem with my iphone that needs to be resolved asap!',
577+
* // labels: [ 'urgent', 'phone', 'computer', 'tablet', 'not urgent' ],
578+
* // scores: [ 0.9958870956360275, 0.9923963400697035, 0.002333537946160235, 0.0015134138567598765, 0.0010699384208377163 ]
579+
* // }
580+
* ```
581+
*
553582
* @extends Pipeline
554583
*/
555584
export class ZeroShotClassificationPipeline extends Pipeline {
@@ -576,7 +605,7 @@ export class ZeroShotClassificationPipeline extends Pipeline {
576605
this.entailment_id = 2;
577606
}
578607

579-
this.contradiction_id = this.label2id['contradiction'];
608+
this.contradiction_id = this.label2id['contradiction'] ?? this.label2id['not_entailment'];
580609
if (this.contradiction_id === undefined) {
581610
console.warn("Could not find 'contradiction' in label2id mapping. Using 0 as contradiction_id.");
582611
this.contradiction_id = 0;

0 commit comments

Comments
 (0)