@@ -158,8 +158,7 @@ const MODEL_CLASS_TO_NAME_MAPPING = new Map();
158158 * @private
159159 */
160160async function getSession ( pretrained_model_name_or_path , fileName , options ) {
161- const custom_config = options . config ?. [ 'transformers.js_config' ] ?? { } ;
162- const device_config = custom_config . device_config ?? { } ;
161+ let custom_config = options . config ?. [ 'transformers.js_config' ] ?? { } ;
163162
164163 let device = options . device ?? custom_config . device ;
165164 if ( device && typeof device !== 'string' ) {
@@ -176,14 +175,20 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
176175 device ?? ( apis . IS_NODE_ENV ? 'cpu' : 'wasm' )
177176 ) ;
178177
179- // Get device-specific config if available
180- const deviceSpecificConfig = device_config [ selectedDevice ] ?? { } ;
181-
182178 const executionProviders = deviceToExecutionProviders ( selectedDevice ) ;
183179
180+ // Update custom config with the selected device's config, if it exists
181+ const device_config = custom_config . device_config ?? { } ;
182+ if ( device_config . hasOwnProperty ( selectedDevice ) ) {
183+ custom_config = {
184+ ...custom_config ,
185+ ...device_config [ selectedDevice ] ,
186+ } ;
187+ }
188+
184189 // If options.dtype is specified, we use it to choose the suffix for the model file.
185- // Otherwise, try device-specific config, then fall back to transformers.js_config
186- let dtype = options . dtype ?? deviceSpecificConfig . dtype ?? custom_config . dtype ;
190+ // Otherwise, we use the default dtype for the device.
191+ let dtype = options . dtype ?? custom_config . dtype ;
187192 if ( typeof dtype !== 'string' ) {
188193 if ( dtype && dtype . hasOwnProperty ( fileName ) ) {
189194 dtype = dtype [ fileName ] ;
@@ -195,7 +200,7 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
195200
196201 if ( dtype === DATA_TYPES . auto ) {
197202 // Try to choose the auto dtype based on the device-specific config first, then fall back to transformers.js_config
198- let config_dtype = deviceSpecificConfig . dtype ?? custom_config . dtype ;
203+ let config_dtype = custom_config . dtype ;
199204 if ( typeof config_dtype !== 'string' ) {
200205 config_dtype = config_dtype ?. [ fileName ] ;
201206 }
@@ -217,10 +222,8 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
217222 throw new Error ( `The device (${ selectedDevice } ) does not support fp16.` ) ;
218223 }
219224
220- // Check device-specific kv_cache_dtype first, then fall back to transformers.js_config
221- const kv_cache_dtype_config = deviceSpecificConfig . kv_cache_dtype ?? custom_config . kv_cache_dtype ;
222-
223225 // Only valid for models with a decoder
226+ const kv_cache_dtype_config = custom_config . kv_cache_dtype ;
224227 const kv_cache_dtype = kv_cache_dtype_config
225228 ? ( typeof kv_cache_dtype_config === 'string'
226229 ? kv_cache_dtype_config
@@ -246,24 +249,21 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
246249 // Overwrite `executionProviders` if not specified
247250 session_options . executionProviders ??= executionProviders ;
248251
249- // First check device-specific free_dimension_overrides, then fall back to transformers.js_config
250- const free_dimension_overrides = deviceSpecificConfig . free_dimension_overrides ?? custom_config . free_dimension_overrides ;
252+ const free_dimension_overrides = custom_config . free_dimension_overrides ;
251253
252254 if ( free_dimension_overrides ) {
253255 session_options . freeDimensionOverrides ??= free_dimension_overrides ;
254256 } else if ( selectedDevice . startsWith ( 'webnn' ) && ! session_options . freeDimensionOverrides ) {
255257 console . warn (
256- ' WebNN does not currently support dynamic shapes and requires ` free_dimension_overrides` to be set in config.json as a field within "device_config[selectedDevice]" or " transformers.js_config"' +
257- ' When ` free_dimension_overrides` is not set, you may experience significant performance degradation.'
258+ ` WebNN does not currently support dynamic shapes and requires ' free_dimension_overrides' to be set in config.json, preferably as a field within config[" transformers.js_config"]["device_config"][" ${ selectedDevice } "]. ` +
259+ ` When ' free_dimension_overrides' is not set, you may experience significant performance degradation.`
258260 ) ;
259261 }
260262
261263 const bufferOrPathPromise = getModelFile ( pretrained_model_name_or_path , modelFileName , true , options , apis . IS_NODE_ENV ) ;
262264
263- // handle onnx external data files - check device-specific config first, then fall back to transformers.js_config
264- const use_external_data_format = options . use_external_data_format ??
265- deviceSpecificConfig . use_external_data_format ??
266- custom_config . use_external_data_format ;
265+ // Handle onnx external data files
266+ const use_external_data_format = options . use_external_data_format ?? custom_config . use_external_data_format ;
267267
268268 /** @type {Promise<string|{path: string, data: Uint8Array}>[] } */
269269 let externalDataPromises = [ ] ;
0 commit comments