diff --git a/src/backends/onnx.js b/src/backends/onnx.js index a64f9d160..b7d4ec7bc 100644 --- a/src/backends/onnx.js +++ b/src/backends/onnx.js @@ -56,12 +56,15 @@ let defaultDevices; let ONNX; const ORT_SYMBOL = Symbol.for('onnxruntime'); +/** @type {"custom"|"node"|"web"} */ +let ort; if (ORT_SYMBOL in globalThis) { // If the JS runtime exposes their own ONNX runtime, use it ONNX = globalThis[ORT_SYMBOL]; + ort = 'custom'; -} else if (apis.IS_NODE_ENV) { - ONNX = ONNX_NODE.default ?? ONNX_NODE; +} else if (apis.IS_NODE_ENV && (ONNX = ONNX_NODE.default ?? ONNX_NODE)?.InferenceSession) { + ort = 'node'; // Updated as of ONNX Runtime 1.20.1 // The following table lists the supported versions of ONNX Runtime Node.js binding provided with pre-built binaries. @@ -87,6 +90,7 @@ if (ORT_SYMBOL in globalThis) { defaultDevices = ['cpu']; } else { ONNX = ONNX_WEB; + ort = 'web'; if (apis.IS_WEBNN_AVAILABLE) { // TODO: Only push supported providers (depending on available hardware) @@ -169,6 +173,14 @@ export function isONNXTensor(x) { return x instanceof ONNX.Tensor; } +/** + * The type of ONNX runtime being used. + * - 'node' for `onnxruntime-node` + * - 'web' for `onnxruntime-web` + * - 'custom' for a custom ONNX runtime + */ +export const runtime = ort; + /** @type {import('onnxruntime-common').Env} */ // @ts-ignore const ONNX_ENV = ONNX?.env; diff --git a/src/env.js b/src/env.js index fd3924213..39693edd5 100644 --- a/src/env.js +++ b/src/env.js @@ -142,7 +142,7 @@ export const env = { remoteHost: 'https://huggingface.co/', remotePathTemplate: '{model}/resolve/{revision}/', - allowLocalModels: !(IS_BROWSER_ENV || IS_WEBWORKER_ENV), + allowLocalModels: !(IS_BROWSER_ENV || IS_WEBWORKER_ENV || IS_DENO_RUNTIME), localModelPath: localModelPath, useFS: IS_FS_AVAILABLE, diff --git a/src/models.js b/src/models.js index f05a58665..cff3ee8b6 100644 --- a/src/models.js +++ b/src/models.js @@ -48,6 +48,7 @@ import { createInferenceSession, isONNXTensor, isONNXProxy, + runtime, } from './backends/onnx.js'; import { DATA_TYPES, @@ -172,7 +173,7 @@ async function getSession(pretrained_model_name_or_path, fileName, options) { // If the device is not specified, we use the default (supported) execution providers. const selectedDevice = /** @type {import("./utils/devices.js").DeviceType} */( - device ?? (apis.IS_NODE_ENV ? 'cpu' : 'wasm') + device ?? (runtime === "web" ? 'wasm' : 'cpu') ); const executionProviders = deviceToExecutionProviders(selectedDevice); diff --git a/webpack.config.js b/webpack.config.js index 53b5ccd9d..52d2c8c1d 100644 --- a/webpack.config.js +++ b/webpack.config.js @@ -54,7 +54,21 @@ class PostBuildPlugin { { const src = path.join(__dirname, 'node_modules/onnxruntime-web/dist', ORT_JSEP_FILE); const dest = path.join(dist, ORT_JSEP_FILE); - fs.copyFileSync(src, dest); + + // Transformers.js uses both onnxruntime-web and onnxruntime-node in the same package, + // and the runtime we use depends on the environment (onnxruntime-web for web, onnxruntime-node for Node.js). + // This means that we don't currently support using the WASM backend in Node.js, so we disable this behaviour in the JSEP file. + const content = fs.readFileSync(src, 'utf8'); + const updatedContent = content + .replace( + `"object"==typeof process&&"object"==typeof process.versions&&"string"==typeof process.versions.node&&"renderer"!=process.type`, + "false", + ) + .replace( + `typeof globalThis.process?.versions?.node == 'string'`, + "false", + ) + fs.writeFileSync(dest, updatedContent); } }); }