diff --git a/src/backends/onnx.js b/src/backends/onnx.js index 38cd71337..1b485c185 100644 --- a/src/backends/onnx.js +++ b/src/backends/onnx.js @@ -141,19 +141,19 @@ let wasmInitPromise = null; /** * Create an ONNX inference session. - * @param {Uint8Array} buffer The ONNX model buffer. + * @param {Uint8Array|string} buffer_or_path The ONNX model buffer or path. * @param {import('onnxruntime-common').InferenceSession.SessionOptions} session_options ONNX inference session options. * @param {Object} session_config ONNX inference session configuration. * @returns {Promise} The ONNX inference session. */ -export async function createInferenceSession(buffer, session_options, session_config) { +export async function createInferenceSession(buffer_or_path, session_options, session_config) { if (wasmInitPromise) { // A previous session has already initialized the WASM runtime // so we wait for it to resolve before creating this new session. await wasmInitPromise; } - const sessionPromise = InferenceSession.create(buffer, session_options); + const sessionPromise = InferenceSession.create(buffer_or_path, session_options); wasmInitPromise ??= sessionPromise; const session = await sessionPromise; session.config = session_config; diff --git a/src/configs.js b/src/configs.js index 94f0b31f0..19c7050c9 100644 --- a/src/configs.js +++ b/src/configs.js @@ -407,5 +407,5 @@ export class AutoConfig { * for more information. * @property {import('./utils/devices.js').DeviceType} [device] The default device to use for the model. * @property {import('./utils/dtypes.js').DataType|Record} [dtype] The default data type to use for the model. - * @property {boolean|Record} [use_external_data_format=false] Whether to load the model using the external data format (used for models >= 2GB in size). + * @property {import('./utils/hub.js').ExternalData|Record} [use_external_data_format=false] Whether to load the model using the external data format (used for models >= 2GB in size). */ diff --git a/src/models.js b/src/models.js index 976d1c000..cf0bb8083 100644 --- a/src/models.js +++ b/src/models.js @@ -68,6 +68,7 @@ import { import { getModelFile, getModelJSON, + MAX_EXTERNAL_DATA_CHUNKS, } from './utils/hub.js'; import { @@ -152,7 +153,7 @@ const MODEL_CLASS_TO_NAME_MAPPING = new Map(); * @param {string} pretrained_model_name_or_path The path to the directory containing the model file. * @param {string} fileName The name of the model file. * @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the model. - * @returns {Promise<{buffer: Uint8Array, session_options: Object, session_config: Object}>} A Promise that resolves to the data needed to create an InferenceSession object. + * @returns {Promise<{buffer_or_path: Uint8Array|string, session_options: Object, session_config: Object}>} A Promise that resolves to the data needed to create an InferenceSession object. * @private */ async function getSession(pretrained_model_name_or_path, fileName, options) { @@ -227,7 +228,8 @@ async function getSession(pretrained_model_name_or_path, fileName, options) { // Construct the model file name const suffix = DEFAULT_DTYPE_SUFFIX_MAPPING[selectedDtype]; - const modelFileName = `${options.subfolder ?? ''}/${fileName}${suffix}.onnx`; + const baseName = `${fileName}${suffix}.onnx`; + const modelFileName = `${options.subfolder ?? ''}/${baseName}`; const session_options = { ...options.session_options }; @@ -245,29 +247,38 @@ async function getSession(pretrained_model_name_or_path, fileName, options) { ); } - const bufferPromise = getModelFile(pretrained_model_name_or_path, modelFileName, true, options); + const bufferOrPathPromise = getModelFile(pretrained_model_name_or_path, modelFileName, true, options, apis.IS_NODE_ENV); // handle onnx external data files const use_external_data_format = options.use_external_data_format ?? custom_config.use_external_data_format; - /** @type {Promise<{path: string, data: Uint8Array}>[]} */ + /** @type {Promise[]} */ let externalDataPromises = []; - if (use_external_data_format && ( - use_external_data_format === true || - ( - typeof use_external_data_format === 'object' && - use_external_data_format.hasOwnProperty(fileName) && - use_external_data_format[fileName] === true - ) - )) { - if (apis.IS_NODE_ENV) { - throw new Error('External data format is not yet supported in Node.js'); + if (use_external_data_format) { + let external_data_format; + if (typeof use_external_data_format === 'object') { + if (use_external_data_format.hasOwnProperty(baseName)) { + external_data_format = use_external_data_format[baseName]; + } else if (use_external_data_format.hasOwnProperty(fileName)) { + external_data_format = use_external_data_format[fileName]; + } else { + external_data_format = false; + } + } else { + external_data_format = use_external_data_format; + } + + const num_chunks = +external_data_format; // (false=0, true=1, number remains the same) + if (num_chunks > MAX_EXTERNAL_DATA_CHUNKS) { + throw new Error(`The number of external data chunks (${num_chunks}) exceeds the maximum allowed value (${MAX_EXTERNAL_DATA_CHUNKS}).`); + } + for (let i = 0; i < num_chunks; ++i) { + const path = `${baseName}_data${i === 0 ? '' : '_' + i}`; + const fullPath = `${options.subfolder ?? ''}/${path}`; + externalDataPromises.push(new Promise(async (resolve, reject) => { + const data = await getModelFile(pretrained_model_name_or_path, fullPath, true, options, apis.IS_NODE_ENV); + resolve(data instanceof Uint8Array ? { path, data } : path); + })); } - const path = `${fileName}${suffix}.onnx_data`; - const fullPath = `${options.subfolder ?? ''}/${path}`; - externalDataPromises.push(new Promise(async (resolve, reject) => { - const data = await getModelFile(pretrained_model_name_or_path, fullPath, true, options); - resolve({ path, data }) - })); } else if (session_options.externalData !== undefined) { externalDataPromises = session_options.externalData.map(async (ext) => { @@ -284,7 +295,10 @@ async function getSession(pretrained_model_name_or_path, fileName, options) { } if (externalDataPromises.length > 0) { - session_options.externalData = await Promise.all(externalDataPromises); + const externalData = await Promise.all(externalDataPromises); + if (!apis.IS_NODE_ENV) { + session_options.externalData = externalData; + } } if (selectedDevice === 'webgpu') { @@ -302,9 +316,9 @@ async function getSession(pretrained_model_name_or_path, fileName, options) { } } - const buffer = await bufferPromise; + const buffer_or_path = await bufferOrPathPromise; - return { buffer, session_options, session_config }; + return { buffer_or_path, session_options, session_config }; } /** @@ -319,8 +333,8 @@ async function getSession(pretrained_model_name_or_path, fileName, options) { async function constructSessions(pretrained_model_name_or_path, names, options) { return Object.fromEntries(await Promise.all( Object.keys(names).map(async (name) => { - const { buffer, session_options, session_config } = await getSession(pretrained_model_name_or_path, names[name], options); - const session = await createInferenceSession(buffer, session_options, session_config); + const { buffer_or_path, session_options, session_config } = await getSession(pretrained_model_name_or_path, names[name], options); + const session = await createInferenceSession(buffer_or_path, session_options, session_config); return [name, session]; }) )); diff --git a/src/pipelines.js b/src/pipelines.js index 649b00a49..6901e1f77 100644 --- a/src/pipelines.js +++ b/src/pipelines.js @@ -3299,6 +3299,8 @@ export async function pipeline( revision = 'main', device = null, dtype = null, + subfolder = 'onnx', + use_external_data_format = null, model_file_name = null, session_options = {}, } = {} @@ -3329,6 +3331,8 @@ export async function pipeline( revision, device, dtype, + subfolder, + use_external_data_format, model_file_name, session_options, } diff --git a/src/utils/hub.js b/src/utils/hub.js index 17ee4c1b1..fe44bd614 100755 --- a/src/utils/hub.js +++ b/src/utils/hub.js @@ -8,9 +8,16 @@ import fs from 'fs'; import path from 'path'; -import { env } from '../env.js'; +import { apis, env } from '../env.js'; import { dispatchCallback } from './core.js'; +/** + * @typedef {boolean|number} ExternalData Whether to load the model using the external data format (used for models >= 2GB in size). + * If `true`, the model will be loaded using the external data format. + * If a number, this many chunks will be loaded using the external data format (of the form: "model.onnx_data[_{chunk_number}]"). + */ +export const MAX_EXTERNAL_DATA_CHUNKS = 100; + /** * @typedef {Object} PretrainedOptions Options for loading a pretrained model. * @property {import('./core.js').ProgressCallback} [progress_callback=null] If specified, this function will be called during model construction, to provide the user with progress updates. @@ -31,7 +38,7 @@ import { dispatchCallback } from './core.js'; * @property {string} [model_file_name=null] If specified, load the model with this name (excluding the .onnx suffix). Currently only valid for encoder- or decoder-only models. * @property {import("./devices.js").DeviceType|Record} [device=null] The device to run the model on. If not specified, the device will be chosen from the environment settings. * @property {import("./dtypes.js").DataType|Record} [dtype=null] The data type to use for the model. If not specified, the data type will be chosen from the environment settings. - * @property {boolean|Record} [use_external_data_format=false] Whether to load the model using the external data format (used for models >= 2GB in size). + * @property {ExternalData|Record} [use_external_data_format=false] Whether to load the model using the external data format (used for models >= 2GB in size). * @property {import('onnxruntime-common').InferenceSession.SessionOptions} [session_options] (Optional) User-specified session options passed to the runtime. If not provided, suitable defaults will be chosen. */ @@ -57,7 +64,7 @@ class FileResponse { /** * Creates a new `FileResponse` object. - * @param {string|URL} filePath + * @param {string} filePath */ constructor(filePath) { this.filePath = filePath; @@ -73,13 +80,15 @@ class FileResponse { this.updateContentType(); - let self = this; + const stream = fs.createReadStream(filePath); this.body = new ReadableStream({ start(controller) { - self.arrayBuffer().then(buffer => { - controller.enqueue(new Uint8Array(buffer)); - controller.close(); - }) + stream.on('data', (chunk) => controller.enqueue(chunk)); + stream.on('end', () => controller.close()); + stream.on('error', (err) => controller.error(err)); + }, + cancel() { + stream.destroy(); } }); } else { @@ -190,7 +199,7 @@ function isValidUrl(string, protocols = null, validHosts = null) { export async function getFile(urlOrPath) { if (env.useFS && !isValidUrl(urlOrPath, ['http:', 'https:', 'blob:'])) { - return new FileResponse(urlOrPath); + return new FileResponse(urlOrPath.toString()); } else if (typeof process !== 'undefined' && process?.release?.name === 'node') { const IS_CI = !!process.env?.TESTING_REMOTELY; @@ -281,20 +290,52 @@ class FileCache { /** * Adds the given response to the cache. * @param {string} request - * @param {Response|FileResponse} response + * @param {Response} response + * @param {(data: {progress: number, loaded: number, total: number}) => void} [progress_callback] Optional. + * The function to call with progress updates * @returns {Promise} */ - async put(request, response) { - const buffer = Buffer.from(await response.arrayBuffer()); - - let outputPath = path.join(this.path, request); + async put(request, response, progress_callback = undefined) { + let filePath = path.join(this.path, request); try { - await fs.promises.mkdir(path.dirname(outputPath), { recursive: true }); - await fs.promises.writeFile(outputPath, buffer); + const contentLength = response.headers.get('Content-Length'); + const total = parseInt(contentLength ?? '0'); + let loaded = 0; + + await fs.promises.mkdir(path.dirname(filePath), { recursive: true }); + const fileStream = fs.createWriteStream(filePath); + const reader = response.body.getReader(); + + while (true) { + const { done, value } = await reader.read(); + if (done) { + break; + } + + await new Promise((resolve, reject) => { + fileStream.write(value, (err) => { + if (err) { + reject(err); + return; + } + resolve(); + }); + }); + + loaded += value.length; + const progress = total ? (loaded / total) * 100 : 0; + + progress_callback?.({ progress, loaded, total }); + } - } catch (err) { - console.warn('An error occurred while writing the file to cache:', err) + fileStream.close(); + } catch (error) { + // Clean up the file if an error occurred during download + try { + await fs.promises.unlink(filePath); + } catch { } + throw error; } } @@ -325,21 +366,21 @@ async function tryCache(cache, ...names) { } /** - * * Retrieves a file from either a remote URL using the Fetch API or from the local file system using the FileSystem API. * If the filesystem is available and `env.useCache = true`, the file will be downloaded and cached. - * + * * @param {string} path_or_repo_id This can be either: * - a string, the *model id* of a model repo on huggingface.co. * - a path to a *directory* potentially containing the file. * @param {string} filename The name of the file to locate in `path_or_repo`. * @param {boolean} [fatal=true] Whether to throw an error if the file is not found. * @param {PretrainedOptions} [options] An object containing optional parameters. - * + * @param {boolean} [return_path=false] Whether to return the path of the file instead of the file content. + * * @throws Will throw an error if the file is not found and `fatal` is true. - * @returns {Promise} A Promise that resolves with the file content as a buffer. + * @returns {Promise} A Promise that resolves with the file content as a Uint8Array if `return_path` is false, or the file path as a string if `return_path` is true. */ -export async function getModelFile(path_or_repo_id, filename, fatal = true, options = {}) { +export async function getModelFile(path_or_repo_id, filename, fatal = true, options = {}, return_path = false) { if (!env.allowLocalModels) { // User has disabled local models, so we just make sure other settings are correct. @@ -403,8 +444,9 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti const revision = options.revision ?? 'main'; let requestURL = pathJoin(path_or_repo_id, filename); - let localPath = pathJoin(env.localModelPath, requestURL); + let cachePath = pathJoin(env.localModelPath, requestURL); + let localPath = requestURL; let remoteURL = pathJoin( env.remoteHost, env.remotePathTemplate @@ -433,7 +475,7 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti // 1. We first try to get from cache using the local path. In some environments (like deno), // non-URL cache keys are not allowed. In these cases, `response` will be undefined. // 2. If no response is found, we try to get from cache using the remote URL or file system cache. - response = await tryCache(cache, localPath, proposedCacheKey); + response = await tryCache(cache, cachePath, proposedCacheKey); } const cacheHit = response !== undefined; @@ -455,9 +497,9 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti console.warn(`Unable to load from local path "${localPath}": "${e}"`); } } else if (options.local_files_only) { - throw new Error(`\`local_files_only=true\`, but attempted to load a remote file from: ${requestURL}.`); + throw new Error(`\`local_files_only=true\`, but attempted to load a remote file from: ${localPath}.`); } else if (!env.allowRemoteModels) { - throw new Error(`\`env.allowRemoteModels=false\`, but attempted to load a remote file from: ${requestURL}.`); + throw new Error(`\`env.allowRemoteModels=false\`, but attempted to load a remote file from: ${localPath}.`); } } @@ -504,41 +546,45 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti file: filename }) - /** @type {Uint8Array} */ - let buffer; - - if (!options.progress_callback) { - // If no progress callback is specified, we can use the `.arrayBuffer()` - // method to read the response. - buffer = new Uint8Array(await response.arrayBuffer()); - - } else if ( - cacheHit // The item is being read from the cache - && - typeof navigator !== 'undefined' && /firefox/i.test(navigator.userAgent) // We are in Firefox - ) { - // Due to bug in Firefox, we cannot display progress when loading from cache. - // Fortunately, since this should be instantaneous, this should not impact users too much. - buffer = new Uint8Array(await response.arrayBuffer()); - - // For completeness, we still fire the final progress callback - dispatchCallback(options.progress_callback, { - status: 'progress', - name: path_or_repo_id, - file: filename, - progress: 100, - loaded: buffer.length, - total: buffer.length, - }) - } else { - buffer = await readResponse(response, data => { + let result; + if (!(apis.IS_NODE_ENV && return_path)) { + /** @type {Uint8Array} */ + let buffer; + + if (!options.progress_callback) { + // If no progress callback is specified, we can use the `.arrayBuffer()` + // method to read the response. + buffer = new Uint8Array(await response.arrayBuffer()); + + } else if ( + cacheHit // The item is being read from the cache + && + typeof navigator !== 'undefined' && /firefox/i.test(navigator.userAgent) // We are in Firefox + ) { + // Due to bug in Firefox, we cannot display progress when loading from cache. + // Fortunately, since this should be instantaneous, this should not impact users too much. + buffer = new Uint8Array(await response.arrayBuffer()); + + // For completeness, we still fire the final progress callback dispatchCallback(options.progress_callback, { status: 'progress', name: path_or_repo_id, file: filename, - ...data, + progress: 100, + loaded: buffer.length, + total: buffer.length, }) - }) + } else { + buffer = await readResponse(response, data => { + dispatchCallback(options.progress_callback, { + status: 'progress', + name: path_or_repo_id, + file: filename, + ...data, + }) + }) + } + result = buffer; } if ( @@ -549,25 +595,43 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti // Check again whether request is in cache. If not, we add the response to the cache (await cache.match(cacheKey) === undefined) ) { - // NOTE: We use `new Response(buffer, ...)` instead of `response.clone()` to handle LFS files - await cache.put(cacheKey, new Response(buffer, { - headers: response.headers - })) - .catch(err => { - // Do not crash if unable to add to cache (e.g., QuotaExceededError). - // Rather, log a warning and proceed with execution. - console.warn(`Unable to add response to browser cache: ${err}.`); - }); - + if (!result) { + // We haven't yet read the response body, so we need to do so now. + await cache.put(cacheKey, /** @type {Response} */(response), options.progress_callback); + } else { + // NOTE: We use `new Response(buffer, ...)` instead of `response.clone()` to handle LFS files + await cache.put(cacheKey, new Response(result, { + headers: response.headers + })) + .catch(err => { + // Do not crash if unable to add to cache (e.g., QuotaExceededError). + // Rather, log a warning and proceed with execution. + console.warn(`Unable to add response to browser cache: ${err}.`); + }); + } } - dispatchCallback(options.progress_callback, { status: 'done', name: path_or_repo_id, file: filename }); - return buffer; + if (result) { + if (return_path) { + throw new Error("Cannot return path in a browser environment.") + } + return result; + } + if (response instanceof FileResponse) { + return response.filePath; + } + + const path = await cache.match(cacheKey); + if (path instanceof FileResponse) { + return path.filePath; + } + throw new Error("Unable to return path for response."); + } /** @@ -581,14 +645,14 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti * @throws Will throw an error if the file is not found and `fatal` is true. */ export async function getModelJSON(modelPath, fileName, fatal = true, options = {}) { - let buffer = await getModelFile(modelPath, fileName, fatal, options); + const buffer = await getModelFile(modelPath, fileName, fatal, options, false); if (buffer === null) { // Return empty object return {} } - let decoder = new TextDecoder('utf-8'); - let jsonData = decoder.decode(buffer); + const decoder = new TextDecoder('utf-8'); + const jsonData = decoder.decode(/** @type {Uint8Array} */(buffer)); return JSON.parse(jsonData); } @@ -614,30 +678,26 @@ async function readResponse(response, progress_callback) { const { done, value } = await reader.read(); if (done) return; - let newLoaded = loaded + value.length; + const newLoaded = loaded + value.length; if (newLoaded > total) { total = newLoaded; // Adding the new data will overflow buffer. // In this case, we extend the buffer - let newBuffer = new Uint8Array(total); + const newBuffer = new Uint8Array(total); // copy contents newBuffer.set(buffer); buffer = newBuffer; } - buffer.set(value, loaded) + buffer.set(value, loaded); loaded = newLoaded; const progress = (loaded / total) * 100; // Call your function here - progress_callback({ - progress: progress, - loaded: loaded, - total: total, - }) + progress_callback({ progress, loaded, total }); return read(); } diff --git a/tests/models.test.js b/tests/models.test.js index ec52fc49d..b5ddfee1c 100644 --- a/tests/models.test.js +++ b/tests/models.test.js @@ -2,7 +2,7 @@ * Test that models loaded outside of the `pipeline` function work correctly (e.g., `AutoModel.from_pretrained(...)`); */ -import { AutoTokenizer, AutoModel, BertModel, GPT2Model, T5ForConditionalGeneration, BertTokenizer, GPT2Tokenizer, T5Tokenizer } from "../src/transformers.js"; +import { AutoTokenizer, AutoProcessor, BertForMaskedLM, GPT2LMHeadModel, T5ForConditionalGeneration, BertTokenizer, GPT2Tokenizer, T5Tokenizer, LlamaTokenizer, LlamaForCausalLM, WhisperForConditionalGeneration, WhisperProcessor, AutoModelForMaskedLM, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoModelForSpeechSeq2Seq } from "../src/transformers.js"; import { init, MAX_TEST_EXECUTION_TIME, DEFAULT_MODEL_OPTIONS } from "./init.js"; import { compare, collect_and_execute_tests } from "./test_utils.js"; @@ -12,44 +12,53 @@ init(); describe("Loading different architecture types", () => { // List all models which will be tested const models_to_test = [ - // [name, modelClass, tokenizerClass] - ["hf-internal-testing/tiny-random-BertForMaskedLM", BertModel, BertTokenizer], // Encoder-only - ["hf-internal-testing/tiny-random-GPT2LMHeadModel", GPT2Model, GPT2Tokenizer], // Decoder-only - ["hf-internal-testing/tiny-random-T5ForConditionalGeneration", T5ForConditionalGeneration, T5Tokenizer], // Encoder-decoder + // [name, [AutoModelClass, ModelClass], [AutoProcessorClass, ProcessorClass], [modelOptions?], [modality?]] + ["hf-internal-testing/tiny-random-BertForMaskedLM", [AutoModelForMaskedLM, BertForMaskedLM], [AutoTokenizer, BertTokenizer]], // Encoder-only + ["hf-internal-testing/tiny-random-GPT2LMHeadModel", [AutoModelForCausalLM, GPT2LMHeadModel], [AutoTokenizer, GPT2Tokenizer]], // Decoder-only + ["hf-internal-testing/tiny-random-T5ForConditionalGeneration", [AutoModelForSeq2SeqLM, T5ForConditionalGeneration], [AutoTokenizer, T5Tokenizer]], // Encoder-decoder + ["onnx-internal-testing/tiny-random-LlamaForCausalLM-ONNX_external", [AutoModelForCausalLM, LlamaForCausalLM], [AutoTokenizer, LlamaTokenizer]], // Decoder-only w/ external data + ["onnx-internal-testing/tiny-random-WhisperForConditionalGeneration-ONNX_external", [AutoModelForSpeechSeq2Seq, WhisperForConditionalGeneration], [AutoProcessor, WhisperProcessor], {}], // Encoder-decoder-only w/ external data ]; const texts = ["Once upon a time", "I like to eat apples"]; - for (const [model_id, modelClass, tokenizerClass] of models_to_test) { + for (const [model_id, models, processors, modelOptions] of models_to_test) { // Test that both the auto model and the specific model work - const tokenizers = [AutoTokenizer, tokenizerClass]; - const models = [AutoModel, modelClass]; - - for (let i = 0; i < tokenizers.length; ++i) { - const tokenizerClassToTest = tokenizers[i]; + for (let i = 0; i < processors.length; ++i) { + const processorClassToTest = processors[i]; const modelClassToTest = models[i]; it( `${model_id} (${modelClassToTest.name})`, async () => { - // Load model and tokenizer - const tokenizer = await tokenizerClassToTest.from_pretrained(model_id); - const model = await modelClassToTest.from_pretrained(model_id, DEFAULT_MODEL_OPTIONS); + // Load model and processor + const processor = await processorClassToTest.from_pretrained(model_id); + const model = await modelClassToTest.from_pretrained(model_id, modelOptions ?? DEFAULT_MODEL_OPTIONS); const tests = [ texts[0], // single texts, // batched ]; + + const { model_type } = model.config; + const tokenizer = model_type === "whisper" ? processor.tokenizer : processor; + const feature_extractor = model_type === "whisper" ? processor.feature_extractor : null; + for (const test of tests) { const inputs = await tokenizer(test, { truncation: true, padding: true }); if (model.config.is_encoder_decoder) { inputs.decoder_input_ids = inputs.input_ids; } + if (feature_extractor) { + Object.assign(inputs, await feature_extractor(new Float32Array(16000))); + } + const output = await model(inputs); if (output.logits) { // Ensure correct shapes - const expected_shape = [...inputs.input_ids.dims, model.config.vocab_size]; + const input_ids = inputs.input_ids ?? inputs.decoder_input_ids; + const expected_shape = [...input_ids.dims, model.config.vocab_size]; const actual_shape = output.logits.dims; compare(expected_shape, actual_shape); } else if (output.last_hidden_state) {