Skip to content

Commit 4d76e66

Browse files
committed
Update custom config instead of checking each property
1 parent e58fafe commit 4d76e66

File tree

1 file changed

+19
-19
lines changed

1 file changed

+19
-19
lines changed

src/models.js

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,7 @@ const MODEL_CLASS_TO_NAME_MAPPING = new Map();
158158
* @private
159159
*/
160160
async function getSession(pretrained_model_name_or_path, fileName, options) {
161-
const custom_config = options.config?.['transformers.js_config'] ?? {};
162-
const device_config = custom_config.device_config ?? {};
161+
let custom_config = options.config?.['transformers.js_config'] ?? {};
163162

164163
let device = options.device ?? custom_config.device;
165164
if (device && typeof device !== 'string') {
@@ -176,14 +175,20 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
176175
device ?? (apis.IS_NODE_ENV ? 'cpu' : 'wasm')
177176
);
178177

179-
// Get device-specific config if available
180-
const deviceSpecificConfig = device_config[selectedDevice] ?? {};
181-
182178
const executionProviders = deviceToExecutionProviders(selectedDevice);
183179

180+
// Update custom config with the selected device's config, if it exists
181+
const device_config = custom_config.device_config ?? {};
182+
if (device_config.hasOwnProperty(selectedDevice)) {
183+
custom_config = {
184+
...custom_config,
185+
...device_config[selectedDevice],
186+
};
187+
}
188+
184189
// If options.dtype is specified, we use it to choose the suffix for the model file.
185-
// Otherwise, try device-specific config, then fall back to transformers.js_config
186-
let dtype = options.dtype ?? deviceSpecificConfig.dtype ?? custom_config.dtype;
190+
// Otherwise, we use the default dtype for the device.
191+
let dtype = options.dtype ?? custom_config.dtype;
187192
if (typeof dtype !== 'string') {
188193
if (dtype && dtype.hasOwnProperty(fileName)) {
189194
dtype = dtype[fileName];
@@ -195,7 +200,7 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
195200

196201
if (dtype === DATA_TYPES.auto) {
197202
// Try to choose the auto dtype based on the device-specific config first, then fall back to transformers.js_config
198-
let config_dtype = deviceSpecificConfig.dtype ?? custom_config.dtype;
203+
let config_dtype = custom_config.dtype;
199204
if (typeof config_dtype !== 'string') {
200205
config_dtype = config_dtype?.[fileName];
201206
}
@@ -217,10 +222,8 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
217222
throw new Error(`The device (${selectedDevice}) does not support fp16.`);
218223
}
219224

220-
// Check device-specific kv_cache_dtype first, then fall back to transformers.js_config
221-
const kv_cache_dtype_config = deviceSpecificConfig.kv_cache_dtype ?? custom_config.kv_cache_dtype;
222-
223225
// Only valid for models with a decoder
226+
const kv_cache_dtype_config = custom_config.kv_cache_dtype;
224227
const kv_cache_dtype = kv_cache_dtype_config
225228
? (typeof kv_cache_dtype_config === 'string'
226229
? kv_cache_dtype_config
@@ -246,24 +249,21 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
246249
// Overwrite `executionProviders` if not specified
247250
session_options.executionProviders ??= executionProviders;
248251

249-
// First check device-specific free_dimension_overrides, then fall back to transformers.js_config
250-
const free_dimension_overrides = deviceSpecificConfig.free_dimension_overrides ?? custom_config.free_dimension_overrides;
252+
const free_dimension_overrides = custom_config.free_dimension_overrides;
251253

252254
if (free_dimension_overrides) {
253255
session_options.freeDimensionOverrides ??= free_dimension_overrides;
254256
} else if (selectedDevice.startsWith('webnn') && !session_options.freeDimensionOverrides) {
255257
console.warn(
256-
'WebNN does not currently support dynamic shapes and requires `free_dimension_overrides` to be set in config.json as a field within "device_config[selectedDevice]" or "transformers.js_config"' +
257-
'When `free_dimension_overrides` is not set, you may experience significant performance degradation.'
258+
`WebNN does not currently support dynamic shapes and requires 'free_dimension_overrides' to be set in config.json, preferably as a field within config["transformers.js_config"]["device_config"]["${selectedDevice}"]. ` +
259+
`When 'free_dimension_overrides' is not set, you may experience significant performance degradation.`
258260
);
259261
}
260262

261263
const bufferOrPathPromise = getModelFile(pretrained_model_name_or_path, modelFileName, true, options, apis.IS_NODE_ENV);
262264

263-
// handle onnx external data files - check device-specific config first, then fall back to transformers.js_config
264-
const use_external_data_format = options.use_external_data_format ??
265-
deviceSpecificConfig.use_external_data_format ??
266-
custom_config.use_external_data_format;
265+
// Handle onnx external data files
266+
const use_external_data_format = options.use_external_data_format ?? custom_config.use_external_data_format;
267267

268268
/** @type {Promise<string|{path: string, data: Uint8Array}>[]} */
269269
let externalDataPromises = [];

0 commit comments

Comments
 (0)