@@ -1836,13 +1836,13 @@ export class MT5ForConditionalGeneration extends MT5PreTrainedModel {
18361836 }
18371837
18381838 /**
1839- * Generates the start beams for the given input tokens and output sequence length.
1840- *
1841- * @param {any[] } inputs The input sequence.
1842- * @param {number } numOutputTokens The desired length of the output sequence.
1843- * @param {...* } args Additional arguments to pass to the `seq2seqStartBeams` function.
1844- * @returns {any[] } An array of `Beam` objects representing the start beams.
1845- */
1839+ * Generates the start beams for the given input tokens and output sequence length.
1840+ *
1841+ * @param {any[] } inputs The input sequence.
1842+ * @param {number } numOutputTokens The desired length of the output sequence.
1843+ * @param {...* } args Additional arguments to pass to the `seq2seqStartBeams` function.
1844+ * @returns {any[] } An array of `Beam` objects representing the start beams.
1845+ */
18461846 getStartBeams ( inputs , numOutputTokens , ...args ) {
18471847 return seq2seqStartBeams ( this , inputs , numOutputTokens ) ;
18481848 }
@@ -1860,16 +1860,16 @@ export class MT5ForConditionalGeneration extends MT5PreTrainedModel {
18601860 * Updates the given beam with the new predicted token.
18611861 * @param {any } beam The beam to update.
18621862 * @param {number } newTokenId The index of the predicted token.
1863- */
1863+ */
18641864 updateBeam ( beam , newTokenId ) {
18651865 beam . output_token_ids = [ ...beam . output_token_ids , newTokenId ] ;
18661866 }
18671867
18681868 /**
1869- * Runs the forward pass of the model on the given inputs.
1870- * @param {any } model_inputs The model inputs.
1871- * @returns {Promise<any> } A Promise that resolves to the model outputs.
1872- */
1869+ * Runs the forward pass of the model on the given inputs.
1870+ * @param {any } model_inputs The model inputs.
1871+ * @returns {Promise<any> } A Promise that resolves to the model outputs.
1872+ */
18731873 async forward ( model_inputs ) {
18741874 return await seq2seqForward ( this , model_inputs ) ;
18751875 }
@@ -2884,6 +2884,94 @@ export class CodeGenForCausalLM extends CodeGenPreTrainedModel {
28842884}
28852885//////////////////////////////////////////////////
28862886
2887+
2888+ //////////////////////////////////////////////////
2889+ // LLama models
2890+
2891+ /**
2892+ * The bare LLama Model outputting raw hidden-states without any specific head on top.
2893+ */
2894+ export class LlamaPreTrainedModel extends PreTrainedModel {
2895+ /**
2896+ * Creates a new instance of the `LlamaPreTrainedModel` class.
2897+ * @param {Object } config The model configuration object.
2898+ * @param {Object } session The ONNX session object.
2899+ */
2900+ constructor ( config , session ) {
2901+ super ( config , session ) ;
2902+
2903+ // config doesn't contain pad_token_id, so we assume it is the eos_token_id
2904+ this . config . pad_token_id = this . config . eos_token_id
2905+
2906+ this . num_heads = this . config . num_attention_heads
2907+ this . num_layers = this . config . num_hidden_layers
2908+ this . dim_kv = this . config . hidden_size / this . num_heads ;
2909+ }
2910+ }
2911+ /**
2912+ * The bare LLaMA Model outputting raw hidden-states without any specific head on top.
2913+ */
2914+ export class LlamaModel extends LlamaPreTrainedModel {
2915+ /**
2916+ * Throws an error indicating that the current model class is not compatible with `.generate()`,
2917+ * as it doesn't have a language model head.
2918+ *
2919+ * @throws {Error } The current model class is not compatible with `.generate()`
2920+ *
2921+ * @param {...any } args Arguments passed to the generate function
2922+ * @returns {Promise<any> }
2923+ */
2924+ async generate ( ...args ) {
2925+ throw Error (
2926+ "The current model class (LlamaModel) is not compatible with `.generate()`, as it doesn't have a language model head. Please use one of the following classes instead: {'LlamaForCausalLM'}"
2927+ )
2928+ }
2929+ }
2930+
2931+ export class LlamaForCausalLM extends LlamaPreTrainedModel {
2932+
2933+ /**
2934+ * Initializes and returns the beam for text generation task
2935+ * @param {Tensor } inputTokenIds The input token ids.
2936+ * @param {number } numOutputTokens The number of tokens to be generated.
2937+ * @param {Tensor } inputs_attention_mask Optional input attention mask.
2938+ * @returns {any } A Beam object representing the initialized beam.
2939+ */
2940+ getStartBeams ( inputTokenIds , numOutputTokens , inputs_attention_mask ) {
2941+ return decoderStartBeams ( this , inputTokenIds , numOutputTokens , inputs_attention_mask )
2942+ }
2943+
2944+ /**
2945+ * Runs a single step of the beam search generation algorithm.
2946+ * @param {any } beam The current beam being generated.
2947+ * @returns {Promise<any> } The updated beam after a single generation step.
2948+ */
2949+ async runBeam ( beam ) {
2950+ return await decoderRunBeam ( this , beam ) ;
2951+ }
2952+
2953+ /**
2954+ * Updates the given beam with the new generated token id.
2955+ * @param {any } beam The Beam object representing the beam.
2956+ * @param {number } newTokenId The new generated token id to be added to the beam.
2957+ */
2958+ updateBeam ( beam , newTokenId ) {
2959+ return decoderUpdatebeam ( beam , newTokenId ) ;
2960+ }
2961+
2962+ /**
2963+ * Forward pass for the model.
2964+ * @param {Object } model_inputs The inputs for the model.
2965+ * @returns {Promise<any> } The output tensor of the model.
2966+ */
2967+ async forward ( model_inputs ) {
2968+ return await decoderForward ( this , model_inputs ) ;
2969+ }
2970+
2971+ }
2972+ //////////////////////////////////////////////////
2973+
2974+
28872975//////////////////////////////////////////////////
28882976export class ViTPreTrainedModel extends PreTrainedModel { }
28892977export class ViTForImageClassification extends ViTPreTrainedModel {
@@ -3260,6 +3348,7 @@ const MODEL_MAPPING_NAMES_DECODER_ONLY = new Map([
32603348 [ 'gpt_bigcode' , GPTBigCodeModel ] ,
32613349 [ 'gpt_neo' , GPTNeoModel ] ,
32623350 [ 'codegen' , CodeGenModel ] ,
3351+ [ 'llama' , LlamaModel ] ,
32633352] ) ;
32643353
32653354const MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = new Map ( [
@@ -3300,6 +3389,7 @@ const MODEL_WITH_LM_HEAD_MAPPING_NAMES = new Map([
33003389 [ 'gpt_bigcode' , GPTBigCodeForCausalLM ] ,
33013390 [ 'gpt_neo' , GPTNeoForCausalLM ] ,
33023391 [ 'codegen' , CodeGenForCausalLM ] ,
3392+ [ 'llama' , LlamaForCausalLM ] ,
33033393] ) ;
33043394
33053395const MODEL_FOR_MASKED_LM_MAPPING_NAMES = new Map ( [
0 commit comments