Skip to content

Commit 46dd490

Browse files
authored
[Llama + LLama2] Add model support (#232)
* Add support for llama models * Fix JSDoc
1 parent 1e157ba commit 46dd490

File tree

1 file changed

+102
-12
lines changed

1 file changed

+102
-12
lines changed

src/models.js

Lines changed: 102 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1836,13 +1836,13 @@ export class MT5ForConditionalGeneration extends MT5PreTrainedModel {
18361836
}
18371837

18381838
/**
1839-
* Generates the start beams for the given input tokens and output sequence length.
1840-
*
1841-
* @param {any[]} inputs The input sequence.
1842-
* @param {number} numOutputTokens The desired length of the output sequence.
1843-
* @param {...*} args Additional arguments to pass to the `seq2seqStartBeams` function.
1844-
* @returns {any[]} An array of `Beam` objects representing the start beams.
1845-
*/
1839+
* Generates the start beams for the given input tokens and output sequence length.
1840+
*
1841+
* @param {any[]} inputs The input sequence.
1842+
* @param {number} numOutputTokens The desired length of the output sequence.
1843+
* @param {...*} args Additional arguments to pass to the `seq2seqStartBeams` function.
1844+
* @returns {any[]} An array of `Beam` objects representing the start beams.
1845+
*/
18461846
getStartBeams(inputs, numOutputTokens, ...args) {
18471847
return seq2seqStartBeams(this, inputs, numOutputTokens);
18481848
}
@@ -1860,16 +1860,16 @@ export class MT5ForConditionalGeneration extends MT5PreTrainedModel {
18601860
* Updates the given beam with the new predicted token.
18611861
* @param {any} beam The beam to update.
18621862
* @param {number} newTokenId The index of the predicted token.
1863-
*/
1863+
*/
18641864
updateBeam(beam, newTokenId) {
18651865
beam.output_token_ids = [...beam.output_token_ids, newTokenId];
18661866
}
18671867

18681868
/**
1869-
* Runs the forward pass of the model on the given inputs.
1870-
* @param {any} model_inputs The model inputs.
1871-
* @returns {Promise<any>} A Promise that resolves to the model outputs.
1872-
*/
1869+
* Runs the forward pass of the model on the given inputs.
1870+
* @param {any} model_inputs The model inputs.
1871+
* @returns {Promise<any>} A Promise that resolves to the model outputs.
1872+
*/
18731873
async forward(model_inputs) {
18741874
return await seq2seqForward(this, model_inputs);
18751875
}
@@ -2884,6 +2884,94 @@ export class CodeGenForCausalLM extends CodeGenPreTrainedModel {
28842884
}
28852885
//////////////////////////////////////////////////
28862886

2887+
2888+
//////////////////////////////////////////////////
2889+
// LLama models
2890+
2891+
/**
2892+
* The bare LLama Model outputting raw hidden-states without any specific head on top.
2893+
*/
2894+
export class LlamaPreTrainedModel extends PreTrainedModel {
2895+
/**
2896+
* Creates a new instance of the `LlamaPreTrainedModel` class.
2897+
* @param {Object} config The model configuration object.
2898+
* @param {Object} session The ONNX session object.
2899+
*/
2900+
constructor(config, session) {
2901+
super(config, session);
2902+
2903+
// config doesn't contain pad_token_id, so we assume it is the eos_token_id
2904+
this.config.pad_token_id = this.config.eos_token_id
2905+
2906+
this.num_heads = this.config.num_attention_heads
2907+
this.num_layers = this.config.num_hidden_layers
2908+
this.dim_kv = this.config.hidden_size / this.num_heads;
2909+
}
2910+
}
2911+
/**
2912+
* The bare LLaMA Model outputting raw hidden-states without any specific head on top.
2913+
*/
2914+
export class LlamaModel extends LlamaPreTrainedModel {
2915+
/**
2916+
* Throws an error indicating that the current model class is not compatible with `.generate()`,
2917+
* as it doesn't have a language model head.
2918+
*
2919+
* @throws {Error} The current model class is not compatible with `.generate()`
2920+
*
2921+
* @param {...any} args Arguments passed to the generate function
2922+
* @returns {Promise<any>}
2923+
*/
2924+
async generate(...args) {
2925+
throw Error(
2926+
"The current model class (LlamaModel) is not compatible with `.generate()`, as it doesn't have a language model head. Please use one of the following classes instead: {'LlamaForCausalLM'}"
2927+
)
2928+
}
2929+
}
2930+
2931+
export class LlamaForCausalLM extends LlamaPreTrainedModel {
2932+
2933+
/**
2934+
* Initializes and returns the beam for text generation task
2935+
* @param {Tensor} inputTokenIds The input token ids.
2936+
* @param {number} numOutputTokens The number of tokens to be generated.
2937+
* @param {Tensor} inputs_attention_mask Optional input attention mask.
2938+
* @returns {any} A Beam object representing the initialized beam.
2939+
*/
2940+
getStartBeams(inputTokenIds, numOutputTokens, inputs_attention_mask) {
2941+
return decoderStartBeams(this, inputTokenIds, numOutputTokens, inputs_attention_mask)
2942+
}
2943+
2944+
/**
2945+
* Runs a single step of the beam search generation algorithm.
2946+
* @param {any} beam The current beam being generated.
2947+
* @returns {Promise<any>} The updated beam after a single generation step.
2948+
*/
2949+
async runBeam(beam) {
2950+
return await decoderRunBeam(this, beam);
2951+
}
2952+
2953+
/**
2954+
* Updates the given beam with the new generated token id.
2955+
* @param {any} beam The Beam object representing the beam.
2956+
* @param {number} newTokenId The new generated token id to be added to the beam.
2957+
*/
2958+
updateBeam(beam, newTokenId) {
2959+
return decoderUpdatebeam(beam, newTokenId);
2960+
}
2961+
2962+
/**
2963+
* Forward pass for the model.
2964+
* @param {Object} model_inputs The inputs for the model.
2965+
* @returns {Promise<any>} The output tensor of the model.
2966+
*/
2967+
async forward(model_inputs) {
2968+
return await decoderForward(this, model_inputs);
2969+
}
2970+
2971+
}
2972+
//////////////////////////////////////////////////
2973+
2974+
28872975
//////////////////////////////////////////////////
28882976
export class ViTPreTrainedModel extends PreTrainedModel { }
28892977
export class ViTForImageClassification extends ViTPreTrainedModel {
@@ -3260,6 +3348,7 @@ const MODEL_MAPPING_NAMES_DECODER_ONLY = new Map([
32603348
['gpt_bigcode', GPTBigCodeModel],
32613349
['gpt_neo', GPTNeoModel],
32623350
['codegen', CodeGenModel],
3351+
['llama', LlamaModel],
32633352
]);
32643353

32653354
const MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = new Map([
@@ -3300,6 +3389,7 @@ const MODEL_WITH_LM_HEAD_MAPPING_NAMES = new Map([
33003389
['gpt_bigcode', GPTBigCodeForCausalLM],
33013390
['gpt_neo', GPTNeoForCausalLM],
33023391
['codegen', CodeGenForCausalLM],
3392+
['llama', LlamaForCausalLM],
33033393
]);
33043394

33053395
const MODEL_FOR_MASKED_LM_MAPPING_NAMES = new Map([

0 commit comments

Comments
 (0)