diff --git a/src/engine.ts b/src/engine.ts index 47bd7b64..ff956332 100644 --- a/src/engine.ts +++ b/src/engine.ts @@ -455,6 +455,7 @@ export class MLCEngine implements MLCEngineInterface { } await this.decode(pipeline, genConfig); } + await pipeline.flushDeferredTokens?.(genConfig); return pipeline.getMessage(); } @@ -617,6 +618,9 @@ export class MLCEngine implements MLCEngineInterface { } } + // Flush any remaining deferred tokens + await pipeline.flushDeferredTokens?.(genConfig); + // Reset seed -- we do not want this seed to affect future requests if (request.seed !== null && request.seed !== undefined) { pipeline.setSeed(Date.now()); diff --git a/src/llm_chat.ts b/src/llm_chat.ts index 70941449..c3aeeeb2 100644 --- a/src/llm_chat.ts +++ b/src/llm_chat.ts @@ -35,6 +35,124 @@ import { type ImageURL = ChatCompletionContentPartImage.ImageURL; +/** Check if a value is defined and non-null (avoids falsy 0 issue). */ +function _hasValue(value: any): boolean { + return value !== undefined && value !== null; +} + +interface SamplingParams { + temperature: number; + top_p: number; + repetition_penalty: number; + frequency_penalty: number; + presence_penalty: number; + logit_bias: Record | undefined; +} + +/** + * Manages GPU-resident sampled tokens to avoid per-token GPU→CPU sync. + * Accumulates sampled tokens on GPU and flushes to CPU in batches. + */ +class DeferredSampler { + private tvm: tvmjs.Instance; + private device: tvmjs.DLDevice; + /** The last sampled token as a GPU tensor (shape [1], int32). */ + lastTokenDevice: tvmjs.Tensor | null = null; + /** Host-side readback slots, one per deferred step. */ + private hostSlots: tvmjs.Tensor[]; + /** Number of tokens accumulated since last flush. */ + private pendingCount = 0; + /** How many tokens to accumulate before flushing. */ + readonly submitInterval: number; + + constructor(tvm: tvmjs.Instance, device: tvmjs.DLDevice, submitInterval = 4) { + this.tvm = tvm; + this.device = device; + this.submitInterval = submitInterval; + // Pre-allocate host readback slots + this.hostSlots = []; + for (let i = 0; i < submitInterval; i++) { + this.hostSlots.push( + tvm.detachFromCurrentScope(tvm.empty([1], "int32", tvm.cpu())), + ); + } + } + + /** + * Store a newly sampled GPU token. Queues async GPU→CPU copy into the next slot. + * Returns true if a flush (sync + process) is needed. + */ + push(sampledTokenDevice: tvmjs.Tensor): boolean { + // Dispose previous GPU token if different + if ( + this.lastTokenDevice !== null && + this.lastTokenDevice !== sampledTokenDevice + ) { + this.lastTokenDevice.dispose(); + } + this.lastTokenDevice = sampledTokenDevice; + // Queue async copy into the current slot (no sync) + this.hostSlots[this.pendingCount].copyFrom(sampledTokenDevice); + this.pendingCount++; + return this.pendingCount >= this.submitInterval; + } + + /** + * After device.sync(), read all pending tokens from host slots. + * Returns array of token IDs and resets pending count. + */ + flush(): number[] { + const tokens: number[] = []; + for (let i = 0; i < this.pendingCount; i++) { + tokens.push((this.hostSlots[i].toArray() as Int32Array)[0]); + } + this.pendingCount = 0; + this.lastTokenDevice = null; + return tokens; + } + + /** Check if deferred mode should be used for the given generation config. */ + static canDefer( + genConfig: GenerationConfig | undefined, + logitProcessor: LogitProcessor | undefined, + ): boolean { + if (logitProcessor !== undefined) return false; + if (genConfig?.logprobs) return false; + if ( + genConfig?.response_format?.type === "json_object" || + genConfig?.response_format?.type === "grammar" || + genConfig?.response_format?.type === "structural_tag" + ) { + return false; + } + return true; + } + + get hasPendingToken(): boolean { + return this.lastTokenDevice !== null; + } + + get pendingTokenCount(): number { + return this.pendingCount; + } + + reset(): void { + if (this.lastTokenDevice !== null) { + this.lastTokenDevice.dispose(); + this.lastTokenDevice = null; + } + this.pendingCount = 0; + } + + dispose(): void { + this.reset(); + for (const slot of this.hostSlots) { + slot.dispose(); + } + this.hostSlots = []; + } +} + export class LLMChatPipeline { private config: ChatConfig; private tokenizer: Tokenizer; @@ -85,6 +203,19 @@ export class LLMChatPipeline { private finishReason: ChatCompletionFinishReason | undefined = undefined; // frequency of appeared token ids till now (refresh after PrefillStep); token_id mapped to freq private appearedTokensFreq = new Map(); + // Pre-allocated parallel arrays mirroring appearedTokensFreq to avoid per-token spread copies + private penaltyTokenIds: Int32Array = new Int32Array(64); + private penaltyTokenCnts: Int32Array = new Int32Array(64); + private penaltyNumTokens = 0; + private penaltyArraysDirty = true; // rebuild arrays only when a new unique token appears + // Pre-allocated sampling tensors reused every decode step + private temperaturesDevice!: tvmjs.Tensor; + private sampledTokensHost!: tvmjs.Tensor; + private topPHost: Float32Array = new Float32Array(1); + private penaltyPos2seqIds: Int32Array = new Int32Array(64); + private penaltySeqIdsDevice!: tvmjs.Tensor; + private penaltiesDevice!: tvmjs.Tensor; + private penaltiesHost: Float32Array = new Float32Array(3); private conversation: Conversation; // The logprob information of all tokens for this current round (cleared upon each prefillStep) // Cleared & updated at the exact same spots as `outputMessage`. Only updated when @@ -142,6 +273,9 @@ export class LLMChatPipeline { private curRoundGrammarInitTotalTime = 0; // Total time of getting next bitmask and accepting token in seconds private curRoundGrammarPerTokenTotalTime = 0; + // Deferred GPU sampling + private deferredSampler: DeferredSampler; + // Instance variables for supporting sampling on WebGPU private sampleIndices: Int32Array; private sampleIndicesDevice: tvmjs.Tensor; @@ -195,6 +329,7 @@ export class LLMChatPipeline { // 1. Create VM and get the core functions tvm.beginScope(); + this.deferredSampler = new DeferredSampler(this.tvm, this.device); this.vm = this.tvm.detachFromCurrentScope( this.tvm.createVirtualMachine(this.device), ); @@ -339,11 +474,30 @@ export class LLMChatPipeline { this.tvm.empty([numProbs], "float32", this.device), ); + // Pre-allocate sampling tensors reused every decode step + this.temperaturesDevice = this.tvm.detachFromCurrentScope( + this.tvm.empty([numSamples], "float32", this.device), + ); + this.sampledTokensHost = this.tvm.detachFromCurrentScope( + this.tvm.empty([numSamples], "int32", this.tvm.cpu()), + ); + this.penaltySeqIdsDevice = this.tvm.detachFromCurrentScope( + this.tvm.empty([1], "int32", this.device).copyFrom([0]), + ); + this.penaltiesDevice = this.tvm.detachFromCurrentScope( + this.tvm.empty([1, 3], "float32", this.device), + ); + tvm.endScope(); } dispose() { // TODO: Do we need to dispose all PackedFuncs here? + this.deferredSampler?.dispose(); + this.temperaturesDevice?.dispose(); + this.sampledTokensHost?.dispose(); + this.penaltySeqIdsDevice?.dispose(); + this.penaltiesDevice?.dispose(); this.grammarMatcher?.dispose(); this.params.dispose(); this.decoding.dispose(); @@ -386,8 +540,11 @@ export class LLMChatPipeline { if (!keepStats) { this.resetRuntimeStats(); } + this.deferredSampler?.reset(); this.resetKVCache(); this.filledKVCacheLength = 0; + this.penaltyNumTokens = 0; + this.penaltyArraysDirty = true; this.logitProcessor?.resetState(); this.tvm.endScope(); } @@ -587,8 +744,11 @@ export class LLMChatPipeline { const tstart = performance.now(); // cleanup the per convo states + this.deferredSampler?.reset(); this.outputIds = []; this.appearedTokensFreq.clear(); + this.penaltyNumTokens = 0; + this.penaltyArraysDirty = true; this.outputMessage = ""; this.tokenLogprobArray = []; this.curRoundDecodingTotalTokens = 0; @@ -767,33 +927,101 @@ export class LLMChatPipeline { const tstart = performance.now(); - this.tvm.beginScope(); - const chunk: Array> = [ - this.outputIds.slice(this.outputIds.length - 1), - ]; - const chunkLen = chunk.length; - const prevFilledLen = this.filledKVCacheLength; - const logits = this.tvm.detachFromCurrentScope( - await this.embedAndForward(chunk, chunkLen), - ); - if (this.filledKVCacheLength !== prevFilledLen + chunkLen) { - throw new Error( - "Internal Error: filledKVCacheLength does not match expected value.", + // Determine if we can use the deferred (GPU-resident token) path + const canDefer = DeferredSampler.canDefer(genConfig, this.logitProcessor); + const useDeferred = canDefer && this.deferredSampler.hasPendingToken; + + let logits: tvmjs.Tensor; + if (useDeferred) { + // Fast path: use the GPU-resident token directly, skip CPU round-trip + this.tvm.beginScope(); + const prevFilledLen = this.filledKVCacheLength; + logits = this.tvm.detachFromCurrentScope( + await this.embedGPUTokenAndForward( + this.deferredSampler.lastTokenDevice!, + ), ); + if (this.filledKVCacheLength !== prevFilledLen + 1) { + throw new Error( + "Internal Error: filledKVCacheLength does not match expected value.", + ); + } + this.tvm.endScope(); + } else { + // Standard path: read token from outputIds on CPU + if (this.deferredSampler.hasPendingToken) { + // We had a deferred token but can't defer anymore — flush it + await this.flushDeferredTokens(genConfig); + } + this.tvm.beginScope(); + const chunk: Array> = [ + this.outputIds.slice(this.outputIds.length - 1), + ]; + const chunkLen = chunk.length; + const prevFilledLen = this.filledKVCacheLength; + logits = this.tvm.detachFromCurrentScope( + await this.embedAndForward(chunk, chunkLen), + ); + if (this.filledKVCacheLength !== prevFilledLen + chunkLen) { + throw new Error( + "Internal Error: filledKVCacheLength does not match expected value.", + ); + } + this.tvm.endScope(); } - this.tvm.endScope(); - // sample from logits - const nextToken = await this.sampleTokenFromLogits(logits, genConfig); - logits.dispose(); - const tend = performance.now(); + // Sample from logits + if (canDefer) { + // Deferred path: sample and keep token on GPU, skip sync + const sampledTokenDevice = await this.sampleTokenFromLogitsDeferred( + logits, + genConfig, + ); + logits.dispose(); - this.decodingTotalTime += (tend - tstart) / 1e3; - this.decodingTotalTokens += 1; - this.curRoundDecodingTotalTokens += 1; - this.curRoundDecodingTotalTime += (tend - tstart) / 1e3; + const needsFlush = this.deferredSampler.push(sampledTokenDevice); + const tend = performance.now(); + this.decodingTotalTime += (tend - tstart) / 1e3; + this.decodingTotalTokens += 1; + this.curRoundDecodingTotalTokens += 1; + this.curRoundDecodingTotalTime += (tend - tstart) / 1e3; - this.processNextToken(nextToken, genConfig); + if (needsFlush) { + await this.flushDeferredTokens(genConfig); + } + } else { + // Standard path: sample with full sync + const nextToken = await this.sampleTokenFromLogits(logits, genConfig); + logits.dispose(); + const tend = performance.now(); + + this.decodingTotalTime += (tend - tstart) / 1e3; + this.decodingTotalTokens += 1; + this.curRoundDecodingTotalTokens += 1; + this.curRoundDecodingTotalTime += (tend - tstart) / 1e3; + + this.processNextToken(nextToken, genConfig); + } + } + + /** + * Flush deferred tokens: sync GPU, read back all accumulated tokens, and process them. + * Called every `submitInterval` tokens or when switching out of deferred mode. + */ + async flushDeferredTokens(genConfig?: GenerationConfig): Promise { + if (this.deferredSampler.pendingTokenCount === 0) return; + + // One sync completes all pending GPU→CPU copies + await this.device.sync(); + + // Read all accumulated tokens from host slots + const tokens = this.deferredSampler.flush(); + + // Process each token (updates outputIds, outputMessage, stop checks) + for (const token of tokens) { + if (this.stopTriggered) break; + this.processNextToken(token, genConfig); + } } /** @@ -863,12 +1091,22 @@ export class LLMChatPipeline { } if (!this.stopTriggered) { this.outputIds.push(nextToken); - // Update token appearance frequency + // Update token appearance frequency and parallel penalty arrays const curFreq = this.appearedTokensFreq.get(nextToken); if (curFreq !== undefined) { this.appearedTokensFreq.set(nextToken, curFreq + 1); + // Repeated token — update count in-place, no rebuild needed + // Find index and update (linear scan is fine, unique token count is small) + for (let i = 0; i < this.penaltyNumTokens; i++) { + if (this.penaltyTokenIds[i] === nextToken) { + this.penaltyTokenCnts[i] = curFreq + 1; + break; + } + } } else { this.appearedTokensFreq.set(nextToken, 1); + // New unique token — mark dirty so arrays are rebuilt next penalty pass + this.penaltyArraysDirty = true; } } @@ -1097,6 +1335,30 @@ export class LLMChatPipeline { return logits; } + /** + * Embed a GPU-resident token tensor and forward through the decoder. + * Used by deferred sampling to avoid GPU→CPU→GPU round-trip. + */ + private async embedGPUTokenAndForward( + tokenDevice: tvmjs.Tensor, + ): Promise { + this.tvm.beginScope(); + const embedding = this.embed!(tokenDevice, this.params); + const allEmbeddings = embedding.view([1].concat(embedding.shape)); + + const inputLenShape = this.tvm.makeShapeTuple([1]); + const seqIdsTuple = this.tvm.makeShapeTuple([0]); + this.fKVCacheBeginForward!(this.kvCache, seqIdsTuple, inputLenShape); + const retValue = this.decoding(allEmbeddings, this.kvCache, this.params); + this.fKVCacheEndForward!(this.kvCache); + this.filledKVCacheLength += 1; + + const logits = this.tvm.detachFromCurrentScope(retValue.get(0)); + this.tvm.endScope(); + this.tvm.attachToCurrentScope(logits); + return logits; + } + // NOTE: caller must call device.sync() private updateLogitsOnCPU(logits: tvmjs.Tensor): tvmjs.Tensor { if (this.logitsOnCPU == undefined) { @@ -1112,95 +1374,204 @@ export class LLMChatPipeline { return this.logitsOnCPU; } - private async sampleTokenFromLogits( - logitsOnGPU: tvmjs.Tensor, - genConfig?: GenerationConfig, - ) { - // 0. Get value of temperature, top_p, and various penalties, possibly overridden by genConfig - // Also load other genConfig items like logit_bias. Consume all fields of `genConfig` here. - function _hasValue(value: any): boolean { - // if we use `if value` directly, `value` being 0 evaluates to false, violating semantics - return value !== undefined && value !== null; - } + /** + * Extract and validate sampling parameters from genConfig, falling back to this.config defaults. + */ + private extractSamplingParams(genConfig?: GenerationConfig): SamplingParams { let temperature: number = this.config.temperature; let top_p: number = this.config.top_p; let repetition_penalty: number = this.config.repetition_penalty; let frequency_penalty: number = this.config.frequency_penalty; let presence_penalty: number = this.config.presence_penalty; let logit_bias: Record | undefined = undefined; - let logprobs: boolean | undefined = undefined; - let top_logprobs: number | undefined = undefined; - let response_format: ResponseFormat | undefined = undefined; if (genConfig !== undefined) { - if (_hasValue(genConfig.temperature)) { + if (_hasValue(genConfig.temperature)) temperature = genConfig.temperature!; - } - if (_hasValue(genConfig.top_p)) { - top_p = genConfig.top_p!; - } - // TODO: setting top_p to 1.0 by default might run into issues since - // top_p masking in relax uses < instead of <= - // Set default top_p to 1.0 if not set - if (!_hasValue(top_p)) { - top_p = 1.0; - } - if (_hasValue(genConfig.repetition_penalty)) { + if (_hasValue(genConfig.top_p)) top_p = genConfig.top_p!; + if (!_hasValue(top_p)) top_p = 1.0; + if (_hasValue(genConfig.repetition_penalty)) repetition_penalty = genConfig.repetition_penalty!; - } - if (_hasValue(genConfig.frequency_penalty)) { + if (_hasValue(genConfig.frequency_penalty)) frequency_penalty = genConfig.frequency_penalty!; - } - if (_hasValue(genConfig.presence_penalty)) { + if (_hasValue(genConfig.presence_penalty)) presence_penalty = genConfig.presence_penalty!; - } - // If only one of frequency or presence penalty is set, make the other one 0.0 - if (_hasValue(frequency_penalty) && !_hasValue(presence_penalty)) { + if (_hasValue(frequency_penalty) && !_hasValue(presence_penalty)) presence_penalty = 0.0; - } - if (_hasValue(presence_penalty) && !_hasValue(frequency_penalty)) { - frequency_penalty = 0.0; - } - if (!_hasValue(frequency_penalty)) { + if (_hasValue(presence_penalty) && !_hasValue(frequency_penalty)) frequency_penalty = 0.0; - } - if (!_hasValue(presence_penalty)) { - presence_penalty = 0.0; - } - if (_hasValue(genConfig.logit_bias)) { - logit_bias = genConfig.logit_bias!; - } - if (_hasValue(genConfig.logprobs)) { - logprobs = genConfig.logprobs!; - } - if (_hasValue(genConfig.top_logprobs)) { - top_logprobs = genConfig.top_logprobs!; - } - if (_hasValue(genConfig.response_format)) { - response_format = genConfig.response_format!; - } - } - // Check range validity - if (top_p <= 0 || top_p > 1) { - throw new RangeError("top_p", 0, 1); - } - if (temperature < 0) { - throw new MinValueError("temperature", 0); + if (!_hasValue(frequency_penalty)) frequency_penalty = 0.0; + if (!_hasValue(presence_penalty)) presence_penalty = 0.0; + if (_hasValue(genConfig.logit_bias)) logit_bias = genConfig.logit_bias!; } - if (repetition_penalty <= 0) { + + // Validate ranges + if (top_p <= 0 || top_p > 1) throw new RangeError("top_p", 0, 1); + if (temperature < 0) throw new MinValueError("temperature", 0); + if (repetition_penalty <= 0) throw new MinValueError("repetition_penalty", 0); - } if ( frequency_penalty && (frequency_penalty < -2.0 || frequency_penalty > 2.0) - ) { + ) throw new RangeError("frequency_penalty", -2.0, 2.0); + if (presence_penalty && (presence_penalty < -2.0 || presence_penalty > 2.0)) + throw new RangeError("presence_penalty", -2.0, 2.0); + + return { + temperature, + top_p, + repetition_penalty, + frequency_penalty, + presence_penalty, + logit_bias, + }; + } + + /** + * Apply logit_bias and repetition/frequency/presence penalties to logits on GPU. + */ + private applyPenaltiesAndLogitBias( + logitsOnGPU: tvmjs.Tensor, + params: SamplingParams, + genConfig?: GenerationConfig, + ): void { + const { + logit_bias, + frequency_penalty, + presence_penalty, + repetition_penalty, + } = params; + + // Apply logit_bias on GPU + if (_hasValue(logit_bias)) { + const logitBiasBegin = performance.now(); + + const numTokens = Object.keys(logit_bias ?? {}).length; + const pos2seqIds = new Int32Array(numTokens).fill(0); + const tokenIds = new Int32Array(numTokens); + const tokenLogitBias = new Float32Array(numTokens); + const logitBiasKeys = Object.keys(logit_bias ?? {}); + for (let index = 0; index < numTokens; index++) { + const tokenId = parseInt(logitBiasKeys[index]); + tokenIds[index] = tokenId; + tokenLogitBias[index] = logit_bias![tokenId]; + } + + this.tvm.beginScope(); + const pos2seqIdsDevice = this.tvm + .empty([numTokens], "int32", this.device) + .copyFrom(pos2seqIds); + const tokenIdsDevice = this.tvm + .empty([numTokens], "int32", this.device) + .copyFrom(tokenIds); + const tokenLogitBiasDevice = this.tvm + .empty([numTokens], "float32", this.device) + .copyFrom(tokenLogitBias); + this.fapplyLogitBias( + logitsOnGPU.view([1, this.fullVocabSize]), + pos2seqIdsDevice, + tokenIdsDevice, + tokenLogitBiasDevice, + ); + this.tvm.endScope(); + + if (genConfig?.enable_latency_breakdown) { + const logitBiasEnd = performance.now(); + this.curRoundLatencyBreakdown.logitBiasTime.push( + (logitBiasEnd - logitBiasBegin) / 1e3, + ); + } } + + // Apply penalties on GPU if ( - presence_penalty && - (presence_penalty < -2.0 || presence_penalty > 2.0) + frequency_penalty != 0.0 || + presence_penalty != 0.0 || + repetition_penalty != 1.0 ) { - throw new RangeError("presence_penalty", -2.0, 2.0); + if (this.penaltyArraysDirty) { + const size = this.appearedTokensFreq.size; + if (size > this.penaltyTokenIds.length) { + const newLen = Math.max(size, this.penaltyTokenIds.length * 2); + this.penaltyTokenIds = new Int32Array(newLen); + this.penaltyTokenCnts = new Int32Array(newLen); + } + let i = 0; + for (const [id, freq] of this.appearedTokensFreq) { + this.penaltyTokenIds[i] = id; + this.penaltyTokenCnts[i] = freq; + i++; + } + this.penaltyNumTokens = size; + this.penaltyArraysDirty = false; + } + + const numTokens = this.penaltyNumTokens; + if (numTokens > 0) { + const penaltyBegin = performance.now(); + + const tokenIds = this.penaltyTokenIds.subarray(0, numTokens); + const tokenCnt = this.penaltyTokenCnts.subarray(0, numTokens); + if (numTokens > this.penaltyPos2seqIds.length) { + this.penaltyPos2seqIds = new Int32Array( + Math.max(numTokens, this.penaltyPos2seqIds.length * 2), + ); + } + + this.penaltiesHost[0] = presence_penalty; + this.penaltiesHost[1] = frequency_penalty; + this.penaltiesHost[2] = repetition_penalty; + this.penaltiesDevice.copyFrom(this.penaltiesHost); + + this.tvm.beginScope(); + const pos2seqIdsDevice = this.tvm + .empty([numTokens], "int32", this.device) + .copyFrom(this.penaltyPos2seqIds.subarray(0, numTokens)); + const tokenIdsDevice = this.tvm + .empty([numTokens], "int32", this.device) + .copyFrom(tokenIds); + const tokenCntDevice = this.tvm + .empty([numTokens], "int32", this.device) + .copyFrom(tokenCnt); + this.fapplyPenalty( + logitsOnGPU.view([1, this.fullVocabSize]), + this.penaltySeqIdsDevice, + pos2seqIdsDevice, + tokenIdsDevice, + tokenCntDevice, + this.penaltiesDevice, + ); + this.tvm.endScope(); + + if (genConfig?.enable_latency_breakdown) { + const penaltyEnd = performance.now(); + this.curRoundLatencyBreakdown.penaltyTime.push( + (penaltyEnd - penaltyBegin) / 1e3, + ); + } + } + } + } + + private async sampleTokenFromLogits( + logitsOnGPU: tvmjs.Tensor, + genConfig?: GenerationConfig, + ) { + // 0. Extract and validate sampling parameters + const params = this.extractSamplingParams(genConfig); + let { temperature } = params; + const { top_p } = params; + + let logprobs: boolean | undefined = undefined; + let top_logprobs: number | undefined = undefined; + let response_format: ResponseFormat | undefined = undefined; + if (genConfig !== undefined) { + if (_hasValue(genConfig.logprobs)) logprobs = genConfig.logprobs!; + if (_hasValue(genConfig.top_logprobs)) + top_logprobs = genConfig.top_logprobs!; + if (_hasValue(genConfig.response_format)) + response_format = genConfig.response_format!; } const outputTokenBegin = performance.now(); @@ -1283,120 +1654,9 @@ export class LLMChatPipeline { } } - // 2. Apply logit_bias on GPU - if (_hasValue(logit_bias)) { - const logitBiasBegin = performance.now(); - - const numTokens = Object.keys(logit_bias ?? {}).length; - const pos2seqIds = new Int32Array(numTokens).fill(0); - const tokenIds = new Int32Array(numTokens); - const tokenLogitBias = new Float32Array(numTokens); - - const logitBiasKeys = Object.keys(logit_bias ?? {}); - for (let index = 0; index < numTokens; index++) { - const tokenId = parseInt(logitBiasKeys[index]); - tokenIds[index] = tokenId; - tokenLogitBias[index] = logit_bias![tokenId]; - } - - this.tvm.beginScope(); - - const pos2seqIdsDevice = this.tvm - .empty([numTokens], "int32", this.device) - .copyFrom(pos2seqIds); - - const tokenIdsDevice = this.tvm - .empty([numTokens], "int32", this.device) - .copyFrom(tokenIds); - - const tokenLogitBiasDevice = this.tvm - .empty([numTokens], "float32", this.device) - .copyFrom(tokenLogitBias); - - this.fapplyLogitBias( - logitsOnGPU.view([1, this.fullVocabSize]), - pos2seqIdsDevice, - tokenIdsDevice, - tokenLogitBiasDevice, - ); - - this.tvm.endScope(); - - if (genConfig?.enable_latency_breakdown) { - const logitBiasEnd = performance.now(); - const logitBiasTimeSpent = (logitBiasEnd - logitBiasBegin) / 1e3; - this.curRoundLatencyBreakdown.logitBiasTime.push(logitBiasTimeSpent); - } - } - - // 3. Apply penalties to logits on GPU - if ( - frequency_penalty != 0.0 || - presence_penalty != 0.0 || - repetition_penalty != 1.0 - ) { - const appearedTokens = [...this.appearedTokensFreq.keys()]; - const appearedTokensFreqs = [...this.appearedTokensFreq.values()]; - - const numTokens = appearedTokens.length; - - if (numTokens > 0) { - const penaltyBegin = performance.now(); - - const pos2seqIds = new Int32Array(numTokens).fill(0); - const tokenIds = new Int32Array(numTokens).fill(0); - const tokenCnt = new Int32Array(numTokens).fill(0); - const penalties = new Float32Array([ - presence_penalty, - frequency_penalty, - repetition_penalty, - ]); + // 2-3. Apply logit_bias and penalties on GPU + this.applyPenaltiesAndLogitBias(logitsOnGPU, params, genConfig); - tokenIds.set(appearedTokens); - tokenCnt.set(appearedTokensFreqs); - - this.tvm.beginScope(); - const seqIdsArray = this.tvm - .empty([1], "int32", this.device) - .copyFrom([0]); - - const pos2seqIdsDevice = this.tvm - .empty([numTokens], "int32", this.device) - .copyFrom(pos2seqIds); - - const tokenIdsDevice = this.tvm - .empty([numTokens], "int32", this.device) - .copyFrom(tokenIds); - - const tokenCntDevice = this.tvm - .empty([numTokens], "int32", this.device) - .copyFrom(tokenCnt); - - const penaltiesDevice = this.tvm - .empty([1, 3], "float32", this.device) - .copyFrom(penalties); - - this.fapplyPenalty( - logitsOnGPU.view([1, this.fullVocabSize]), - seqIdsArray, - pos2seqIdsDevice, - tokenIdsDevice, - tokenCntDevice, - penaltiesDevice, - ); - - this.tvm.endScope(); - - if (genConfig?.enable_latency_breakdown) { - const penaltyEnd = performance.now(); - const penaltyTimeSpent = (penaltyEnd - penaltyBegin) / 1e3; - this.curRoundLatencyBreakdown.penaltyTime.push(penaltyTimeSpent); - } - } - } - - // TODO: Explore usage of multinomial sampling kernel (currently blocked due to usage - // of i8) for cases where top_p is not set // 4. Sample token from logits const sampleBegin = performance.now(); @@ -1406,12 +1666,9 @@ export class LLMChatPipeline { const numSeqs = 1; const numProbs = 1; - const temperatures = new Float32Array([temperature]); - this.tvm.beginScope(); - const temperaturesDevice = this.tvm - .empty([numSeqs], "float32", this.device) - .copyFrom(temperatures); + this.temperaturesDevice.copyFrom(new Float32Array([temperature])); + const temperaturesDevice = this.temperaturesDevice; let probs = this.fsoftmaxWithTemperature( logitsOnGPU.view([numSeqs, numProbs, this.fullVocabSize]), @@ -1425,12 +1682,9 @@ export class LLMChatPipeline { const uniformSamplesDevice = this.tvm.uniform([1], 0.0, 1.0, this.device); - const topPHost = new Float32Array(numProbs).fill(-1); const topPValue = Math.max(top_p, 1e-5); - this.sampleIndices.forEach((row) => { - topPHost[row] = topPValue; - }); - this.topPDevice.copyFrom(topPHost); + this.topPHost[0] = topPValue; + this.topPDevice.copyFrom(this.topPHost); const sampledTokensDevice = this.tvm.detachFromCurrentScope( this.fsampleWithTopP( @@ -1441,11 +1695,8 @@ export class LLMChatPipeline { this.topPDevice, ), ); - const sampledTokensHost = this.tvm.detachFromCurrentScope( - this.tvm - .empty([numSeqs], "int32", this.tvm.cpu()) - .copyFrom(sampledTokensDevice), - ); + this.sampledTokensHost.copyFrom(sampledTokensDevice); + const sampledTokensHost = this.sampledTokensHost; if (logprobs && top_logprobs! > 0) { this.updateLogitsOnCPU(probs); } @@ -1492,6 +1743,58 @@ export class LLMChatPipeline { return sampledToken; } + /** + * Deferred variant of sampleTokenFromLogits: performs all GPU-side sampling + * (penalties, softmax, argsort, top-p) but does NOT sync or read back to CPU. + * Returns the sampled token as a GPU tensor (shape [1], int32). + * + * Precondition: no grammar, no logitProcessor, no logprobs (checked by DeferredSampler.canDefer). + */ + private async sampleTokenFromLogitsDeferred( + logitsOnGPU: tvmjs.Tensor, + genConfig?: GenerationConfig, + ): Promise { + const params = this.extractSamplingParams(genConfig); + let { temperature } = params; + const { top_p } = params; + + // Apply logit_bias and penalties on GPU (no latency breakdown for deferred path) + this.applyPenaltiesAndLogitBias(logitsOnGPU, params); + + // Sample token on GPU — no sync + temperature = Math.max(1e-6, temperature); + const numSeqs = 1; + const numProbs = 1; + + this.tvm.beginScope(); + this.temperaturesDevice.copyFrom(new Float32Array([temperature])); + let probs = this.fsoftmaxWithTemperature( + logitsOnGPU.view([numSeqs, numProbs, this.fullVocabSize]), + this.temperaturesDevice, + ); + probs = probs.view([numProbs, this.fullVocabSize]); + const argsortResults = this.fargsortProbs(probs); + const sortedProbsDevice = argsortResults.get(0); + const sortedIndicesDevice = argsortResults.get(1); + const uniformSamplesDevice = this.tvm.uniform([1], 0.0, 1.0, this.device); + const topPValue = Math.max(top_p, 1e-5); + this.topPHost[0] = topPValue; + this.topPDevice.copyFrom(this.topPHost); + const sampledTokensDevice = this.tvm.detachFromCurrentScope( + this.fsampleWithTopP( + sortedProbsDevice, + sortedIndicesDevice, + uniformSamplesDevice, + this.sampleIndicesDevice, + this.topPDevice, + ), + ); + this.tvm.endScope(); + + // Do NOT sync — return GPU tensor directly + return sampledTokensDevice; + } + /** * Return the an array of a mixture of token arrays and imageURLs (which cannot be represented * as tokens). Also return the number of tokens this represents.