Skip to content

Commit 3c60cdb

Browse files
authored
[Refactor] Refactor for compatibility with TVM FFI updates (#730)
1 parent e4b4dc2 commit 3c60cdb

File tree

7 files changed

+397
-358
lines changed

7 files changed

+397
-358
lines changed

package-lock.json

Lines changed: 352 additions & 313 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727
"license": "Apache-2.0",
2828
"homepage": "https://github.com/mlc-ai/web-llm",
2929
"devDependencies": {
30+
"@mlc-ai/web-runtime": "^0.23.0-dev0",
3031
"@mlc-ai/web-tokenizers": "^0.1.6",
32+
"@mlc-ai/web-xgrammar": "0.1.0",
3133
"@next/eslint-plugin-next": "^14.2.3",
3234
"@rollup/plugin-commonjs": "^20.0.0",
3335
"@rollup/plugin-node-resolve": "^13.0.4",
@@ -50,8 +52,6 @@
5052
"rollup-plugin-typescript2": "^0.34.1",
5153
"ts-jest": "^29.1.2",
5254
"tslib": "^2.3.1",
53-
"@mlc-ai/web-runtime": "0.18.0-dev2",
54-
"@mlc-ai/web-xgrammar": "0.1.0",
5555
"typescript": "^4.9.5"
5656
},
5757
"dependencies": {

src/cache_util.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ export async function hasModelInCache(
2929
const modelRecord = findModelRecord(modelId, appConfig);
3030
const modelUrl = cleanModelUrl(modelRecord.model);
3131
const cacheType = appConfig.useIndexedDBCache ? "indexeddb" : "cache";
32-
return tvmjs.hasNDArrayInCache(modelUrl, "webllm/model", cacheType);
32+
return tvmjs.hasTensorInCache(modelUrl, "webllm/model", cacheType);
3333
}
3434

3535
export async function deleteModelAllInfoInCache(
@@ -60,10 +60,10 @@ export async function deleteModelInCache(
6060
const modelUrl = cleanModelUrl(modelRecord.model);
6161
let modelCache: tvmjs.ArtifactCacheTemplate;
6262
if (appConfig.useIndexedDBCache) {
63-
tvmjs.deleteNDArrayCache(modelUrl, "webllm/model", "indexeddb");
63+
tvmjs.deleteTensorCache(modelUrl, "webllm/model", "indexeddb");
6464
modelCache = new tvmjs.ArtifactIndexedDBCache("webllm/model");
6565
} else {
66-
tvmjs.deleteNDArrayCache(modelUrl, "webllm/model", "cache");
66+
tvmjs.deleteTensorCache(modelUrl, "webllm/model", "cache");
6767
modelCache = new tvmjs.ArtifactCache("webllm/model");
6868
}
6969
await modelCache.deleteInCache(new URL("tokenizer.model", modelUrl).href);

src/embedding.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ export class EmbeddingPipeline {
204204
maskNDArray = maskNDArray.view([curBatchSize, maxInputSize]);
205205

206206
// 3.5 Actual forwarding on GPU, logits of shape (curBatchSize, maxInputSize, hidden_size)
207-
const logitsCurBatchOnGPU: tvmjs.NDArray = this.prefill(
207+
const logitsCurBatchOnGPU: tvmjs.Tensor = this.prefill(
208208
inputNDArray,
209209
maskNDArray,
210210
this.params,
@@ -213,7 +213,7 @@ export class EmbeddingPipeline {
213213

214214
// 3.6 Copy logits to CPU, flatten to curBatchSize * maxInputSize * hidden_size
215215
const hidden_size = logitsCurBatchOnGPU.shape[2];
216-
let logitsCurBatchOnCPU: tvmjs.NDArray = this.tvm.empty(
216+
let logitsCurBatchOnCPU: tvmjs.Tensor = this.tvm.empty(
217217
logitsCurBatchOnGPU.shape,
218218
logitsCurBatchOnGPU.dtype,
219219
this.tvm.cpu(),

src/engine.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ export class MLCEngine implements MLCEngineInterface {
367367
this.logger,
368368
);
369369
const cacheType = this.appConfig.useIndexedDBCache ? "indexeddb" : "cache";
370-
await tvm.fetchNDArrayCache(
370+
await tvm.fetchTensorCache(
371371
modelUrl,
372372
tvm.webgpu(),
373373
"webllm/model",

src/llm_chat.ts

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ export class LLMChatPipeline {
6565
// parameter states
6666
private params: tvmjs.TVMObject;
6767
private kvCache: tvmjs.TVMObject;
68-
private logitsOnCPU?: tvmjs.NDArray = undefined;
68+
private logitsOnCPU?: tvmjs.Tensor = undefined;
6969
private filledKVCacheLength = 0;
7070

7171
// meta data
@@ -224,7 +224,7 @@ export class LLMChatPipeline {
224224
// 2. Get json stored in the vm's metadata function
225225
const fgetMetadata = this.vm.getFunction("_metadata");
226226
const ret_value = fgetMetadata();
227-
const metadataStr = this.tvm.detachFromCurrentScope(ret_value).toString();
227+
const metadataStr = ret_value.toString();
228228
const metadata = JSON.parse(metadataStr);
229229

230230
// 3. Load parameters by name
@@ -671,7 +671,7 @@ export class LLMChatPipeline {
671671

672672
// 2. Prefill each chunk
673673
this.tvm.beginScope();
674-
let logits: tvmjs.NDArray;
674+
let logits: tvmjs.Tensor;
675675
for (let i = 0; i < chunks.length; i++) {
676676
const chunk = chunks[i];
677677
const chunkLen = chunkLens[i];
@@ -860,7 +860,7 @@ export class LLMChatPipeline {
860860
* @note precondition: inputTokens.length <= prefillChunkSize, since we take care of
861861
* chunking in `getChunkedPrefillInputData()`.
862862
*/
863-
private getTokensEmbeddings(inputTokens: number[]): tvmjs.NDArray {
863+
private getTokensEmbeddings(inputTokens: number[]): tvmjs.Tensor {
864864
this.tvm.beginScope();
865865
if (inputTokens.length > this.prefillChunkSize) {
866866
throw new Error(
@@ -873,7 +873,7 @@ export class LLMChatPipeline {
873873
this.device,
874874
);
875875
inputData.copyFrom(inputTokens);
876-
const embed: tvmjs.NDArray = this.tvm.detachFromCurrentScope(
876+
const embed: tvmjs.Tensor = this.tvm.detachFromCurrentScope(
877877
this.embed!(inputData, this.params),
878878
);
879879
this.tvm.endScope();
@@ -886,9 +886,9 @@ export class LLMChatPipeline {
886886
*/
887887
private async getImageEmbeddings(
888888
inputImage: ImageURL,
889-
): Promise<tvmjs.NDArray> {
889+
): Promise<tvmjs.Tensor> {
890890
this.tvm.beginScope();
891-
// 1. Transform ImageURL into image input in NDArray
891+
// 1. Transform ImageURL into image input in TVMArray
892892
const url = inputImage.url;
893893
// url starting with `data:image` and `http` share the same loading method
894894
const imgData: ImageData = await getImageDataFromURL(url);
@@ -900,7 +900,7 @@ export class LLMChatPipeline {
900900
.view([1, imgData.height, imgData.width, 3]); // NHWC
901901

902902
// 2. Call image embed kernel
903-
const embed: tvmjs.NDArray = this.tvm.detachFromCurrentScope(
903+
const embed: tvmjs.Tensor = this.tvm.detachFromCurrentScope(
904904
this.image_embed!(pixelArray, this.params),
905905
);
906906
if (embed.shape[0] !== IMAGE_EMBED_SIZE) {
@@ -920,14 +920,14 @@ export class LLMChatPipeline {
920920
*
921921
* @param inputData data to embed and forward
922922
* @param inputDataLen length of this inputData, should smaller than prefill chunk size.
923-
* @returns The logits returned by this forward as tvmjs.NDArray on GPU.
923+
* @returns The logits returned by this forward as tvmjs.Tensor on GPU.
924924
*
925925
* @note Precondition: inputData's data length is smaller than prefill chunk size
926926
*/
927927
private async embedAndForward(
928928
inputData: Array<Array<number> | ImageURL>,
929929
inputDataLen: number,
930-
): Promise<tvmjs.NDArray> {
930+
): Promise<tvmjs.Tensor> {
931931
if (inputDataLen > this.prefillChunkSize) {
932932
throw new Error(
933933
"InternalError: expect inputDataLen <= this.prefillChunkSize.",
@@ -938,18 +938,18 @@ export class LLMChatPipeline {
938938

939939
// 1. Embed all inputData
940940
this.tvm.beginScope();
941-
const embeddings: tvmjs.NDArray[] = [];
941+
const embeddings: tvmjs.Tensor[] = [];
942942
for (let i = 0; i < inputData.length; i++) {
943943
const data = inputData[i];
944944
if (Array.isArray(data)) {
945-
embeddings.push(this.getTokensEmbeddings(data));
945+
embeddings.push(await this.getTokensEmbeddings(data));
946946
} else {
947947
embeddings.push(await this.getImageEmbeddings(data));
948948
}
949949
}
950950

951951
// 2. Concatenate embeddings
952-
let allEmbeddings: tvmjs.NDArray;
952+
let allEmbeddings: tvmjs.Tensor;
953953
if (embeddings.length === 1) {
954954
allEmbeddings = embeddings[0];
955955
} else {
@@ -983,7 +983,7 @@ export class LLMChatPipeline {
983983
}
984984

985985
// NOTE: caller must call device.sync()
986-
private updateLogitsOnCPU(logits: tvmjs.NDArray): tvmjs.NDArray {
986+
private updateLogitsOnCPU(logits: tvmjs.Tensor): tvmjs.Tensor {
987987
if (this.logitsOnCPU == undefined) {
988988
this.logitsOnCPU = this.tvm.detachFromCurrentScope(
989989
this.tvm.empty(logits.shape, logits.dtype, this.tvm.cpu()),
@@ -998,7 +998,7 @@ export class LLMChatPipeline {
998998
}
999999

10001000
private async sampleTokenFromLogits(
1001-
logitsOnGPU: tvmjs.NDArray,
1001+
logitsOnGPU: tvmjs.Tensor,
10021002
genConfig?: GenerationConfig,
10031003
) {
10041004
// 0. Get value of temperature, top_p, and various penalties, possibly overridden by genConfig
@@ -1160,7 +1160,7 @@ export class LLMChatPipeline {
11601160
const logitBiasBegin = performance.now();
11611161

11621162
const numTokens = Object.keys(logit_bias ?? {}).length;
1163-
const pos2seq_id = new Int32Array(numTokens).fill(0);
1163+
const pos2seqIds = new Int32Array(numTokens).fill(0);
11641164
const tokenIds = new Int32Array(numTokens);
11651165
const tokenLogitBias = new Float32Array(numTokens);
11661166

@@ -1173,23 +1173,23 @@ export class LLMChatPipeline {
11731173

11741174
this.tvm.beginScope();
11751175

1176-
const pos2seqIdsArray = this.tvm
1176+
const pos2seqIdsDevice = this.tvm
11771177
.empty([numTokens], "int32", this.device)
1178-
.copyFrom(pos2seq_id);
1178+
.copyFrom(pos2seqIds);
11791179

1180-
const tokenIdsArray = this.tvm
1180+
const tokenIdsDevice = this.tvm
11811181
.empty([numTokens], "int32", this.device)
11821182
.copyFrom(tokenIds);
11831183

1184-
const tokenLogitBiasArray = this.tvm
1184+
const tokenLogitBiasDevice = this.tvm
11851185
.empty([numTokens], "float32", this.device)
11861186
.copyFrom(tokenLogitBias);
11871187

11881188
this.fapplyLogitBias(
11891189
logitsOnGPU.view([1, this.fullVocabSize]),
1190-
pos2seqIdsArray,
1191-
tokenIdsArray,
1192-
tokenLogitBiasArray,
1190+
pos2seqIdsDevice,
1191+
tokenIdsDevice,
1192+
tokenLogitBiasDevice,
11931193
);
11941194

11951195
this.tvm.endScope();
@@ -1215,7 +1215,7 @@ export class LLMChatPipeline {
12151215
if (numTokens > 0) {
12161216
const penaltyBegin = performance.now();
12171217

1218-
const pos2seq_id = new Int32Array(numTokens).fill(0);
1218+
const pos2seqIds = new Int32Array(numTokens).fill(0);
12191219
const tokenIds = new Int32Array(numTokens).fill(0);
12201220
const tokenCnt = new Int32Array(numTokens).fill(0);
12211221
const penalties = new Float32Array([
@@ -1232,29 +1232,29 @@ export class LLMChatPipeline {
12321232
.empty([1], "int32", this.device)
12331233
.copyFrom([0]);
12341234

1235-
const pos2seqIdsArray = this.tvm
1235+
const pos2seqIdsDevice = this.tvm
12361236
.empty([numTokens], "int32", this.device)
1237-
.copyFrom(pos2seq_id);
1237+
.copyFrom(pos2seqIds);
12381238

1239-
const tokenIdsArray = this.tvm
1239+
const tokenIdsDevice = this.tvm
12401240
.empty([numTokens], "int32", this.device)
12411241
.copyFrom(tokenIds);
12421242

1243-
const tokenCntArray = this.tvm
1243+
const tokenCntDevice = this.tvm
12441244
.empty([numTokens], "int32", this.device)
12451245
.copyFrom(tokenCnt);
12461246

1247-
const penaltiesArray = this.tvm
1247+
const penaltiesDevice = this.tvm
12481248
.empty([1, 3], "float32", this.device)
12491249
.copyFrom(penalties);
12501250

12511251
this.fapplyPenalty(
12521252
logitsOnGPU.view([1, this.fullVocabSize]),
12531253
seqIdsArray,
1254-
pos2seqIdsArray,
1255-
tokenIdsArray,
1256-
tokenCntArray,
1257-
penaltiesArray,
1254+
pos2seqIdsDevice,
1255+
tokenIdsDevice,
1256+
tokenCntDevice,
1257+
penaltiesDevice,
12581258
);
12591259

12601260
this.tvm.endScope();
@@ -1280,13 +1280,13 @@ export class LLMChatPipeline {
12801280
const temperatures = new Float32Array([temperature]);
12811281

12821282
this.tvm.beginScope();
1283-
const temperaturesArray = this.tvm
1283+
const temperaturesDevice = this.tvm
12841284
.empty([numSeqs], "float32", this.device)
12851285
.copyFrom(temperatures);
12861286

12871287
const probs = this.fsoftmaxWithTemperature(
12881288
logitsOnGPU.view([numSeqs, 1, this.fullVocabSize]),
1289-
temperaturesArray,
1289+
temperaturesDevice,
12901290
);
12911291
this.updateLogitsOnCPU(probs);
12921292
this.tvm.endScope();
@@ -1458,7 +1458,7 @@ export class LLMChatPipeline {
14581458
const chunkLens: Array<number> = retGetChunks[1];
14591459

14601460
// 2. Prefill each chunk
1461-
let logitsOnGPU: tvmjs.NDArray;
1461+
let logitsOnGPU: tvmjs.Tensor;
14621462
for (let i = 0; i < chunks.length; i++) {
14631463
const chunk = chunks[i];
14641464
const chunkLen = chunkLens[i];

tsconfig.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
"strict": true,
99
"moduleResolution": "Node",
1010
"esModuleInterop": true,
11-
"lib": ["dom", "WebWorker"]
11+
"lib": ["dom", "WebWorker", "es2022"]
1212
},
1313
"typeRoots": ["./node_modules/@webgpu/types", "./node_modules/@types"],
1414
"include": ["src"],

0 commit comments

Comments
 (0)