Skip to content

Commit c367f9d

Browse files
authored
Add support for Blenderbot and BlenderbotSmall (#292)
* Add support for `Blenderbot` models Closes #37 References #29 * Add support for `BlenderbotTokenizer` * Add blenderbot to supported models * Add support for `BlenderbotSmallTokenizer` * Add custom tests for blenderbot-small * Add support for `BlenderbotSmall` models * Update list of supported models * Improve `addPastKeyValues` function * Allow skipping of adding encoder past key values
1 parent c453e6b commit c367f9d

File tree

6 files changed

+181
-49
lines changed

6 files changed

+181
-49
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,8 @@ You can refine your search by selecting the task you're interested in (e.g., [te
258258
1. **[BART](https://huggingface.co/docs/transformers/model_doc/bart)** (from Facebook) released with the paper [BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension](https://arxiv.org/abs/1910.13461) by Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov and Luke Zettlemoyer.
259259
1. **[BEiT](https://huggingface.co/docs/transformers/model_doc/beit)** (from Microsoft) released with the paper [BEiT: BERT Pre-Training of Image Transformers](https://arxiv.org/abs/2106.08254) by Hangbo Bao, Li Dong, Furu Wei.
260260
1. **[BERT](https://huggingface.co/docs/transformers/model_doc/bert)** (from Google) released with the paper [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova.
261+
1. **[Blenderbot](https://huggingface.co/docs/transformers/model_doc/blenderbot)** (from Facebook) released with the paper [Recipes for building an open-domain chatbot](https://arxiv.org/abs/2004.13637) by Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston.
262+
1. **[BlenderbotSmall](https://huggingface.co/docs/transformers/model_doc/blenderbot-small)** (from Facebook) released with the paper [Recipes for building an open-domain chatbot](https://arxiv.org/abs/2004.13637) by Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston.
261263
1. **[BLOOM](https://huggingface.co/docs/transformers/model_doc/bloom)** (from BigScience workshop) released by the [BigScience Workshop](https://bigscience.huggingface.co/).
262264
1. **[CamemBERT](https://huggingface.co/docs/transformers/model_doc/camembert)** (from Inria/Facebook/Sorbonne) released with the paper [CamemBERT: a Tasty French Language Model](https://arxiv.org/abs/1911.03894) by Louis Martin*, Benjamin Muller*, Pedro Javier Ortiz Suárez*, Yoann Dupont, Laurent Romary, Éric Villemonte de la Clergerie, Djamé Seddah and Benoît Sagot.
263265
1. **[CLIP](https://huggingface.co/docs/transformers/model_doc/clip)** (from OpenAI) released with the paper [Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020) by Alec Radford, Jong Wook Kim, Chris Hallacy, Aditya Ramesh, Gabriel Goh, Sandhini Agarwal, Girish Sastry, Amanda Askell, Pamela Mishkin, Jack Clark, Gretchen Krueger, Ilya Sutskever.

docs/snippets/6_supported-models.snippet

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
1. **[BART](https://huggingface.co/docs/transformers/model_doc/bart)** (from Facebook) released with the paper [BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension](https://arxiv.org/abs/1910.13461) by Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov and Luke Zettlemoyer.
66
1. **[BEiT](https://huggingface.co/docs/transformers/model_doc/beit)** (from Microsoft) released with the paper [BEiT: BERT Pre-Training of Image Transformers](https://arxiv.org/abs/2106.08254) by Hangbo Bao, Li Dong, Furu Wei.
77
1. **[BERT](https://huggingface.co/docs/transformers/model_doc/bert)** (from Google) released with the paper [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova.
8+
1. **[Blenderbot](https://huggingface.co/docs/transformers/model_doc/blenderbot)** (from Facebook) released with the paper [Recipes for building an open-domain chatbot](https://arxiv.org/abs/2004.13637) by Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston.
9+
1. **[BlenderbotSmall](https://huggingface.co/docs/transformers/model_doc/blenderbot-small)** (from Facebook) released with the paper [Recipes for building an open-domain chatbot](https://arxiv.org/abs/2004.13637) by Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston.
810
1. **[BLOOM](https://huggingface.co/docs/transformers/model_doc/bloom)** (from BigScience workshop) released by the [BigScience Workshop](https://bigscience.huggingface.co/).
911
1. **[CamemBERT](https://huggingface.co/docs/transformers/model_doc/camembert)** (from Inria/Facebook/Sorbonne) released with the paper [CamemBERT: a Tasty French Language Model](https://arxiv.org/abs/1911.03894) by Louis Martin*, Benjamin Muller*, Pedro Javier Ortiz Suárez*, Yoann Dupont, Laurent Romary, Éric Villemonte de la Clergerie, Djamé Seddah and Benoît Sagot.
1012
1. **[CLIP](https://huggingface.co/docs/transformers/model_doc/clip)** (from OpenAI) released with the paper [Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020) by Alec Radford, Jong Wook Kim, Chris Hallacy, Aditya Ramesh, Gabriel Goh, Sandhini Agarwal, Girish Sastry, Amanda Askell, Pamela Mishkin, Jack Clark, Gretchen Krueger, Ilya Sutskever.

scripts/supported_models.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -97,16 +97,16 @@
9797
'bert-base-chinese',
9898
'emilyalsentzer/Bio_ClinicalBERT',
9999
],
100-
# 'blenderbot': [
101-
# # Text2text generation (TODO add conversational)
102-
# 'facebook/blenderbot-400M-distill',
103-
# 'facebook/blenderbot-1B-distill',
104-
# ],
105-
# 'blenderbot-small': [
106-
# # Text2text generation (TODO add conversational)
107-
# 'facebook/blenderbot-90M', # DEPRECATED
108-
# 'facebook/blenderbot_small-90M',
109-
# ],
100+
'blenderbot': [
101+
# Text2text generation (TODO add conversational)
102+
'facebook/blenderbot-400M-distill',
103+
# 'facebook/blenderbot-1B-distill',
104+
],
105+
'blenderbot-small': [
106+
# Text2text generation (TODO add conversational)
107+
# 'facebook/blenderbot-90M', # DEPRECATED
108+
'facebook/blenderbot_small-90M',
109+
],
110110
'bloom': [
111111
# Text generation
112112
'bigscience/bloom-560m',

src/models.js

Lines changed: 113 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,6 @@ function boolTensor(value) {
310310
* @private
311311
*/
312312
async function seq2seqForward(self, model_inputs) {
313-
const add_decoder_pkv = self.add_decoder_pkv ?? true;
314313

315314
let { encoder_outputs, past_key_values } = model_inputs;
316315

@@ -327,7 +326,7 @@ async function seq2seqForward(self, model_inputs) {
327326
if (self.decoder_merged_session.inputNames.includes('encoder_attention_mask')) {
328327
decoderFeeds.encoder_attention_mask = model_inputs.attention_mask
329328
}
330-
self.addPastKeyValues(decoderFeeds, past_key_values, add_decoder_pkv);
329+
self.addPastKeyValues(decoderFeeds, past_key_values);
331330

332331
const decoderResults = await sessionRun(self.decoder_merged_session, decoderFeeds);
333332
let logits = decoderResults.logits;
@@ -1199,57 +1198,51 @@ export class PreTrainedModel extends Callable {
11991198
*
12001199
* @param {Object} decoderFeeds The decoder feeds object to add past key values to.
12011200
* @param {Object} pastKeyValues An object containing past key values.
1202-
* @param {boolean} [hasDecoder=false] Whether the model has a decoder.
12031201
*/
1204-
addPastKeyValues(decoderFeeds, pastKeyValues, hasDecoder = false) {
1202+
addPastKeyValues(decoderFeeds, pastKeyValues) {
12051203
if (pastKeyValues) {
12061204
Object.assign(decoderFeeds, pastKeyValues)
12071205
} else {
12081206
// TODO support batches (i.e., batch_size > 1)
1209-
if (hasDecoder) {
1207+
// @ts-ignore
1208+
if (this.config.is_encoder_decoder && (this.add_encoder_pkv ?? true)) {
12101209
// @ts-ignore
12111210
let encoder_dims = [1, this.num_encoder_heads, 0, this.encoder_dim_kv];
1212-
// @ts-ignore
1213-
for (let i = 0; i < this.num_encoder_layers; ++i) {
1214-
decoderFeeds[`past_key_values.${i}.encoder.key`] = new Tensor('float32', [], encoder_dims)
1215-
decoderFeeds[`past_key_values.${i}.encoder.value`] = new Tensor('float32', [], encoder_dims)
1216-
}
1217-
12181211
// @ts-ignore
12191212
let decoder_dims = [1, this.num_decoder_heads, 0, this.decoder_dim_kv];
12201213
// @ts-ignore
12211214
for (let i = 0; i < this.num_decoder_layers; ++i) {
1215+
decoderFeeds[`past_key_values.${i}.encoder.key`] = new Tensor('float32', [], encoder_dims)
1216+
decoderFeeds[`past_key_values.${i}.encoder.value`] = new Tensor('float32', [], encoder_dims)
12221217
decoderFeeds[`past_key_values.${i}.decoder.key`] = new Tensor('float32', [], decoder_dims)
12231218
decoderFeeds[`past_key_values.${i}.decoder.value`] = new Tensor('float32', [], decoder_dims)
12241219
}
1220+
} else if (this.config.multi_query) { // e.g., for `gpt_bigcode`
1221+
// @ts-ignore
1222+
let dims = [1, 0, 2 * this.dim_kv]
1223+
// @ts-ignore
1224+
for (let i = 0; i < this.num_layers; ++i) {
1225+
decoderFeeds[`past_key_values.${i}.key_value`] = new Tensor('float32', [], dims)
1226+
}
1227+
} else if (this.config.model_type === 'bloom') {
1228+
// NOTE: Custom implementation for Bloom
12251229

1226-
} else {
1227-
if (this.config.multi_query) {
1228-
// @ts-ignore
1229-
let dims = [1, 0, 2 * this.dim_kv]
1230-
// @ts-ignore
1231-
for (let i = 0; i < this.num_layers; ++i) {
1232-
decoderFeeds[`past_key_values.${i}.key_value`] = new Tensor('float32', [], dims)
1233-
}
1234-
} else if (this.config.model_type === 'bloom') {
1235-
// Custom implementation for Bloom
1236-
// @ts-ignore
1237-
let keyDims = [1 * this.num_heads, this.dim_kv, 0] // [batch_size x num_heads,64,past_sequence_length]
1238-
// @ts-ignore
1239-
let valueDims = [1 * this.num_heads, 0, this.dim_kv] // [batch_size x num_heads,past_sequence_length,64]
1240-
// @ts-ignore
1241-
for (let i = 0; i < this.num_layers; ++i) {
1242-
decoderFeeds[`past_key_values.${i}.key`] = new Tensor('float32', [], keyDims)
1243-
decoderFeeds[`past_key_values.${i}.value`] = new Tensor('float32', [], valueDims)
1244-
}
1245-
} else {
1246-
// @ts-ignore
1247-
let dims = [1, this.num_heads, 0, this.dim_kv]
1248-
// @ts-ignore
1249-
for (let i = 0; i < this.num_layers; ++i) {
1250-
decoderFeeds[`past_key_values.${i}.key`] = new Tensor('float32', [], dims)
1251-
decoderFeeds[`past_key_values.${i}.value`] = new Tensor('float32', [], dims)
1252-
}
1230+
// @ts-ignore
1231+
let keyDims = [1 * this.num_heads, this.dim_kv, 0] // [batch_size x num_heads,64,past_sequence_length]
1232+
// @ts-ignore
1233+
let valueDims = [1 * this.num_heads, 0, this.dim_kv] // [batch_size x num_heads,past_sequence_length,64]
1234+
// @ts-ignore
1235+
for (let i = 0; i < this.num_layers; ++i) {
1236+
decoderFeeds[`past_key_values.${i}.key`] = new Tensor('float32', [], keyDims)
1237+
decoderFeeds[`past_key_values.${i}.value`] = new Tensor('float32', [], valueDims)
1238+
}
1239+
} else { // Decoder-only
1240+
// @ts-ignore
1241+
let dims = [1, this.num_heads, 0, this.dim_kv]
1242+
// @ts-ignore
1243+
for (let i = 0; i < this.num_layers; ++i) {
1244+
decoderFeeds[`past_key_values.${i}.key`] = new Tensor('float32', [], dims)
1245+
decoderFeeds[`past_key_values.${i}.value`] = new Tensor('float32', [], dims)
12531246
}
12541247
}
12551248
}
@@ -2033,6 +2026,83 @@ export class MBartForSequenceClassification extends MBartPreTrainedModel {
20332026

20342027
//////////////////////////////////////////////////
20352028

2029+
2030+
//////////////////////////////////////////////////
2031+
// Blenderbot models
2032+
export class BlenderbotPreTrainedModel extends PreTrainedModel { };
2033+
2034+
/**
2035+
* The bare Blenderbot Model outputting raw hidden-states without any specific head on top.
2036+
*/
2037+
export class BlenderbotModel extends BlenderbotPreTrainedModel { }
2038+
2039+
/**
2040+
* The Blenderbot Model with a language modeling head. Can be used for summarization.
2041+
*/
2042+
export class BlenderbotForConditionalGeneration extends BlenderbotPreTrainedModel {
2043+
2044+
/**
2045+
* Creates a new instance of the `BlenderbotForConditionalGeneration` class.
2046+
* @param {any} config The model configuration.
2047+
* @param {any} session The ONNX session containing the encoder weights.
2048+
* @param {any} decoder_merged_session The ONNX session containing the merged decoder weights.
2049+
* @param {GenerationConfig} generation_config The generation configuration.
2050+
*/
2051+
constructor(config, session, decoder_merged_session, generation_config) {
2052+
super(config, session);
2053+
this.decoder_merged_session = decoder_merged_session;
2054+
this.generation_config = generation_config;
2055+
2056+
this.num_decoder_layers = this.config.decoder_layers;
2057+
this.num_decoder_heads = this.config.decoder_attention_heads;
2058+
this.decoder_dim_kv = this.config.d_model / this.num_decoder_heads;
2059+
2060+
this.num_encoder_layers = this.config.encoder_layers;
2061+
this.num_encoder_heads = this.config.encoder_attention_heads;
2062+
this.encoder_dim_kv = this.config.d_model / this.num_encoder_heads;
2063+
}
2064+
}
2065+
//////////////////////////////////////////////////
2066+
2067+
2068+
//////////////////////////////////////////////////
2069+
// Blenderbot models
2070+
export class BlenderbotSmallPreTrainedModel extends PreTrainedModel { };
2071+
2072+
/**
2073+
* The bare BlenderbotSmall Model outputting raw hidden-states without any specific head on top.
2074+
*/
2075+
export class BlenderbotSmallModel extends BlenderbotSmallPreTrainedModel { }
2076+
2077+
/**
2078+
* The BlenderbotSmall Model with a language modeling head. Can be used for summarization.
2079+
*/
2080+
export class BlenderbotSmallForConditionalGeneration extends BlenderbotSmallPreTrainedModel {
2081+
2082+
/**
2083+
* Creates a new instance of the `BlenderbotForConditionalGeneration` class.
2084+
* @param {any} config The model configuration.
2085+
* @param {any} session The ONNX session containing the encoder weights.
2086+
* @param {any} decoder_merged_session The ONNX session containing the merged decoder weights.
2087+
* @param {GenerationConfig} generation_config The generation configuration.
2088+
*/
2089+
constructor(config, session, decoder_merged_session, generation_config) {
2090+
super(config, session);
2091+
this.decoder_merged_session = decoder_merged_session;
2092+
this.generation_config = generation_config;
2093+
2094+
this.num_decoder_layers = this.config.decoder_layers;
2095+
this.num_decoder_heads = this.config.decoder_attention_heads;
2096+
this.decoder_dim_kv = this.config.d_model / this.num_decoder_heads;
2097+
2098+
this.num_encoder_layers = this.config.encoder_layers;
2099+
this.num_encoder_heads = this.config.encoder_attention_heads;
2100+
this.encoder_dim_kv = this.config.d_model / this.num_encoder_heads;
2101+
}
2102+
}
2103+
//////////////////////////////////////////////////
2104+
2105+
20362106
//////////////////////////////////////////////////
20372107
// Roberta models
20382108
export class RobertaPreTrainedModel extends PreTrainedModel { }
@@ -2458,7 +2528,7 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel {
24582528
*/
24592529
export class VisionEncoderDecoderModel extends PreTrainedModel {
24602530
main_input_name = 'pixel_values';
2461-
add_decoder_pkv = false;
2531+
add_encoder_pkv = false;
24622532

24632533
/**
24642534
* Creates a new instance of the `VisionEncoderDecoderModel` class.
@@ -3422,6 +3492,8 @@ const MODEL_MAPPING_NAMES_ENCODER_DECODER = new Map([
34223492
['marian', ['MarianModel', MarianModel]],
34233493
['whisper', ['WhisperModel', WhisperModel]],
34243494
['m2m_100', ['M2M100Model', M2M100Model]],
3495+
['blenderbot', ['BlenderbotModel', BlenderbotModel]],
3496+
['blenderbot-small', ['BlenderbotSmallModel', BlenderbotSmallModel]],
34253497
]);
34263498

34273499

@@ -3475,6 +3547,8 @@ const MODEL_FOR_SEQ_2_SEQ_MAPPING_NAMES = new Map([
34753547
['whisper', ['WhisperForConditionalGeneration', WhisperForConditionalGeneration]],
34763548
['marian', ['MarianMTModel', MarianMTModel]],
34773549
['m2m_100', ['M2M100ForConditionalGeneration', M2M100ForConditionalGeneration]],
3550+
['blenderbot', ['BlenderbotForConditionalGeneration', BlenderbotForConditionalGeneration]],
3551+
['blenderbot-small', ['BlenderbotSmallForConditionalGeneration', BlenderbotSmallForConditionalGeneration]],
34783552
]);
34793553

34803554
const MODEL_WITH_LM_HEAD_MAPPING_NAMES = new Map([

0 commit comments

Comments
 (0)