diff --git a/src/llm_chat.ts b/src/llm_chat.ts index 5f8ecf00..2c7d62a6 100644 --- a/src/llm_chat.ts +++ b/src/llm_chat.ts @@ -50,6 +50,10 @@ export class LLMChatPipeline { private image_embed: tvmjs.PackedFunc | undefined; private embed: tvmjs.PackedFunc; private fapplyBitmask: tvmjs.PackedFunc; + private fapplyPenalty: tvmjs.PackedFunc; + private fapplyLogitBias: tvmjs.PackedFunc; + private fsoftmaxWithTemperature: tvmjs.PackedFunc; + // Functions related to PagedKVCache private fclearKVCaches: tvmjs.PackedFunc; private fKVCacheAddSequence: tvmjs.PackedFunc; @@ -190,6 +194,15 @@ export class LLMChatPipeline { this.fapplyBitmask = this.tvm.detachFromCurrentScope( this.vm.getFunction("apply_bitmask_inplace"), ); + this.fapplyPenalty = this.tvm.detachFromCurrentScope( + this.vm.getFunction("apply_penalty_inplace"), + ); + this.fapplyLogitBias = this.tvm.detachFromCurrentScope( + this.vm.getFunction("apply_logit_bias_inplace"), + ); + this.fsoftmaxWithTemperature = this.tvm.detachFromCurrentScope( + this.vm.getFunction("softmax_with_temperature"), + ); try { this.image_embed = this.tvm.detachFromCurrentScope( this.vm.getFunction("image_embed"), @@ -1091,68 +1104,113 @@ export class LLMChatPipeline { if (this.logitProcessor !== undefined) { logitsOnCPUArray = this.logitProcessor.processLogits(logitsOnCPUArray); } + if (_hasValue(logit_bias)) { - for (const tokenID in logit_bias) { - const curBias = logit_bias[tokenID]; - const curTokenID = parseInt(tokenID); - if (curTokenID > vocab_size) { - throw Error( - "Token " + - curTokenID + - " in logit_bias exceeds vocab_size " + - vocab_size, - ); - } - logitsOnCPUArray[curTokenID] += curBias; + this.tvm.beginScope(); + const numTokens = Object.keys(logit_bias ?? {}).length; + const pos2seq_id = new Int32Array(numTokens).fill(0); + const tokenIds = new Int32Array(numTokens); + const tokenLogitBias = new Float32Array(numTokens); + + const logitBiasKeys = Object.keys(logit_bias ?? {}); + for (let index = 0; index < numTokens; index++) { + const tokenId = parseInt(logitBiasKeys[index]); + tokenIds[index] = tokenId; + tokenLogitBias[index] = logit_bias![tokenId]; } + + const pos2seqIdsArray = this.tvm + .empty([numTokens], "int32", this.device) + .copyFrom(pos2seq_id); + + const tokenIdsArray = this.tvm + .empty([numTokens], "int32", this.device) + .copyFrom(tokenIds); + + const tokenLogitBiasArray = this.tvm + .empty([numTokens], "float32", this.device) + .copyFrom(tokenLogitBias); + + const logitsOnGPU = this.tvm + .empty([1, this.fullVocabSize], "float32", this.device) + .copyFrom(logitsOnCPUArray); + + this.fapplyLogitBias( + logitsOnGPU, + pos2seqIdsArray, + tokenIdsArray, + tokenLogitBiasArray, + ); + this.updateLogitsOnCPU(logitsOnGPU); + this.tvm.endScope(); } - this.logitsOnCPU.copyFrom(logitsOnCPUArray); + await this.device.sync(); } // 3. Apply penalties to logits - if (_hasValue(frequency_penalty) && _hasValue(presence_penalty)) { - // 3.1. Use frequency and presence penalty + if ( + frequency_penalty != 0.0 || + presence_penalty != 0.0 || + repetition_penalty != 1.0 + ) { this.tvm.beginScope(); - // Both `keys()` and `values()` are in insertion order. const appearedTokens = [...this.appearedTokensFreq.keys()]; const appearedTokensFreqs = [...this.appearedTokensFreq.values()]; - const appeared_tokens_ndarray = this.tvm.empty( - [1, appearedTokens.length], - "int32", - this.tvm.cpu(), - ); - const appeared_tokens_freqs_ndarray = this.tvm.empty( - [1, appearedTokensFreqs.length], - "int32", - this.tvm.cpu(), - ); - appeared_tokens_ndarray.copyFrom(appearedTokens); - appeared_tokens_freqs_ndarray.copyFrom(appearedTokensFreqs); - this.tvm.applyPresenceAndFrequencyPenalty( - this.logitsOnCPU, - appeared_tokens_ndarray, - appeared_tokens_freqs_ndarray, - presence_penalty!, - frequency_penalty!, - ); - this.tvm.endScope(); - } else if (repetition_penalty != 1.0) { - // 3.2. Use repetition penalty - this.tvm.beginScope(); - const appearedTokens = [...this.appearedTokensFreq.keys()]; - const appeared_tokens_ndarray = this.tvm.empty( - [1, appearedTokens.length], - "int32", - this.tvm.cpu(), - ); - appeared_tokens_ndarray.copyFrom(appearedTokens); - this.tvm.applyRepetitionPenalty( - this.logitsOnCPU, - appeared_tokens_ndarray, + + const numTokens = appearedTokens.length; + + const seqIdsArray = this.tvm + .empty([1], "int32", this.device) + .copyFrom([0]); + + const pos2seq_id = new Int32Array(numTokens).fill(0); + const tokenIds = new Int32Array(numTokens).fill(0); + const tokenCnt = new Int32Array(numTokens).fill(0); + const penalties = new Float32Array([ + presence_penalty, + frequency_penalty, repetition_penalty, - ); + ]); + const paddedPenalties = new Float32Array(3); + paddedPenalties.set(penalties); + + tokenIds.set(appearedTokens); + tokenCnt.set(appearedTokensFreqs); + + const pos2seqIdsArray = this.tvm + .empty([numTokens], "int32", this.device) + .copyFrom(pos2seq_id); + + const tokenIdsArray = this.tvm + .empty([numTokens], "int32", this.device) + .copyFrom(tokenIds); + + const tokenCntArray = this.tvm + .empty([numTokens], "int32", this.device) + .copyFrom(tokenCnt); + + const penaltiesArray = this.tvm + .empty([1, 3], "float32", this.device) + .copyFrom(paddedPenalties); + + const logitsOnGPU = this.tvm + .empty([1, this.fullVocabSize], "float32", this.device) + .copyFrom(this.logitsOnCPU.toArray()); + + if (numTokens > 0) { + this.fapplyPenalty( + logitsOnGPU, + seqIdsArray, + pos2seqIdsArray, + tokenIdsArray, + tokenCntArray, + penaltiesArray, + ); + } + this.updateLogitsOnCPU(logitsOnGPU); this.tvm.endScope(); } + await this.device.sync(); // 4. Sample token from logits // If logprobs, need the actual distribution via softmax, otherwise directly sample from logits @@ -1160,11 +1218,41 @@ export class LLMChatPipeline { if (logprobs) { // Inplace transform logitsOnCPU to a distribution temperature = Math.max(1e-6, temperature); // to prevent division by zero - this.tvm.applySoftmaxWithTemperature(this.logitsOnCPU, temperature); - sampledToken = this.tvm.sampleTopPFromProb(this.logitsOnCPU, top_p); - this.tokenLogprobArray.push( - this.getTokenLogprob(sampledToken, top_logprobs!), - ); + + const numSeqs = 1; + const numTokens = this.appearedTokensFreq.size; + + if (numTokens > 0) { + const temperatures = new Float32Array([temperature]); + + this.tvm.beginScope(); + const temperaturesArray = this.tvm + .empty([numSeqs], "float32", this.device) + .copyFrom(temperatures); + + const logitsOnGPU = this.tvm + .empty([numSeqs, 1, this.fullVocabSize], "float32", this.device) + .copyFrom(this.logitsOnCPU.toArray()); + + const probs = this.fsoftmaxWithTemperature( + logitsOnGPU, + temperaturesArray, + ); + this.updateLogitsOnCPU(probs); + this.tvm.endScope(); + await this.device.sync(); + + sampledToken = this.tvm.sampleTopPFromProb(this.logitsOnCPU, top_p); + this.tokenLogprobArray.push( + this.getTokenLogprob(sampledToken, top_logprobs!), + ); + } else { + this.tvm.applySoftmaxWithTemperature(this.logitsOnCPU, temperature); + sampledToken = this.tvm.sampleTopPFromProb(this.logitsOnCPU, top_p); + this.tokenLogprobArray.push( + this.getTokenLogprob(sampledToken, top_logprobs!), + ); + } } else { // temperature being 0 is allowed here, equivalent to argmax sampledToken = this.tvm.sampleTopPFromLogits(