Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/backends/onnx.js
Original file line number Diff line number Diff line change
Expand Up @@ -141,19 +141,19 @@ let wasmInitPromise = null;

/**
* Create an ONNX inference session.
* @param {Uint8Array} buffer The ONNX model buffer.
* @param {Uint8Array|string} buffer_or_path The ONNX model buffer or path.
* @param {import('onnxruntime-common').InferenceSession.SessionOptions} session_options ONNX inference session options.
* @param {Object} session_config ONNX inference session configuration.
* @returns {Promise<import('onnxruntime-common').InferenceSession & { config: Object}>} The ONNX inference session.
*/
export async function createInferenceSession(buffer, session_options, session_config) {
export async function createInferenceSession(buffer_or_path, session_options, session_config) {
if (wasmInitPromise) {
// A previous session has already initialized the WASM runtime
// so we wait for it to resolve before creating this new session.
await wasmInitPromise;
}

const sessionPromise = InferenceSession.create(buffer, session_options);
const sessionPromise = InferenceSession.create(buffer_or_path, session_options);
wasmInitPromise ??= sessionPromise;
const session = await sessionPromise;
session.config = session_config;
Expand Down
2 changes: 1 addition & 1 deletion src/configs.js
Original file line number Diff line number Diff line change
Expand Up @@ -407,5 +407,5 @@ export class AutoConfig {
* for more information.
* @property {import('./utils/devices.js').DeviceType} [device] The default device to use for the model.
* @property {import('./utils/dtypes.js').DataType|Record<string, import('./utils/dtypes.js').DataType>} [dtype] The default data type to use for the model.
* @property {boolean|Record<string, boolean>} [use_external_data_format=false] Whether to load the model using the external data format (used for models >= 2GB in size).
* @property {import('./utils/hub.js').ExternalData|Record<string, import('./utils/hub.js').ExternalData>} [use_external_data_format=false] Whether to load the model using the external data format (used for models >= 2GB in size).
*/
64 changes: 39 additions & 25 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ import {
import {
getModelFile,
getModelJSON,
MAX_EXTERNAL_DATA_CHUNKS,
} from './utils/hub.js';

import {
Expand Down Expand Up @@ -152,7 +153,7 @@ const MODEL_CLASS_TO_NAME_MAPPING = new Map();
* @param {string} pretrained_model_name_or_path The path to the directory containing the model file.
* @param {string} fileName The name of the model file.
* @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the model.
* @returns {Promise<{buffer: Uint8Array, session_options: Object, session_config: Object}>} A Promise that resolves to the data needed to create an InferenceSession object.
* @returns {Promise<{buffer_or_path: Uint8Array|string, session_options: Object, session_config: Object}>} A Promise that resolves to the data needed to create an InferenceSession object.
* @private
*/
async function getSession(pretrained_model_name_or_path, fileName, options) {
Expand Down Expand Up @@ -227,7 +228,8 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {

// Construct the model file name
const suffix = DEFAULT_DTYPE_SUFFIX_MAPPING[selectedDtype];
const modelFileName = `${options.subfolder ?? ''}/${fileName}${suffix}.onnx`;
const baseName = `${fileName}${suffix}.onnx`;
const modelFileName = `${options.subfolder ?? ''}/${baseName}`;

const session_options = { ...options.session_options };

Expand All @@ -245,29 +247,38 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
);
}

const bufferPromise = getModelFile(pretrained_model_name_or_path, modelFileName, true, options);
const bufferOrPathPromise = getModelFile(pretrained_model_name_or_path, modelFileName, true, options, apis.IS_NODE_ENV);

// handle onnx external data files
const use_external_data_format = options.use_external_data_format ?? custom_config.use_external_data_format;
/** @type {Promise<{path: string, data: Uint8Array}>[]} */
/** @type {Promise<string|{path: string, data: Uint8Array}>[]} */
let externalDataPromises = [];
if (use_external_data_format && (
use_external_data_format === true ||
(
typeof use_external_data_format === 'object' &&
use_external_data_format.hasOwnProperty(fileName) &&
use_external_data_format[fileName] === true
)
)) {
if (apis.IS_NODE_ENV) {
throw new Error('External data format is not yet supported in Node.js');
if (use_external_data_format) {
let external_data_format;
if (typeof use_external_data_format === 'object') {
if (use_external_data_format.hasOwnProperty(baseName)) {
external_data_format = use_external_data_format[baseName];
} else if (use_external_data_format.hasOwnProperty(fileName)) {
external_data_format = use_external_data_format[fileName];
} else {
external_data_format = false;
}
} else {
external_data_format = use_external_data_format;
}

const num_chunks = +external_data_format; // (false=0, true=1, number remains the same)
if (num_chunks > MAX_EXTERNAL_DATA_CHUNKS) {
throw new Error(`The number of external data chunks (${num_chunks}) exceeds the maximum allowed value (${MAX_EXTERNAL_DATA_CHUNKS}).`);
}
for (let i = 0; i < num_chunks; ++i) {
const path = `${baseName}_data${i === 0 ? '' : '_' + i}`;
const fullPath = `${options.subfolder ?? ''}/${path}`;
externalDataPromises.push(new Promise(async (resolve, reject) => {
const data = await getModelFile(pretrained_model_name_or_path, fullPath, true, options, apis.IS_NODE_ENV);
resolve(data instanceof Uint8Array ? { path, data } : path);
}));
}
const path = `${fileName}${suffix}.onnx_data`;
const fullPath = `${options.subfolder ?? ''}/${path}`;
externalDataPromises.push(new Promise(async (resolve, reject) => {
const data = await getModelFile(pretrained_model_name_or_path, fullPath, true, options);
resolve({ path, data })
}));

} else if (session_options.externalData !== undefined) {
externalDataPromises = session_options.externalData.map(async (ext) => {
Expand All @@ -284,7 +295,10 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
}

if (externalDataPromises.length > 0) {
session_options.externalData = await Promise.all(externalDataPromises);
const externalData = await Promise.all(externalDataPromises);
if (!apis.IS_NODE_ENV) {
session_options.externalData = externalData;
}
}

if (selectedDevice === 'webgpu') {
Expand All @@ -302,9 +316,9 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
}
}

const buffer = await bufferPromise;
const buffer_or_path = await bufferOrPathPromise;

return { buffer, session_options, session_config };
return { buffer_or_path, session_options, session_config };
}

/**
Expand All @@ -319,8 +333,8 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
async function constructSessions(pretrained_model_name_or_path, names, options) {
return Object.fromEntries(await Promise.all(
Object.keys(names).map(async (name) => {
const { buffer, session_options, session_config } = await getSession(pretrained_model_name_or_path, names[name], options);
const session = await createInferenceSession(buffer, session_options, session_config);
const { buffer_or_path, session_options, session_config } = await getSession(pretrained_model_name_or_path, names[name], options);
const session = await createInferenceSession(buffer_or_path, session_options, session_config);
return [name, session];
})
));
Expand Down
4 changes: 4 additions & 0 deletions src/pipelines.js
Original file line number Diff line number Diff line change
Expand Up @@ -3299,6 +3299,8 @@ export async function pipeline(
revision = 'main',
device = null,
dtype = null,
subfolder = 'onnx',
use_external_data_format = null,
model_file_name = null,
session_options = {},
} = {}
Expand Down Expand Up @@ -3329,6 +3331,8 @@ export async function pipeline(
revision,
device,
dtype,
subfolder,
use_external_data_format,
model_file_name,
session_options,
}
Expand Down
Loading