Skip to content

Commit 57f2b5c

Browse files
authored
Add support for MPT models (Fixes #166) (#272)
* Add support for MPT models * Fix `use_cache_branch` * Update list of supported models
1 parent 96b9143 commit 57f2b5c

File tree

4 files changed

+92
-1
lines changed

4 files changed

+92
-1
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te
274274
1. **[MobileBERT](https://huggingface.co/docs/transformers/model_doc/mobilebert)** (from CMU/Google Brain) released with the paper [MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited Devices](https://arxiv.org/abs/2004.02984) by Zhiqing Sun, Hongkun Yu, Xiaodan Song, Renjie Liu, Yiming Yang, and Denny Zhou.
275275
1. **[MobileViT](https://huggingface.co/docs/transformers/model_doc/mobilevit)** (from Apple) released with the paper [MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer](https://arxiv.org/abs/2110.02178) by Sachin Mehta and Mohammad Rastegari.
276276
1. **[MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)** (from Microsoft Research) released with the paper [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu.
277+
1. **[MPT](https://huggingface.co/docs/transformers/model_doc/mpt)** (from MosaiML) released with the repository [llm-foundry](https://github.com/mosaicml/llm-foundry/) by the MosaicML NLP Team.
277278
1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel.
278279
1. **[NLLB](https://huggingface.co/docs/transformers/model_doc/nllb)** (from Meta) released with the paper [No Language Left Behind: Scaling Human-Centered Machine Translation](https://arxiv.org/abs/2207.04672) by the NLLB team.
279280
1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (from Facebook), released together with the paper [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.

docs/snippets/6_supported-models.snippet

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
1. **[MobileBERT](https://huggingface.co/docs/transformers/model_doc/mobilebert)** (from CMU/Google Brain) released with the paper [MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited Devices](https://arxiv.org/abs/2004.02984) by Zhiqing Sun, Hongkun Yu, Xiaodan Song, Renjie Liu, Yiming Yang, and Denny Zhou.
2323
1. **[MobileViT](https://huggingface.co/docs/transformers/model_doc/mobilevit)** (from Apple) released with the paper [MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer](https://arxiv.org/abs/2110.02178) by Sachin Mehta and Mohammad Rastegari.
2424
1. **[MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)** (from Microsoft Research) released with the paper [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu.
25+
1. **[MPT](https://huggingface.co/docs/transformers/model_doc/mpt)** (from MosaiML) released with the repository [llm-foundry](https://github.com/mosaicml/llm-foundry/) by the MosaicML NLP Team.
2526
1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel.
2627
1. **[NLLB](https://huggingface.co/docs/transformers/model_doc/nllb)** (from Meta) released with the paper [No Language Left Behind: Scaling Human-Centered Machine Translation](https://arxiv.org/abs/2207.04672) by the NLLB team.
2728
1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (from Facebook), released together with the paper [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.

scripts/supported_models.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,9 @@
169169
'apple/deeplabv3-mobilevit-x-small',
170170
'apple/deeplabv3-mobilevit-xx-small',
171171
],
172+
'mpt': [
173+
'efederici/ipt-350m',
174+
],
172175
'mpnet': [
173176
'sentence-transformers/all-mpnet-base-v2',
174177
'sentence-transformers/nli-mpnet-base-v2',

src/models.js

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2998,7 +2998,6 @@ export class LlamaForCausalLM extends LlamaPreTrainedModel {
29982998
}
29992999
//////////////////////////////////////////////////
30003000

3001-
30023001
//////////////////////////////////////////////////
30033002
// Bloom models
30043003
/**
@@ -3085,6 +3084,91 @@ export class BloomForCausalLM extends BloomPreTrainedModel {
30853084
}
30863085
//////////////////////////////////////////////////
30873086

3087+
//////////////////////////////////////////////////
3088+
// MPT models
3089+
export class MptPreTrainedModel extends PreTrainedModel {
3090+
/**
3091+
* Creates a new instance of the `MptPreTrainedModel` class.
3092+
* @param {Object} config The model configuration object.
3093+
* @param {Object} session The ONNX session object.
3094+
*/
3095+
constructor(config, session) {
3096+
super(config, session);
3097+
3098+
// config doesn't contain pad_token_id, so we assume it is the eos_token_id
3099+
this.config.pad_token_id = this.config.eos_token_id
3100+
3101+
this.num_heads = this.config.n_heads
3102+
this.num_layers = this.config.n_layers
3103+
this.dim_kv = this.config.d_model / this.num_heads;
3104+
}
3105+
}
3106+
3107+
/**
3108+
* The bare Mpt Model transformer outputting raw hidden-states without any specific head on top.
3109+
*/
3110+
export class MptModel extends MptPreTrainedModel {
3111+
/**
3112+
* Throws an error indicating that the current model class is not compatible with `.generate()`,
3113+
* as it doesn't have a language model head.
3114+
*
3115+
* @throws {Error} The current model class is not compatible with `.generate()`
3116+
*
3117+
* @param {...any} args Arguments passed to the generate function
3118+
* @returns {Promise<any>}
3119+
*/
3120+
async generate(...args) {
3121+
throw Error(
3122+
"The current model class (MptModel) is not compatible with `.generate()`, as it doesn't have a language model head. Please use one of the following classes instead: {'MptForCausalLM'}"
3123+
)
3124+
}
3125+
}
3126+
3127+
/**
3128+
* The MPT Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings).
3129+
*/
3130+
export class MptForCausalLM extends MptPreTrainedModel {
3131+
3132+
/**
3133+
* Initializes and returns the beam for text generation task
3134+
* @param {Tensor} inputTokenIds The input token ids.
3135+
* @param {number} numOutputTokens The number of tokens to be generated.
3136+
* @param {Tensor} inputs_attention_mask Optional input attention mask.
3137+
* @returns {any} A Beam object representing the initialized beam.
3138+
*/
3139+
getStartBeams(inputTokenIds, numOutputTokens, inputs_attention_mask) {
3140+
return decoderStartBeams(this, inputTokenIds, numOutputTokens, inputs_attention_mask)
3141+
}
3142+
3143+
/**
3144+
* Runs a single step of the beam search generation algorithm.
3145+
* @param {any} beam The current beam being generated.
3146+
* @returns {Promise<any>} The updated beam after a single generation step.
3147+
*/
3148+
async runBeam(beam) {
3149+
return await decoderRunBeam(this, beam);
3150+
}
3151+
3152+
/**
3153+
* Updates the given beam with the new generated token id.
3154+
* @param {any} beam The Beam object representing the beam.
3155+
* @param {number} newTokenId The new generated token id to be added to the beam.
3156+
*/
3157+
updateBeam(beam, newTokenId) {
3158+
return decoderUpdatebeam(beam, newTokenId);
3159+
}
3160+
3161+
/**
3162+
* Forward pass for the model.
3163+
* @param {Object} model_inputs The inputs for the model.
3164+
* @returns {Promise<any>} The output tensor of the model.
3165+
*/
3166+
async forward(model_inputs) {
3167+
return await decoderForward(this, model_inputs);
3168+
}
3169+
}
3170+
//////////////////////////////////////////////////
3171+
30883172
//////////////////////////////////////////////////
30893173
export class ViTPreTrainedModel extends PreTrainedModel { }
30903174
export class ViTModel extends ViTPreTrainedModel { }
@@ -3584,6 +3668,7 @@ const MODEL_MAPPING_NAMES_DECODER_ONLY = new Map([
35843668
['gpt_neo', GPTNeoModel],
35853669
['codegen', CodeGenModel],
35863670
['llama', LlamaModel],
3671+
['mpt', MptModel],
35873672
]);
35883673

35893674
const MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = new Map([
@@ -3626,6 +3711,7 @@ const MODEL_WITH_LM_HEAD_MAPPING_NAMES = new Map([
36263711
['gpt_neo', GPTNeoForCausalLM],
36273712
['codegen', CodeGenForCausalLM],
36283713
['llama', LlamaForCausalLM],
3714+
['mpt', MptForCausalLM],
36293715
]);
36303716

36313717
const MODEL_FOR_MASKED_LM_MAPPING_NAMES = new Map([

0 commit comments

Comments
 (0)