diff --git a/src/engine.ts b/src/engine.ts index fe4c427e..eaf6e9b2 100644 --- a/src/engine.ts +++ b/src/engine.ts @@ -483,18 +483,15 @@ export class MLCEngine implements MLCEngineInterface { genConfig: GenerationConfig, timeReceived: number, ): AsyncGenerator { - // Since it is an async generator, we need to do fine-grained try-catch to ensure lock is - // released only when errors occur. Then release at the very end when no error occurs. - // TODO: This makes code less readable, is there a better way to do this? const lock = this.loadedModelIdToLock.get(model)!; - - // 0. Pre-processing - const isChatCompletion = "messages" in request; - const isFunctionCalling = - "tools" in request && - request.tools !== undefined && - request.tools !== null; try { + // 0. Pre-processing + const isChatCompletion = "messages" in request; + const isFunctionCalling = + "tools" in request && + request.tools !== undefined && + request.tools !== null; + if (isFunctionCalling && !isChatCompletion) { throw new Error( "Expect `chat.completions` with tools, not `completions`.", @@ -504,130 +501,114 @@ export class MLCEngine implements MLCEngineInterface { if (request.seed !== null && request.seed !== undefined) { pipeline.setSeed(request.seed); } - } catch (err) { - await lock.release(); - throw err; - } - // 1. Helper function that generates the chunk - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - const created = Date.now(); - const id = crypto.randomUUID(); - this.interruptSignal = false; - let prevMessageLength = 0; // to know where to start slicing the delta; does not count � - - function _countTrailingReplacementChar(curMessage: string): number { - let cntr = 0; - for (let i = curMessage.length - 1; i >= 0; i--) { - if (curMessage.charAt(i) === "�") { - cntr += 1; - } else { - return cntr; + // 1. Helper function that generates the chunk + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + const created = Date.now(); + const id = crypto.randomUUID(); + this.interruptSignal = false; + let prevMessageLength = 0; // to know where to start slicing the delta; does not count � + + const _countTrailingReplacementChar = (curMessage: string): number => { + let cntr = 0; + for (let i = curMessage.length - 1; i >= 0; i--) { + if (curMessage.charAt(i) === "�") { + cntr += 1; + } else { + return cntr; + } } - } - return cntr; - } + return cntr; + }; - async function _getChunk( - selectedPipeline: LLMChatPipeline, - ): Promise { - // Remove the replacement character (U+FFFD) from the response to handle emojis. - // Each emoji is made up of multiples of 4 tokens; when truncated, it is displayed as �, so - // we skip this delta until a full emoji is rendered - // TODO(Charlie): This does not consider cases of � not being emoji, need to fix with Streamer - const curMessage = selectedPipeline.getMessage(); - const numTrailingReplacementChar = - _countTrailingReplacementChar(curMessage); - if (numTrailingReplacementChar % 4 !== 0) { - return undefined; - } + const _getChunk = async ( + selectedPipeline: LLMChatPipeline, + ): Promise => { + // Remove the replacement character (U+FFFD) from the response to handle emojis. + // Each emoji is made up of multiples of 4 tokens; when truncated, it is displayed as �, so + // we skip this delta until a full emoji is rendered + // TODO(Charlie): This does not consider cases of � not being emoji, need to fix with Streamer + const curMessage = selectedPipeline.getMessage(); + const numTrailingReplacementChar = + _countTrailingReplacementChar(curMessage); + if (numTrailingReplacementChar % 4 !== 0) { + return undefined; + } - const deltaMessage = curMessage.slice(prevMessageLength); - prevMessageLength = curMessage.length; - const logprobs = request.logprobs - ? ({ - content: selectedPipeline.getTokenLogprobArray().slice(-1), // always the last entry - } as ChatCompletionChunk.Choice.Logprobs) - : null; - if (isChatCompletion) { - const chunk: ChatCompletionChunk = { - id: id, - choices: [ - { - delta: { content: deltaMessage, role: "assistant" }, - finish_reason: null, // not finished yet - index: 0, - logprobs: logprobs, - }, - ], - model: model, - object: "chat.completion.chunk", - created: created, - }; - return chunk; - } else { - const chunk: Completion = { - id: id, - choices: [ - { - text: deltaMessage, - finish_reason: null, // not finished yet - index: 0, - logprobs: logprobs, - }, - ], - model: model, - object: "text_completion", - created: created, - }; - return chunk; - } - } + const deltaMessage = curMessage.slice(prevMessageLength); + prevMessageLength = curMessage.length; + const logprobs = request.logprobs + ? ({ + content: selectedPipeline.getTokenLogprobArray().slice(-1), // always the last entry + } as ChatCompletionChunk.Choice.Logprobs) + : null; + if (isChatCompletion) { + const chunk: ChatCompletionChunk = { + id: id, + choices: [ + { + delta: { content: deltaMessage, role: "assistant" }, + finish_reason: null, // not finished yet + index: 0, + logprobs: logprobs, + }, + ], + model: model, + object: "chat.completion.chunk", + created: created, + }; + return chunk; + } else { + const chunk: Completion = { + id: id, + choices: [ + { + text: deltaMessage, + finish_reason: null, // not finished yet + index: 0, + logprobs: logprobs, + }, + ], + model: model, + object: "text_completion", + created: created, + }; + return chunk; + } + }; - // 2. Auto-regressive loop - let curChunk; - try { + // 2. Auto-regressive loop await this.prefill(request, pipeline, chatConfig, genConfig); - curChunk = await _getChunk(pipeline); // prefill produces a chunk - } catch (err) { - await lock.release(); - throw err; - } - if (curChunk) { - yield curChunk; - } - - while (!pipeline.stopped()) { - if (this.interruptSignal) { - // TODO: should we directly release lock here and return the async - // generator? Though no issue observed as of now with interruptGenerate() - pipeline.triggerStop(); - break; + let curChunk = await _getChunk(pipeline); // prefill produces a chunk + if (curChunk) { + yield curChunk; } - try { + + while (!pipeline.stopped()) { + if (this.interruptSignal) { + // TODO: should we directly release lock here and return the async + // generator? Though no issue observed as of now with interruptGenerate() + pipeline.triggerStop(); + break; + } await this.decode(pipeline, genConfig); curChunk = await _getChunk(pipeline); - } catch (err) { - await lock.release(); - throw err; - } - if (curChunk) { - yield curChunk; + if (curChunk) { + yield curChunk; + } } - } - // Reset seed -- we do not want this seed to affect future requests - if (request.seed !== null && request.seed !== undefined) { - pipeline.setSeed(Date.now()); - } + // Reset seed -- we do not want this seed to affect future requests + if (request.seed !== null && request.seed !== undefined) { + pipeline.setSeed(Date.now()); + } - // 3. Last chunk empty marking the end - // If function calling, use the last chunk to return tool_calls - let finish_reason = pipeline.getFinishReason()!; - let tool_calls: - | Array - | undefined; - try { + // 3. Last chunk empty marking the end + // If function calling, use the last chunk to return tool_calls + let finish_reason = pipeline.getFinishReason()!; + let tool_calls: + | Array + | undefined; if (pipeline.getFinishReason() === "stop" && isFunctionCalling) { // If stopped due to length or abort, cannot output return tool_calls field finish_reason = "tool_calls"; @@ -637,108 +618,97 @@ export class MLCEngine implements MLCEngineInterface { /*isStreaming=*/ true, ) as Array; } - } catch (err) { - await lock.release(); - throw err; - } - - if (isChatCompletion) { - const lastChunk: ChatCompletionChunk = { - id: id, - choices: [ - { - delta: isFunctionCalling - ? { - role: "assistant", - tool_calls: tool_calls, - } - : {}, - finish_reason: finish_reason, - index: 0, - }, - ], - model: model, - object: "chat.completion.chunk", - created: created, - }; - yield lastChunk; - } else { - const lastChunk: Completion = { - id: id, - choices: [ - { - text: "", - finish_reason: finish_reason, - index: 0, - }, - ], - model: model, - object: "text_completion", - created: created, - }; - yield lastChunk; - } - // 4. Usage chunk - if (request.stream_options?.include_usage) { - const usedGrammar = - "response_format" in request && - (request.response_format?.type === "grammar" || - request.response_format?.type === "json_object"); - const completion_tokens = pipeline.getCurRoundDecodingTotalTokens(); - const prompt_tokens = pipeline.getCurRoundPrefillTotalTokens(); - const prefill_tokens_per_s = pipeline.getCurRoundPrefillTokensPerSec(); - const decode_tokens_per_s = pipeline.getCurRoundDecodingTokensPerSec(); - const grammar_init_s = pipeline.getCurRoundGrammarInitTotalTime(); - const prefill_time = pipeline.getCurRoundPrefillTotalTime(); - const decode_time = pipeline.getCurRoundDecodingTotalTime(); - const grammar_per_token_s = - pipeline.getCurRoundGrammarPerTokenTotalTime(); - const defaultExtra = { - e2e_latency_s: (Date.now() - timeReceived) / 1000, - prefill_tokens_per_s: prefill_tokens_per_s, - decode_tokens_per_s: decode_tokens_per_s, - time_to_first_token_s: prefill_time, - time_per_output_token_s: decode_time / completion_tokens, - }; - const usage: CompletionUsage = { - completion_tokens: completion_tokens, - prompt_tokens: prompt_tokens, - total_tokens: completion_tokens + prompt_tokens, - extra: usedGrammar - ? { - ...defaultExtra, - ...{ - grammar_init_s: grammar_init_s, - grammar_per_token_s: grammar_per_token_s / completion_tokens, - }, - } - : defaultExtra, - }; if (isChatCompletion) { - const usageChunk: ChatCompletionChunk = { + const lastChunk: ChatCompletionChunk = { id: id, - choices: [], - usage: usage, + choices: [ + { + delta: isFunctionCalling + ? { + role: "assistant", + tool_calls: tool_calls, + } + : {}, + finish_reason: finish_reason, + index: 0, + }, + ], model: model, object: "chat.completion.chunk", created: created, }; - yield usageChunk; + yield lastChunk; } else { - const usageChunk: Completion = { + const lastChunk: Completion = { id: id, - choices: [], - usage: usage, + choices: [ + { + text: "", + finish_reason: finish_reason, + index: 0, + }, + ], model: model, object: "text_completion", created: created, }; - yield usageChunk; + yield lastChunk; } - } - await lock.release(); + // 4. Usage chunk + if (request.stream_options?.include_usage) { + const usedGrammar = + "response_format" in request && + (request.response_format?.type === "grammar" || + request.response_format?.type === "json_object"); + const completion_tokens = pipeline.getCurRoundDecodingTotalTokens(); + const prompt_tokens = pipeline.getCurRoundPrefillTotalTokens(); + const prefill_tokens_per_s = pipeline.getCurRoundPrefillTokensPerSec(); + const decode_tokens_per_s = pipeline.getCurRoundDecodingTokensPerSec(); + const grammar_init_s = pipeline.getCurRoundGrammarInitTotalTime(); + const prefill_time = pipeline.getCurRoundPrefillTotalTime(); + const decode_time = pipeline.getCurRoundDecodingTotalTime(); + const grammar_per_token_s = + pipeline.getCurRoundGrammarPerTokenTotalTime(); + const defaultExtra = { + e2e_latency_s: (Date.now() - timeReceived) / 1000, + prefill_tokens_per_s: prefill_tokens_per_s, + decode_tokens_per_s: decode_tokens_per_s, + time_to_first_token_s: prefill_time, + time_per_output_token_s: decode_time / completion_tokens, + }; + const usage: CompletionUsage = { + completion_tokens: completion_tokens, + prompt_tokens: prompt_tokens, + total_tokens: completion_tokens + prompt_tokens, + extra: usedGrammar + ? { + ...defaultExtra, + ...{ + grammar_init_s: grammar_init_s, + grammar_per_token_s: grammar_per_token_s / completion_tokens, + }, + } + : defaultExtra, + }; + if (isChatCompletion) { + const usageChunk: ChatCompletionChunk = { + id: id, + choices: [], + usage: usage, + model: model, + object: "chat.completion.chunk", + created: created, + }; + yield usageChunk; + } else { + // TODO(Charlie): support usage for completion + } + } + } finally { + await lock.release(); + } } async interruptGenerate() {