Skip to content

Commit 9077c21

Browse files
authored
Add support for BLOOM models (#273)
* Add support for Bloom models * Update `BloomTokenizer` to fix the default (invalid) regex * Update supported models * Update default quantization settings for bloom models * Fix `use_cache_branch`
1 parent 62159eb commit 9077c21

File tree

6 files changed

+125
-6
lines changed

6 files changed

+125
-6
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te
256256
1. **[ALBERT](https://huggingface.co/docs/transformers/model_doc/albert)** (from Google Research and the Toyota Technological Institute at Chicago) released with the paper [ALBERT: A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942), by Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut.
257257
1. **[BART](https://huggingface.co/docs/transformers/model_doc/bart)** (from Facebook) released with the paper [BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension](https://arxiv.org/abs/1910.13461) by Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov and Luke Zettlemoyer.
258258
1. **[BERT](https://huggingface.co/docs/transformers/model_doc/bert)** (from Google) released with the paper [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova.
259+
1. **[BLOOM](https://huggingface.co/docs/transformers/model_doc/bloom)** (from BigScience workshop) released by the [BigScience Workshop](https://bigscience.huggingface.co/).
259260
1. **[CLIP](https://huggingface.co/docs/transformers/model_doc/clip)** (from OpenAI) released with the paper [Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020) by Alec Radford, Jong Wook Kim, Chris Hallacy, Aditya Ramesh, Gabriel Goh, Sandhini Agarwal, Girish Sastry, Amanda Askell, Pamela Mishkin, Jack Clark, Gretchen Krueger, Ilya Sutskever.
260261
1. **[CodeGen](https://huggingface.co/docs/transformers/model_doc/codegen)** (from Salesforce) released with the paper [A Conversational Paradigm for Program Synthesis](https://arxiv.org/abs/2203.13474) by Erik Nijkamp, Bo Pang, Hiroaki Hayashi, Lifu Tu, Huan Wang, Yingbo Zhou, Silvio Savarese, Caiming Xiong.
261262
1. **[DeBERTa](https://huggingface.co/docs/transformers/model_doc/deberta)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen.

docs/snippets/6_supported-models.snippet

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
1. **[ALBERT](https://huggingface.co/docs/transformers/model_doc/albert)** (from Google Research and the Toyota Technological Institute at Chicago) released with the paper [ALBERT: A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942), by Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut.
55
1. **[BART](https://huggingface.co/docs/transformers/model_doc/bart)** (from Facebook) released with the paper [BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension](https://arxiv.org/abs/1910.13461) by Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov and Luke Zettlemoyer.
66
1. **[BERT](https://huggingface.co/docs/transformers/model_doc/bert)** (from Google) released with the paper [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova.
7+
1. **[BLOOM](https://huggingface.co/docs/transformers/model_doc/bloom)** (from BigScience workshop) released by the [BigScience Workshop](https://bigscience.huggingface.co/).
78
1. **[CLIP](https://huggingface.co/docs/transformers/model_doc/clip)** (from OpenAI) released with the paper [Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020) by Alec Radford, Jong Wook Kim, Chris Hallacy, Aditya Ramesh, Gabriel Goh, Sandhini Agarwal, Girish Sastry, Amanda Askell, Pamela Mishkin, Jack Clark, Gretchen Krueger, Ilya Sutskever.
89
1. **[CodeGen](https://huggingface.co/docs/transformers/model_doc/codegen)** (from Salesforce) released with the paper [A Conversational Paradigm for Program Synthesis](https://arxiv.org/abs/2203.13474) by Erik Nijkamp, Bo Pang, Hiroaki Hayashi, Lifu Tu, Huan Wang, Yingbo Zhou, Silvio Savarese, Caiming Xiong.
910
1. **[DeBERTa](https://huggingface.co/docs/transformers/model_doc/deberta)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen.

scripts/convert.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@
2929
'whisper': {
3030
'per_channel': False,
3131
'reduce_range': False,
32+
},
33+
'bloom': {
34+
'per_channel': False,
35+
'reduce_range': False,
3236
}
3337
}
3438

scripts/supported_models.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,10 @@
6767
'unitary/toxic-bert',
6868
],
6969
# TODO:
70-
# 'bloom':[
71-
# 'bigscience/bloom-560m',
72-
# 'bigscience/bloomz-560m',
73-
# ],
70+
'bloom': [
71+
'bigscience/bloom-560m',
72+
# 'bigscience/bloomz-560m',
73+
],
7474
# TODO:
7575
# 'blenderbot-small': [
7676
# 'facebook/blenderbot_small-90M',

src/models.js

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ async function decoderForward(self, model_inputs) {
462462
let decoderFeeds = {
463463
input_ids: input_ids,
464464
attention_mask: attention_mask ?? prepareAttentionMask(self, input_ids),
465-
use_cache_branch: boolTensor(past_key_values !== null)
465+
use_cache_branch: boolTensor(!!past_key_values)
466466
}
467467

468468
self.addPastKeyValues(decoderFeeds, past_key_values);
@@ -1178,6 +1178,17 @@ export class PreTrainedModel extends Callable {
11781178
for (let i = 0; i < this.num_layers; ++i) {
11791179
decoderFeeds[`past_key_values.${i}.key_value`] = new Tensor('float32', [], dims)
11801180
}
1181+
} else if (this.config.model_type === 'bloom') {
1182+
// Custom implementation for Bloom
1183+
// @ts-ignore
1184+
let keyDims = [1 * this.num_heads, this.dim_kv, 0] // [batch_size x num_heads,64,past_sequence_length]
1185+
// @ts-ignore
1186+
let valueDims = [1 * this.num_heads, 0, this.dim_kv] // [batch_size x num_heads,past_sequence_length,64]
1187+
// @ts-ignore
1188+
for (let i = 0; i < this.num_layers; ++i) {
1189+
decoderFeeds[`past_key_values.${i}.key`] = new Tensor('float32', [], keyDims)
1190+
decoderFeeds[`past_key_values.${i}.value`] = new Tensor('float32', [], valueDims)
1191+
}
11811192
} else {
11821193
// @ts-ignore
11831194
let dims = [1, this.num_heads, 0, this.dim_kv]
@@ -2660,6 +2671,9 @@ export class GPT2LMHeadModel extends GPT2PreTrainedModel {
26602671
// TODO
26612672
// }
26622673
//////////////////////////////////////////////////
2674+
2675+
//////////////////////////////////////////////////
2676+
// GPTNeo models
26632677
export class GPTNeoPreTrainedModel extends PreTrainedModel {
26642678
/**
26652679
* Creates a new instance of the `GPTNeoPreTrainedModel` class.
@@ -2985,6 +2999,92 @@ export class LlamaForCausalLM extends LlamaPreTrainedModel {
29852999
//////////////////////////////////////////////////
29863000

29873001

3002+
//////////////////////////////////////////////////
3003+
// Bloom models
3004+
/**
3005+
* The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings).
3006+
*/
3007+
export class BloomPreTrainedModel extends PreTrainedModel {
3008+
/**
3009+
* Creates a new instance of the `BloomPreTrainedModel` class.
3010+
* @param {Object} config The configuration of the model.
3011+
* @param {any} session The ONNX session containing the model weights.
3012+
*/
3013+
constructor(config, session) {
3014+
super(config, session);
3015+
3016+
// config doesn't contain pad_token_id, so we assume it is the eos_token_id
3017+
this.config.pad_token_id = this.config.eos_token_id
3018+
3019+
this.num_heads = this.config.n_head
3020+
this.num_layers = this.config.n_layer
3021+
this.dim_kv = this.config.hidden_size / this.num_heads;
3022+
}
3023+
}
3024+
3025+
/**
3026+
* The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.
3027+
*/
3028+
export class BloomModel extends BloomPreTrainedModel {
3029+
3030+
/**
3031+
* BloomModel is not compatible with `.generate()`, as it doesn't have a language model head.
3032+
* @param {...any} args
3033+
* @throws {Error}
3034+
* @returns {Promise<any>}
3035+
*/
3036+
async generate(...args) {
3037+
throw Error(
3038+
"The current model class (BloomModel) is not compatible with `.generate()`, as it doesn't have a language model head. Please use one of the following classes instead: {'BloomForCausalLM'}"
3039+
)
3040+
}
3041+
}
3042+
3043+
/**
3044+
* The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings).
3045+
*/
3046+
export class BloomForCausalLM extends BloomPreTrainedModel {
3047+
3048+
/**
3049+
* Initializes and returns the beam for text generation task
3050+
* @param {Tensor} inputTokenIds The input token ids.
3051+
* @param {number} numOutputTokens The number of tokens to be generated.
3052+
* @param {Tensor} inputs_attention_mask Optional input attention mask.
3053+
* @returns {any} A Beam object representing the initialized beam.
3054+
*/
3055+
getStartBeams(inputTokenIds, numOutputTokens, inputs_attention_mask) {
3056+
return decoderStartBeams(this, inputTokenIds, numOutputTokens, inputs_attention_mask)
3057+
}
3058+
3059+
/**
3060+
* Runs a single step of the beam search generation algorithm.
3061+
* @param {any} beam The current beam being generated.
3062+
* @returns {Promise<any>} The updated beam after a single generation step.
3063+
*/
3064+
async runBeam(beam) {
3065+
return await decoderRunBeam(this, beam);
3066+
}
3067+
3068+
/**
3069+
* Updates the given beam with the new generated token id.
3070+
* @param {any} beam The Beam object representing the beam.
3071+
* @param {number} newTokenId The new generated token id to be added to the beam.
3072+
*/
3073+
updateBeam(beam, newTokenId) {
3074+
return decoderUpdatebeam(beam, newTokenId);
3075+
}
3076+
3077+
/**
3078+
* Forward pass for the model.
3079+
* @param {Object} model_inputs The inputs for the model.
3080+
* @returns {Promise<any>} The output tensor of the model.
3081+
*/
3082+
async forward(model_inputs) {
3083+
return await decoderForward(this, model_inputs);
3084+
}
3085+
}
3086+
//////////////////////////////////////////////////
3087+
29883088
//////////////////////////////////////////////////
29893089
export class ViTPreTrainedModel extends PreTrainedModel { }
29903090
export class ViTModel extends ViTPreTrainedModel { }
@@ -3478,6 +3578,7 @@ const MODEL_MAPPING_NAMES_ENCODER_DECODER = new Map([
34783578

34793579

34803580
const MODEL_MAPPING_NAMES_DECODER_ONLY = new Map([
3581+
['bloom', BloomModel],
34813582
['gpt2', GPT2Model],
34823583
['gpt_bigcode', GPTBigCodeModel],
34833584
['gpt_neo', GPTNeoModel],
@@ -3519,6 +3620,7 @@ const MODEL_FOR_SEQ_2_SEQ_MAPPING_NAMES = new Map([
35193620
]);
35203621

35213622
const MODEL_WITH_LM_HEAD_MAPPING_NAMES = new Map([
3623+
['bloom', BloomForCausalLM],
35223624
['gpt2', GPT2LMHeadModel],
35233625
['gpt_bigcode', GPTBigCodeForCausalLM],
35243626
['gpt_neo', GPTNeoForCausalLM],

src/tokenizers.js

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2566,7 +2566,18 @@ export class GPT2Tokenizer extends PreTrainedTokenizer { }
25662566
export class BartTokenizer extends PreTrainedTokenizer { }
25672567
export class RobertaTokenizer extends PreTrainedTokenizer { }
25682568

2569-
export class BloomTokenizer extends PreTrainedTokenizer { }
2569+
export class BloomTokenizer extends PreTrainedTokenizer {
2570+
constructor(tokenizerJSON, tokenizerConfig) {
2571+
// Override the default (invalid) regex of the pretokenizer.
2572+
// For more information, see https://github.com/xenova/transformers.js/issues/94
2573+
const splitChars = '.,!?\u2026\u3002\uff0c\u3001\u0964\u06d4\u060c';
2574+
const patternObject = tokenizerJSON.pre_tokenizer?.pretokenizers[0]?.pattern;
2575+
if (patternObject && patternObject.Regex === ` ?[^(\\s|[${splitChars}])]+`) {
2576+
patternObject.Regex = ` ?[^\\s${splitChars}]+`;
2577+
}
2578+
super(tokenizerJSON, tokenizerConfig);
2579+
}
2580+
}
25702581
export class LlamaTokenizer extends PreTrainedTokenizer { }
25712582

25722583
export class XLMRobertaTokenizer extends PreTrainedTokenizer { }

0 commit comments

Comments
 (0)