From e7edbfd6809c3163c894666575012f5f5eccbc0e Mon Sep 17 00:00:00 2001 From: Akaash Parthasarathy Date: Fri, 24 Oct 2025 15:35:54 -0400 Subject: [PATCH 1/3] Add support for cross-origin storage caching --- src/cache_util.ts | 169 ++++++++++++++++++---- src/config.ts | 10 ++ src/cross_origin_storage.ts | 225 ++++++++++++++++++++++++++++++ src/cross_origin_storage_cache.ts | 92 ++++++++++++ src/engine.ts | 27 ++-- src/utils.ts | 3 + 6 files changed, 484 insertions(+), 42 deletions(-) create mode 100644 src/cross_origin_storage.ts create mode 100644 src/cross_origin_storage_cache.ts diff --git a/src/cache_util.ts b/src/cache_util.ts index 22245bdc..027e3fbf 100644 --- a/src/cache_util.ts +++ b/src/cache_util.ts @@ -4,10 +4,139 @@ import { ChatConfig, ModelRecord, prebuiltAppConfig, + getCacheBackend, } from "./config"; import { cleanModelUrl } from "./support"; import { ModelNotFoundError, UnsupportedTokenizerFilesError } from "./error"; import { Tokenizer } from "@mlc-ai/web-tokenizers"; +import CrossOriginStorage from "./cross_origin_storage"; +import CrossOriginStorageCache from "./cross_origin_storage_cache"; + +type CacheScope = "webllm/model" | "webllm/config" | "webllm/wasm"; + +let crossOriginUnavailableLogged = false; + +function shouldUseCrossOrigin(appConfig: AppConfig): boolean { + return ( + getCacheBackend(appConfig) === "cross-origin" && + CrossOriginStorage.isAvailable() + ); +} + +export function getArtifactCache( + scope: CacheScope, + appConfig: AppConfig, + logger: (msg: string) => void = console.warn, +): tvmjs.ArtifactCacheTemplate { + const backend = getCacheBackend(appConfig); + if (backend === "cross-origin") { + if (CrossOriginStorage.isAvailable()) { + return new CrossOriginStorageCache(scope); + } + // Fallback to Cache API + if (!crossOriginUnavailableLogged) { + logger( + "Cross-origin storage backend requested but unavailable; falling back to Cache API.", + ); + crossOriginUnavailableLogged = true; + } + } + if (backend === "indexeddb") { + return new tvmjs.ArtifactIndexedDBCache(scope); + } + return new tvmjs.ArtifactCache(scope); +} + +async function hasTensorCache( + cache: tvmjs.ArtifactCacheTemplate, + tensorCacheUrl: string, +): Promise { + const jsonUrl = new URL("tensor-cache.json", tensorCacheUrl).href; + const hasManifest = await cache.hasAllKeys([jsonUrl]); + if (!hasManifest) { + return false; + } + const manifest = await cache.fetchWithCache(jsonUrl, "json"); + const records = manifest?.records ?? []; + if (!Array.isArray(records) || records.length === 0) { + return false; + } + const shardUrls = records.map( + (entry: { dataPath: string }) => + new URL(entry.dataPath, tensorCacheUrl).href, + ); + return cache.hasAllKeys(shardUrls); +} + +async function deleteTensorCacheEntries( + cache: tvmjs.ArtifactCacheTemplate, + tensorCacheUrl: string, +): Promise { + const jsonUrl = new URL("tensor-cache.json", tensorCacheUrl).href; + const hasManifest = await cache.hasAllKeys([jsonUrl]); + if (!hasManifest) { + return; + } + let manifest: { records?: Array<{ dataPath: string }> }; + try { + manifest = await cache.fetchWithCache(jsonUrl, "json"); + } catch (err) { + return; + } + const records = manifest?.records ?? []; + await Promise.all( + records.map(async (entry) => { + if (!entry?.dataPath) { + return; + } + const dataUrl = new URL(entry.dataPath, tensorCacheUrl).href; + await cache.deleteInCache(dataUrl); + }), + ); + await cache.deleteInCache(jsonUrl); +} + +export async function fetchModelArtifacts( + tvm: tvmjs.Instance, + tensorCacheUrl: string, + device: tvmjs.DLDevice, + appConfig: AppConfig, + signal?: AbortSignal, +): Promise { + if (!shouldUseCrossOrigin(appConfig)) { + const backend = getCacheBackend(appConfig); + const cacheType = backend === "indexeddb" ? "indexeddb" : "cache"; + return tvm.fetchTensorCache( + tensorCacheUrl, + device, + "webllm/model", + cacheType, + signal, + ); + } + + const artifactCache = getArtifactCache("webllm/model", appConfig); + const jsonUrl = new URL("tensor-cache.json", tensorCacheUrl).href; + const manifest = await artifactCache.fetchWithCache(jsonUrl, "json", signal); + const records = ( + Array.isArray(manifest?.records) ? manifest.records : [] + ) as Array; + await (tvm as any).fetchTensorCacheInternal( + tensorCacheUrl, + records, + device, + artifactCache, + signal, + ); + if (manifest?.metadata !== undefined) { + const runtime = tvm as any; + runtime.cacheMetadata = { + ...runtime.cacheMetadata, + ...(manifest.metadata as Record), + }; + } + return manifest; +} function findModelRecord(modelId: string, appConfig?: AppConfig): ModelRecord { const matchedItem = appConfig?.model_list.find( @@ -28,7 +157,12 @@ export async function hasModelInCache( } const modelRecord = findModelRecord(modelId, appConfig); const modelUrl = cleanModelUrl(modelRecord.model); - const cacheType = appConfig.useIndexedDBCache ? "indexeddb" : "cache"; + if (shouldUseCrossOrigin(appConfig)) { + const cache = getArtifactCache("webllm/model", appConfig); + return hasTensorCache(cache, modelUrl); + } + const backend = getCacheBackend(appConfig); + const cacheType = backend === "indexeddb" ? "indexeddb" : "cache"; return tvmjs.hasTensorInCache(modelUrl, "webllm/model", cacheType); } @@ -58,13 +192,13 @@ export async function deleteModelInCache( } const modelRecord = findModelRecord(modelId, appConfig); const modelUrl = cleanModelUrl(modelRecord.model); - let modelCache: tvmjs.ArtifactCacheTemplate; - if (appConfig.useIndexedDBCache) { - tvmjs.deleteTensorCache(modelUrl, "webllm/model", "indexeddb"); - modelCache = new tvmjs.ArtifactIndexedDBCache("webllm/model"); + const modelCache = getArtifactCache("webllm/model", appConfig); + if (shouldUseCrossOrigin(appConfig)) { + await deleteTensorCacheEntries(modelCache, modelUrl); } else { - tvmjs.deleteTensorCache(modelUrl, "webllm/model", "cache"); - modelCache = new tvmjs.ArtifactCache("webllm/model"); + const backend = getCacheBackend(appConfig); + const cacheType = backend === "indexeddb" ? "indexeddb" : "cache"; + await tvmjs.deleteTensorCache(modelUrl, "webllm/model", cacheType); } await modelCache.deleteInCache(new URL("tokenizer.model", modelUrl).href); await modelCache.deleteInCache(new URL("tokenizer.json", modelUrl).href); @@ -79,12 +213,7 @@ export async function deleteChatConfigInCache( appConfig = prebuiltAppConfig; } const modelRecord = findModelRecord(modelId, appConfig); - let configCache: tvmjs.ArtifactCacheTemplate; - if (appConfig.useIndexedDBCache) { - configCache = new tvmjs.ArtifactIndexedDBCache("webllm/config"); - } else { - configCache = new tvmjs.ArtifactCache("webllm/config"); - } + const configCache = getArtifactCache("webllm/config", appConfig); const modelUrl = cleanModelUrl(modelRecord.model); const configUrl = new URL("mlc-chat-config.json", modelUrl).href; await configCache.deleteInCache(configUrl); @@ -99,12 +228,7 @@ export async function deleteModelWasmInCache( appConfig = prebuiltAppConfig; } const modelRecord = findModelRecord(modelId, appConfig); - let wasmCache: tvmjs.ArtifactCacheTemplate; - if (appConfig.useIndexedDBCache) { - wasmCache = new tvmjs.ArtifactIndexedDBCache("webllm/wasm"); - } else { - wasmCache = new tvmjs.ArtifactCache("webllm/wasm"); - } + const wasmCache = getArtifactCache("webllm/wasm", appConfig); await wasmCache.deleteInCache(modelRecord.model_lib); } @@ -122,12 +246,7 @@ export async function asyncLoadTokenizer( appConfig: AppConfig, logger: (msg: string) => void = console.log, ): Promise { - let modelCache: tvmjs.ArtifactCacheTemplate; - if (appConfig.useIndexedDBCache) { - modelCache = new tvmjs.ArtifactIndexedDBCache("webllm/model"); - } else { - modelCache = new tvmjs.ArtifactCache("webllm/model"); - } + const modelCache = getArtifactCache("webllm/model", appConfig, logger); if (config.tokenizer_files.includes("tokenizer.json")) { const url = new URL("tokenizer.json", baseUrl).href; diff --git a/src/config.ts b/src/config.ts index cb24d143..99d4c70a 100644 --- a/src/config.ts +++ b/src/config.ts @@ -275,9 +275,19 @@ export interface ModelRecord { * * @note Note that the Cache API is more well-tested in WebLLM as of now. */ +export type CacheBackend = "cache" | "indexeddb" | "cross-origin"; + export interface AppConfig { model_list: Array; useIndexedDBCache?: boolean; + cacheBackend?: CacheBackend; +} + +export function getCacheBackend(appConfig: AppConfig): CacheBackend { + if (appConfig.cacheBackend !== undefined) { + return appConfig.cacheBackend; + } + return appConfig.useIndexedDBCache ? "indexeddb" : "cache"; } /** diff --git a/src/cross_origin_storage.ts b/src/cross_origin_storage.ts new file mode 100644 index 00000000..9d4c2659 --- /dev/null +++ b/src/cross_origin_storage.ts @@ -0,0 +1,225 @@ +const HASH_ALGORITHM = "SHA-256"; +const HASH_MATCH_REGEX = /[A-Fa-f0-9]{64}/; + +export interface CrossOriginHashDescriptor { + algorithm: string; + value: string; +} + +interface CrossOriginStorageHandle { + getFile(): Promise; + createWritable(): Promise; +} + +interface CrossOriginStorageAPI { + requestFileHandles( + descriptors: CrossOriginHashDescriptor[], + options?: { create?: boolean }, + ): Promise; + removeFileHandles?(descriptors: CrossOriginHashDescriptor[]): Promise; +} + +type RequestLike = string | URL | Request | { url?: string }; + +declare global { + interface Navigator { + crossOriginStorage?: CrossOriginStorageAPI; + } +} + +export default class CrossOriginStorage { + private hashCache: Map; + + constructor() { + this.hashCache = new Map(); + } + + static isAvailable(): boolean { + return ( + typeof navigator !== "undefined" && + "crossOriginStorage" in navigator && + navigator.crossOriginStorage !== undefined + ); + } + + async match(request: RequestLike): Promise { + const url = this.normalizeRequest(request); + const hash = await this.resolveHashDescriptor(url); + if (!hash) { + return undefined; + } + try { + const api = this.getApi(); + if (!api) { + return undefined; + } + const handles = await api.requestFileHandles([hash]); + const handle = handles[0]; + if (!handle) { + return undefined; + } + const blob = await handle.getFile(); + return new Response(blob); + } catch { + return undefined; + } + } + + async put(request: RequestLike, response: Response): Promise { + const url = this.normalizeRequest(request); + const blob = await response.blob(); + const hash = await this.getBlobHash(blob); + const api = this.getApi(); + if (!api) { + throw new Error("Cross-origin storage API unavailable."); + } + const handles = await api.requestFileHandles([hash], { create: true }); + const handle = handles[0]; + if (!handle) { + throw new Error("Cross-origin storage API returned no handles."); + } + const writableStream = await handle.createWritable(); + await writableStream.write(blob); + await writableStream.close(); + this.hashCache.set(url, hash); + } + + async delete(request: RequestLike): Promise { + const url = this.normalizeRequest(request); + const hash = await this.resolveHashDescriptor(url); + if (!hash) { + return; + } + const api = this.getApi(); + if (api && typeof api.removeFileHandles === "function") { + await api.removeFileHandles([hash]); + } + this.hashCache.delete(url); + } + + private getApi(): CrossOriginStorageAPI | undefined { + if (!CrossOriginStorage.isAvailable()) { + return undefined; + } + return navigator.crossOriginStorage; + } + + private normalizeRequest(request: RequestLike): string { + if (typeof request === "string") { + return request; + } + if (request instanceof URL) { + return request.href; + } + if (request instanceof Request) { + return request.url; + } + if (request && typeof request.url === "string") { + return request.url; + } + throw new Error("CrossOriginStorage: Unsupported request type."); + } + + private async resolveHashDescriptor( + url: string, + ): Promise { + const cached = this.hashCache.get(url); + if (cached) { + return cached; + } + const hashValue = await this.getFileHash(url); + if (!hashValue) { + return null; + } + const descriptor: CrossOriginHashDescriptor = { + algorithm: HASH_ALGORITHM, + value: hashValue, + }; + this.hashCache.set(url, descriptor); + return descriptor; + } + + // Gets the SHA-256 hash for large resources using request metadata. + private async getFileHash(url: string): Promise { + const metadataHash = await this.extractHashFromHead(url); + if (metadataHash) { + return metadataHash; + } + if (/\/resolve\/main\//.test(url)) { + const pointerHash = await this.extractHashFromPointer(url); + if (pointerHash) { + return pointerHash; + } + } + return null; + } + + private async extractHashFromHead(url: string): Promise { + try { + const response = await fetch(url, { method: "HEAD" }); + if (!response.ok) { + return null; + } + const headerNames = [ + "x-linked-etag", + "x-linked-hash", + "x-amz-meta-sha256", + "x-oss-meta-sha256", + "x-sha256", + "etag", + ]; + for (const name of headerNames) { + const value = response.headers.get(name); + const hash = this.extractSha256(value); + if (hash) { + return hash; + } + } + } catch { + // Swallow errors; fall back to other strategies. + } + return null; + } + + private async extractHashFromPointer(url: string): Promise { + try { + const rawUrl = url.replace(/\/resolve\//, "/raw/"); + const response = await fetch(rawUrl, { + headers: { Range: "bytes=0-1023" }, + }); + if (!response.ok) { + return null; + } + const text = await response.text(); + if (!text.includes("oid sha256:")) { + return null; + } + const match = text.match(/oid sha256:([A-Fa-f0-9]+)/); + return match ? match[1] : null; + } catch { + return null; + } + } + + private extractSha256(value: string | null): string | null { + if (!value) { + return null; + } + const match = value.match(HASH_MATCH_REGEX); + return match ? match[0].toLowerCase() : null; + } + + private async getBlobHash(blob: Blob): Promise { + const arrayBuffer = await blob.arrayBuffer(); + const hashBuffer = await crypto.subtle.digest(HASH_ALGORITHM, arrayBuffer); + const hashArray = Array.from(new Uint8Array(hashBuffer)); + const hashHex = hashArray + .map((byte) => byte.toString(16).padStart(2, "0")) + .join(""); + + return { + algorithm: HASH_ALGORITHM, + value: hashHex, + }; + } +} diff --git a/src/cross_origin_storage_cache.ts b/src/cross_origin_storage_cache.ts new file mode 100644 index 00000000..4bfab8f3 --- /dev/null +++ b/src/cross_origin_storage_cache.ts @@ -0,0 +1,92 @@ +import * as tvmjs from "@mlc-ai/web-runtime"; +import CrossOriginStorage from "./cross_origin_storage"; + +type StoreType = string | undefined; + +const DEFAULT_FETCH_OPTIONS: RequestInit = { method: "GET" }; + +export class CrossOriginStorageCache implements tvmjs.ArtifactCacheTemplate { + private storage: CrossOriginStorage; + + constructor( + _scope: string, + storage: CrossOriginStorage = new CrossOriginStorage(), + ) { + this.storage = storage; + } + + async fetchWithCache( + url: string, + storetype?: StoreType, + signal?: AbortSignal, + ): Promise { + const cachedResponse = await this.storage.match(url); + if (cachedResponse !== undefined) { + return this.responseToStoreType(cachedResponse, storetype); + } + + await this.addToCache(url, storetype, signal); + const hydrated = await this.storage.match(url); + if (hydrated === undefined) { + throw new Error(`CrossOriginStorageCache: failed to hydrate ${url}`); + } + return this.responseToStoreType(hydrated, storetype); + } + + async addToCache( + url: string, + _storetype?: StoreType, + signal?: AbortSignal, + ): Promise { + const existing = await this.storage.match(url); + if (existing !== undefined) { + return; + } + const request = new Request( + url, + signal ? { ...DEFAULT_FETCH_OPTIONS, signal } : DEFAULT_FETCH_OPTIONS, + ); + const response = await fetch(request); + if (!response.ok) { + throw new Error( + `CrossOriginStorageCache: Unable to fetch ${url}, received status ${response.status}`, + ); + } + const cloned = response.clone(); + await this.storage.put(url, cloned); + } + + async hasAllKeys(keys: string[]): Promise { + const results = await Promise.all( + keys.map(async (key) => { + const cached = await this.storage.match(key); + return cached !== undefined; + }), + ); + return results.every((item) => item); + } + + async deleteInCache(_url: string): Promise { + // no delete API currently provided by Cross-Origin Storage + return; + } + + private async responseToStoreType( + response: Response, + storetype?: StoreType, + ): Promise { + if (storetype === undefined) { + return response; + } + const format = storetype.toLowerCase(); + if (format === "json") { + return response.json(); + } + if (format === "arraybuffer") { + return response.arrayBuffer(); + } + return response; + } +} + +export default CrossOriginStorageCache; diff --git a/src/engine.ts b/src/engine.ts index 47bd7b64..aa8ca064 100644 --- a/src/engine.ts +++ b/src/engine.ts @@ -69,7 +69,11 @@ import { SpecifiedModelNotFoundError, ModelNotLoadedError, } from "./error"; -import { asyncLoadTokenizer } from "./cache_util"; +import { + asyncLoadTokenizer, + fetchModelArtifacts, + getArtifactCache, +} from "./cache_util"; import { EmbeddingPipeline } from "./embedding"; /** @@ -260,12 +264,7 @@ export class MLCEngine implements MLCEngineInterface { this.loadedModelIdToModelType.set(modelId, modelType); // instantiate cache - let configCache: tvmjs.ArtifactCacheTemplate; - if (this.appConfig.useIndexedDBCache) { - configCache = new tvmjs.ArtifactIndexedDBCache("webllm/config"); - } else { - configCache = new tvmjs.ArtifactCache("webllm/config"); - } + const configCache = getArtifactCache("webllm/config", this.appConfig); // load config const configUrl = new URL("mlc-chat-config.json", modelUrl).href; @@ -281,12 +280,7 @@ export class MLCEngine implements MLCEngineInterface { this.loadedModelIdToChatConfig.set(modelId, curModelConfig); // load tvm wasm - let wasmCache: tvmjs.ArtifactCacheTemplate; - if (this.appConfig.useIndexedDBCache) { - wasmCache = new tvmjs.ArtifactIndexedDBCache("webllm/wasm"); - } else { - wasmCache = new tvmjs.ArtifactCache("webllm/wasm"); - } + const wasmCache = getArtifactCache("webllm/wasm", this.appConfig); const wasmUrl = modelRecord.model_lib; if (wasmUrl === undefined) { @@ -367,12 +361,11 @@ export class MLCEngine implements MLCEngineInterface { this.appConfig, this.logger, ); - const cacheType = this.appConfig.useIndexedDBCache ? "indexeddb" : "cache"; - await tvm.fetchTensorCache( + await fetchModelArtifacts( + tvm, modelUrl, tvm.webgpu(), - "webllm/model", - cacheType, + this.appConfig, this.reloadController?.signal, ); diff --git a/src/utils.ts b/src/utils.ts index 7c688927..2c91b1e8 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -80,6 +80,9 @@ export function areAppConfigsEqual( if (config1.useIndexedDBCache !== config2.useIndexedDBCache) { return false; } + if (config1.cacheBackend !== config2.cacheBackend) { + return false; + } // Check if both configurations have the same number of model records if (config1.model_list.length !== config2.model_list.length) { From 00a82d387493c5473e4c462438ccc15edc844d5e Mon Sep 17 00:00:00 2001 From: Akaash Parthasarathy Date: Fri, 24 Oct 2025 15:36:21 -0400 Subject: [PATCH 2/3] Update READMEs to describe cross-origin storage --- examples/README.md | 3 +-- examples/cache-usage/README.md | 8 ++++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/examples/README.md b/examples/README.md index 0d7dad42..766e3fc7 100644 --- a/examples/README.md +++ b/examples/README.md @@ -46,8 +46,7 @@ These examples demonstrate various capabilities via WebLLM's OpenAI-like API. #### Others - [logit-processor](logit-processor): while `logit_bias` is supported, we additionally support stateful logit processing where users can specify their own rules. We also expose low-level API `forwardTokensAndSample()`. -- [cache-usage](cache-usage): demonstrates how WebLLM supports both the [Cache API](https://developer.mozilla.org/en-US/docs/Web/API/Cache) and [IndexedDB cache](https://developer.mozilla.org/en-US/docs/Web/API/IndexedDB_API), and - users can pick with `appConfig.useIndexedDBCache`. Also demonstrates various cache utils such as checking +- [cache-usage](cache-usage): demonstrates how WebLLM supports multiple cache backends. Choose between the [Cache API](https://developer.mozilla.org/en-US/docs/Web/API/Cache), [IndexedDB cache](https://developer.mozilla.org/en-US/docs/Web/API/IndexedDB_API), or the experimental Chrome [Cross-Origin Storage](https://github.com/explainers-by-googlers/cross-origin-storage) extension via `appConfig.cacheBackend`. Also demonstrates various cache utils such as checking whether a model is cached, deleting a model's weights from cache, deleting a model library wasm from cache, etc. - [simple-chat-upload](simple-chat-upload): demonstrates how to upload local models to WebLLM instead of downloading via a URL link diff --git a/examples/cache-usage/README.md b/examples/cache-usage/README.md index dab6d623..51205acc 100644 --- a/examples/cache-usage/README.md +++ b/examples/cache-usage/README.md @@ -1,9 +1,13 @@ # WebLLM Cache Usage -WebLLM supports both the Cache API and IndexedDB, which you can specify via `AppConfig.useIndexedDBCache`. -This folder provides an example on how Cache and IndexedDB Cache are used in WebLLM. We also +WebLLM supports multiple persistent cache backends. You can pick the classic Cache API, IndexedDB, or the experimental Chrome [Cross-Origin Storage](https://github.com/explainers-by-googlers/cross-origin-storage) extension by +setting `AppConfig.cacheBackend` to `"cache"`, `"indexeddb"`, or `"cross-origin"`. (`AppConfig.useIndexedDBCache` +is still honored for backward compatibility.) +This folder provides an example on how different caches are used in WebLLM. We also demonstrate the utility cache functions such as deleting models, checking if models are in cache, etc. +> **Note:** The cross-origin backend requires Chrome's cross-origin storage experiment or the community browser extension to be installed and granted access to the domains that host your model artifacts (e.g. huggingface.co). + For more information about the two caches, see: https://developer.mozilla.org/en-US/docs/Web/API/Storage_API/Storage_quotas_and_eviction_criteria#what_technologies_store_data_in_the_browser. To inspect the downloaded artifacts in your browser, open up developer console, go to application, From c08d7beccc90f28ae6874e0ec0aa3b003b637f18 Mon Sep 17 00:00:00 2001 From: Akaash Parthasarathy Date: Thu, 27 Nov 2025 14:33:27 -0500 Subject: [PATCH 3/3] Remove backward compatibility with useIndexedDBCache flag --- examples/cache-usage/README.md | 3 +- examples/cache-usage/package.json | 4 +- src/cache_util.ts | 32 +++++++--- src/config.ts | 9 ++- src/cross_origin_storage.ts | 62 ++++++++++++++++---- src/cross_origin_storage_cache.ts | 3 +- src/utils.ts | 5 +- tests/scripts/sanity_checks/sanity_checks.ts | 2 +- 8 files changed, 84 insertions(+), 36 deletions(-) diff --git a/examples/cache-usage/README.md b/examples/cache-usage/README.md index 51205acc..7db09833 100644 --- a/examples/cache-usage/README.md +++ b/examples/cache-usage/README.md @@ -1,8 +1,7 @@ # WebLLM Cache Usage WebLLM supports multiple persistent cache backends. You can pick the classic Cache API, IndexedDB, or the experimental Chrome [Cross-Origin Storage](https://github.com/explainers-by-googlers/cross-origin-storage) extension by -setting `AppConfig.cacheBackend` to `"cache"`, `"indexeddb"`, or `"cross-origin"`. (`AppConfig.useIndexedDBCache` -is still honored for backward compatibility.) +setting `AppConfig.cacheBackend` to `"cache"`, `"indexeddb"`, or `"cross-origin"`. This folder provides an example on how different caches are used in WebLLM. We also demonstrate the utility cache functions such as deleting models, checking if models are in cache, etc. diff --git a/examples/cache-usage/package.json b/examples/cache-usage/package.json index 17e4fdc1..a91d3312 100644 --- a/examples/cache-usage/package.json +++ b/examples/cache-usage/package.json @@ -3,12 +3,12 @@ "version": "0.1.0", "private": true, "scripts": { - "start": "parcel src/cache_usage.html --port 8888", + "start": "parcel src/cache_usage.html --port 8889", "build": "parcel build src/cache_usage.html --dist-dir lib" }, "devDependencies": { "buffer": "^5.7.1", - "parcel": "^2.8.3", + "parcel": "2.8.3", "process": "^0.11.10", "tslib": "^2.3.1", "typescript": "^4.9.5", diff --git a/src/cache_util.ts b/src/cache_util.ts index 027e3fbf..97aa89c6 100644 --- a/src/cache_util.ts +++ b/src/cache_util.ts @@ -15,6 +15,26 @@ import CrossOriginStorageCache from "./cross_origin_storage_cache"; type CacheScope = "webllm/model" | "webllm/config" | "webllm/wasm"; let crossOriginUnavailableLogged = false; +let crossOriginAvailabilityWait: Promise | null = null; + +function scheduleCrossOriginFallbackWarning( + logger: (msg: string) => void, +): void { + if (crossOriginUnavailableLogged || crossOriginAvailabilityWait) { + return; + } + crossOriginAvailabilityWait = (async () => { + const availableSoon = await CrossOriginStorage.waitForAvailability(); + crossOriginAvailabilityWait = null; + if (availableSoon || crossOriginUnavailableLogged) { + return; + } + logger( + "Cross-origin storage backend is not yet available; temporarily falling back to the Cache API.", + ); + crossOriginUnavailableLogged = true; + })(); +} function shouldUseCrossOrigin(appConfig: AppConfig): boolean { return ( @@ -33,13 +53,7 @@ export function getArtifactCache( if (CrossOriginStorage.isAvailable()) { return new CrossOriginStorageCache(scope); } - // Fallback to Cache API - if (!crossOriginUnavailableLogged) { - logger( - "Cross-origin storage backend requested but unavailable; falling back to Cache API.", - ); - crossOriginUnavailableLogged = true; - } + scheduleCrossOriginFallbackWarning(logger); } if (backend === "indexeddb") { return new tvmjs.ArtifactIndexedDBCache(scope); @@ -81,6 +95,10 @@ async function deleteTensorCacheEntries( try { manifest = await cache.fetchWithCache(jsonUrl, "json"); } catch (err) { + console.warn( + `Failed to load tensor cache manifest at ${jsonUrl}; skipping deletion.`, + err, + ); return; } const records = manifest?.records ?? []; diff --git a/src/config.ts b/src/config.ts index 99d4c70a..ad69b13e 100644 --- a/src/config.ts +++ b/src/config.ts @@ -269,8 +269,8 @@ export interface ModelRecord { * passed to the load. * * @param model_list: models to be used. - * @param useIndexedDBCache: if true, will use IndexedDBCache to cache models and other artifacts. - * If false or unspecified, will use the Cache API. For more information of the two, see: + * @param cacheBackend: the backend to use for caching models and other artifacts. + * If unspecified, will use the Cache API. For more information, see: * https://developer.mozilla.org/en-US/docs/Web/API/Storage_API/Storage_quotas_and_eviction_criteria#what_technologies_store_data_in_the_browser * * @note Note that the Cache API is more well-tested in WebLLM as of now. @@ -279,7 +279,6 @@ export type CacheBackend = "cache" | "indexeddb" | "cross-origin"; export interface AppConfig { model_list: Array; - useIndexedDBCache?: boolean; cacheBackend?: CacheBackend; } @@ -287,7 +286,7 @@ export function getCacheBackend(appConfig: AppConfig): CacheBackend { if (appConfig.cacheBackend !== undefined) { return appConfig.cacheBackend; } - return appConfig.useIndexedDBCache ? "indexeddb" : "cache"; + return "cache"; } /** @@ -319,7 +318,7 @@ export const functionCallingModelIds = [ * current WebLLM npm version. */ export const prebuiltAppConfig: AppConfig = { - useIndexedDBCache: false, + cacheBackend: "cache", model_list: [ // Llama-3.2 { diff --git a/src/cross_origin_storage.ts b/src/cross_origin_storage.ts index 9d4c2659..1e99b5d3 100644 --- a/src/cross_origin_storage.ts +++ b/src/cross_origin_storage.ts @@ -1,5 +1,17 @@ const HASH_ALGORITHM = "SHA-256"; const HASH_MATCH_REGEX = /[A-Fa-f0-9]{64}/; +const AVAILABILITY_POLL_INTERVAL_MS = 100; +const DEFAULT_AVAILABILITY_TIMEOUT_MS = 3000; +const HASH_CACHE_SYMBOL = Symbol.for("mlc.crossOriginStorage.hashCache"); + +const globalScope = globalThis as Record; +if (!globalScope[HASH_CACHE_SYMBOL]) { + globalScope[HASH_CACHE_SYMBOL] = new Map(); +} +const GLOBAL_HASH_CACHE = globalScope[HASH_CACHE_SYMBOL] as Map< + string, + CrossOriginHashDescriptor +>; export interface CrossOriginHashDescriptor { algorithm: string; @@ -16,7 +28,6 @@ interface CrossOriginStorageAPI { descriptors: CrossOriginHashDescriptor[], options?: { create?: boolean }, ): Promise; - removeFileHandles?(descriptors: CrossOriginHashDescriptor[]): Promise; } type RequestLike = string | URL | Request | { url?: string }; @@ -25,13 +36,16 @@ declare global { interface Navigator { crossOriginStorage?: CrossOriginStorageAPI; } + interface WorkerNavigator { + crossOriginStorage?: CrossOriginStorageAPI; + } } export default class CrossOriginStorage { private hashCache: Map; constructor() { - this.hashCache = new Map(); + this.hashCache = GLOBAL_HASH_CACHE; } static isAvailable(): boolean { @@ -42,6 +56,35 @@ export default class CrossOriginStorage { ); } + static async waitForAvailability( + timeoutMs: number = DEFAULT_AVAILABILITY_TIMEOUT_MS, + ): Promise { + if (CrossOriginStorage.isAvailable()) { + return true; + } + if (typeof navigator === "undefined") { + return false; + } + if (typeof setTimeout === "undefined") { + return false; + } + return new Promise((resolve) => { + const deadline = Date.now() + timeoutMs; + const tick = () => { + if (CrossOriginStorage.isAvailable()) { + resolve(true); + return; + } + if (Date.now() >= deadline) { + resolve(false); + return; + } + setTimeout(tick, AVAILABILITY_POLL_INTERVAL_MS); + }; + setTimeout(tick, AVAILABILITY_POLL_INTERVAL_MS); + }); + } + async match(request: RequestLike): Promise { const url = this.normalizeRequest(request); const hash = await this.resolveHashDescriptor(url); @@ -84,17 +127,10 @@ export default class CrossOriginStorage { this.hashCache.set(url, hash); } + // eslint-disable-next-line @typescript-eslint/no-unused-vars async delete(request: RequestLike): Promise { - const url = this.normalizeRequest(request); - const hash = await this.resolveHashDescriptor(url); - if (!hash) { - return; - } - const api = this.getApi(); - if (api && typeof api.removeFileHandles === "function") { - await api.removeFileHandles([hash]); - } - this.hashCache.delete(url); + // Currently no delete API provided by Cross-Origin Storage Extension + return; } private getApi(): CrossOriginStorageAPI | undefined { @@ -145,7 +181,7 @@ export default class CrossOriginStorage { if (metadataHash) { return metadataHash; } - if (/\/resolve\/main\//.test(url)) { + if (/\/resolve\//.test(url)) { const pointerHash = await this.extractHashFromPointer(url); if (pointerHash) { return pointerHash; diff --git a/src/cross_origin_storage_cache.ts b/src/cross_origin_storage_cache.ts index 4bfab8f3..211f0424 100644 --- a/src/cross_origin_storage_cache.ts +++ b/src/cross_origin_storage_cache.ts @@ -67,8 +67,7 @@ export class CrossOriginStorageCache implements tvmjs.ArtifactCacheTemplate { } async deleteInCache(_url: string): Promise { - // no delete API currently provided by Cross-Origin Storage - return; + await this.storage.delete(_url); } private async responseToStoreType( diff --git a/src/utils.ts b/src/utils.ts index 2c91b1e8..17a86925 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -76,10 +76,7 @@ export function areAppConfigsEqual( return config1 === config2; } - // Check if both configurations have the same IndexedDB cache usage - if (config1.useIndexedDBCache !== config2.useIndexedDBCache) { - return false; - } + // Check if both configurations have the same cache backend if (config1.cacheBackend !== config2.cacheBackend) { return false; } diff --git a/tests/scripts/sanity_checks/sanity_checks.ts b/tests/scripts/sanity_checks/sanity_checks.ts index da842353..2f96e051 100644 --- a/tests/scripts/sanity_checks/sanity_checks.ts +++ b/tests/scripts/sanity_checks/sanity_checks.ts @@ -157,7 +157,7 @@ async function testLogprobs(modelId: string, appConfig: webllm.AppConfig) { async function main() { const modelId = "Qwen3-0.6B-q0f32-MLC"; const appConfig = webllm.prebuiltAppConfig; - appConfig.useIndexedDBCache = true; + appConfig.cacheBackend = "indexeddb"; setLabel("gpu-test-label", "Running tests..."); let passed = 0, total = 0;