diff --git a/src/configs.js b/src/configs.js index dccf6add5..c0ee9243d 100644 --- a/src/configs.js +++ b/src/configs.js @@ -404,6 +404,7 @@ export class AutoConfig { /** * Transformers.js-specific configuration, possibly present in config.json under the key `transformers.js_config`. * @typedef {Object} TransformersJSConfig + * @property {Record} [device_config] Device-specific configurations. * @property {import('./utils/tensor.js').DataType|Record} [kv_cache_dtype] The data type of the key-value cache. * @property {Record} [free_dimension_overrides] Override the free dimensions of the model. * See https://onnxruntime.ai/docs/tutorials/web/env-flags-and-session-options.html#freedimensionoverrides @@ -412,3 +413,8 @@ export class AutoConfig { * @property {import('./utils/dtypes.js').DataType|Record} [dtype] The default data type to use for the model. * @property {import('./utils/hub.js').ExternalData|Record} [use_external_data_format=false] Whether to load the model using the external data format (used for models >= 2GB in size). */ + +/** + * Device-specific configuration options. + * @typedef {Omit} DeviceConfig + */ diff --git a/src/models.js b/src/models.js index 3c447d779..659c2f1e1 100644 --- a/src/models.js +++ b/src/models.js @@ -158,7 +158,8 @@ const MODEL_CLASS_TO_NAME_MAPPING = new Map(); * @private */ async function getSession(pretrained_model_name_or_path, fileName, options) { - const custom_config = options.config?.['transformers.js_config'] ?? {}; + let custom_config = options.config?.['transformers.js_config'] ?? {}; + let device = options.device ?? custom_config.device; if (device && typeof device !== 'string') { if (device.hasOwnProperty(fileName)) { @@ -173,8 +174,18 @@ async function getSession(pretrained_model_name_or_path, fileName, options) { const selectedDevice = /** @type {import("./utils/devices.js").DeviceType} */( device ?? (apis.IS_NODE_ENV ? 'cpu' : 'wasm') ); + const executionProviders = deviceToExecutionProviders(selectedDevice); + // Update custom config with the selected device's config, if it exists + const device_config = custom_config.device_config ?? {}; + if (device_config.hasOwnProperty(selectedDevice)) { + custom_config = { + ...custom_config, + ...device_config[selectedDevice], + }; + } + // If options.dtype is specified, we use it to choose the suffix for the model file. // Otherwise, we use the default dtype for the device. let dtype = options.dtype ?? custom_config.dtype; @@ -191,11 +202,11 @@ async function getSession(pretrained_model_name_or_path, fileName, options) { // Try to choose the auto dtype based on the custom config let config_dtype = custom_config.dtype; if (typeof config_dtype !== 'string') { - config_dtype = config_dtype[fileName]; + config_dtype = config_dtype?.[fileName]; } if (config_dtype && config_dtype !== DATA_TYPES.auto && DATA_TYPES.hasOwnProperty(config_dtype)) { - // Defined by the custom config, and is not "auto" + // Defined by the config, and is not "auto" dtype = config_dtype; } else { // Choose default dtype based on device, falling back to fp32 @@ -212,10 +223,11 @@ async function getSession(pretrained_model_name_or_path, fileName, options) { } // Only valid for models with a decoder - const kv_cache_dtype = custom_config.kv_cache_dtype - ? (typeof custom_config.kv_cache_dtype === 'string' - ? custom_config.kv_cache_dtype - : custom_config.kv_cache_dtype[selectedDtype] ?? 'float32') + const kv_cache_dtype_config = custom_config.kv_cache_dtype; + const kv_cache_dtype = kv_cache_dtype_config + ? (typeof kv_cache_dtype_config === 'string' + ? kv_cache_dtype_config + : kv_cache_dtype_config[selectedDtype] ?? 'float32') : undefined; if (kv_cache_dtype && !['float32', 'float16'].includes(kv_cache_dtype)) { @@ -243,14 +255,14 @@ async function getSession(pretrained_model_name_or_path, fileName, options) { session_options.freeDimensionOverrides ??= free_dimension_overrides; } else if (selectedDevice.startsWith('webnn') && !session_options.freeDimensionOverrides) { console.warn( - 'WebNN does not currently support dynamic shapes and requires `free_dimension_overrides` to be set in config.json as a field within "transformers.js_config". ' + - 'When `free_dimension_overrides` is not set, you may experience significant performance degradation.' + `WebNN does not currently support dynamic shapes and requires 'free_dimension_overrides' to be set in config.json, preferably as a field within config["transformers.js_config"]["device_config"]["${selectedDevice}"]. ` + + `When 'free_dimension_overrides' is not set, you may experience significant performance degradation.` ); } const bufferOrPathPromise = getModelFile(pretrained_model_name_or_path, modelFileName, true, options, apis.IS_NODE_ENV); - // handle onnx external data files + // Handle onnx external data files const use_external_data_format = options.use_external_data_format ?? custom_config.use_external_data_format; /** @type {Promise[]} */ let externalDataPromises = [];