Skip to content

Commit 6b51147

Browse files
ibelemxenova
andauthored
Support device-level configuration across all devices (#1276)
* [WebNN] Only allow free_dimension_override on a device level * add device_config * Use Omit to define device config to prevent duplication * Update custom config instead of checking each property * Cleanup * Add back comment --------- Co-authored-by: Joshua Lochner <[email protected]>
1 parent 28ca8a8 commit 6b51147

File tree

2 files changed

+27
-9
lines changed

2 files changed

+27
-9
lines changed

src/configs.js

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,7 @@ export class AutoConfig {
404404
/**
405405
* Transformers.js-specific configuration, possibly present in config.json under the key `transformers.js_config`.
406406
* @typedef {Object} TransformersJSConfig
407+
* @property {Record<import('./utils/devices.js').DeviceType, DeviceConfig>} [device_config] Device-specific configurations.
407408
* @property {import('./utils/tensor.js').DataType|Record<import('./utils/dtypes.js').DataType, import('./utils/tensor.js').DataType>} [kv_cache_dtype] The data type of the key-value cache.
408409
* @property {Record<string, number>} [free_dimension_overrides] Override the free dimensions of the model.
409410
* See https://onnxruntime.ai/docs/tutorials/web/env-flags-and-session-options.html#freedimensionoverrides
@@ -412,3 +413,8 @@ export class AutoConfig {
412413
* @property {import('./utils/dtypes.js').DataType|Record<string, import('./utils/dtypes.js').DataType>} [dtype] The default data type to use for the model.
413414
* @property {import('./utils/hub.js').ExternalData|Record<string, import('./utils/hub.js').ExternalData>} [use_external_data_format=false] Whether to load the model using the external data format (used for models >= 2GB in size).
414415
*/
416+
417+
/**
418+
* Device-specific configuration options.
419+
* @typedef {Omit<TransformersJSConfig, "device" | "device_config">} DeviceConfig
420+
*/

src/models.js

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,8 @@ const MODEL_CLASS_TO_NAME_MAPPING = new Map();
158158
* @private
159159
*/
160160
async function getSession(pretrained_model_name_or_path, fileName, options) {
161-
const custom_config = options.config?.['transformers.js_config'] ?? {};
161+
let custom_config = options.config?.['transformers.js_config'] ?? {};
162+
162163
let device = options.device ?? custom_config.device;
163164
if (device && typeof device !== 'string') {
164165
if (device.hasOwnProperty(fileName)) {
@@ -173,8 +174,18 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
173174
const selectedDevice = /** @type {import("./utils/devices.js").DeviceType} */(
174175
device ?? (apis.IS_NODE_ENV ? 'cpu' : 'wasm')
175176
);
177+
176178
const executionProviders = deviceToExecutionProviders(selectedDevice);
177179

180+
// Update custom config with the selected device's config, if it exists
181+
const device_config = custom_config.device_config ?? {};
182+
if (device_config.hasOwnProperty(selectedDevice)) {
183+
custom_config = {
184+
...custom_config,
185+
...device_config[selectedDevice],
186+
};
187+
}
188+
178189
// If options.dtype is specified, we use it to choose the suffix for the model file.
179190
// Otherwise, we use the default dtype for the device.
180191
let dtype = options.dtype ?? custom_config.dtype;
@@ -195,7 +206,7 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
195206
}
196207

197208
if (config_dtype && config_dtype !== DATA_TYPES.auto && DATA_TYPES.hasOwnProperty(config_dtype)) {
198-
// Defined by the custom config, and is not "auto"
209+
// Defined by the config, and is not "auto"
199210
dtype = config_dtype;
200211
} else {
201212
// Choose default dtype based on device, falling back to fp32
@@ -212,10 +223,11 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
212223
}
213224

214225
// Only valid for models with a decoder
215-
const kv_cache_dtype = custom_config.kv_cache_dtype
216-
? (typeof custom_config.kv_cache_dtype === 'string'
217-
? custom_config.kv_cache_dtype
218-
: custom_config.kv_cache_dtype[selectedDtype] ?? 'float32')
226+
const kv_cache_dtype_config = custom_config.kv_cache_dtype;
227+
const kv_cache_dtype = kv_cache_dtype_config
228+
? (typeof kv_cache_dtype_config === 'string'
229+
? kv_cache_dtype_config
230+
: kv_cache_dtype_config[selectedDtype] ?? 'float32')
219231
: undefined;
220232

221233
if (kv_cache_dtype && !['float32', 'float16'].includes(kv_cache_dtype)) {
@@ -243,15 +255,15 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
243255
session_options.freeDimensionOverrides ??= free_dimension_overrides;
244256
} else if (selectedDevice.startsWith('webnn') && !session_options.freeDimensionOverrides) {
245257
console.warn(
246-
'WebNN does not currently support dynamic shapes and requires `free_dimension_overrides` to be set in config.json as a field within "transformers.js_config". ' +
247-
'When `free_dimension_overrides` is not set, you may experience significant performance degradation.'
258+
`WebNN does not currently support dynamic shapes and requires 'free_dimension_overrides' to be set in config.json, preferably as a field within config["transformers.js_config"]["device_config"]["${selectedDevice}"]. ` +
259+
`When 'free_dimension_overrides' is not set, you may experience significant performance degradation.`
248260
);
249261
}
250262

251263
const return_path = apis.IS_NODE_ENV && env.useFSCache;
252264
const bufferOrPathPromise = getModelFile(pretrained_model_name_or_path, modelFileName, true, options, return_path);
253265

254-
// handle onnx external data files
266+
// Handle onnx external data files
255267
const use_external_data_format = options.use_external_data_format ?? custom_config.use_external_data_format;
256268
/** @type {Promise<string|{path: string, data: Uint8Array}>[]} */
257269
let externalDataPromises = [];

0 commit comments

Comments
 (0)