Skip to content

Commit c2f2bd4

Browse files
committed
Support setting external data format in config.json
1 parent 1816e67 commit c2f2bd4

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

src/models.js

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,8 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
228228

229229
// Construct the model file name
230230
const suffix = DEFAULT_DTYPE_SUFFIX_MAPPING[selectedDtype];
231-
const modelFileName = `${options.subfolder ?? ''}/${fileName}${suffix}.onnx`;
231+
const baseName = `${fileName}${suffix}.onnx`;
232+
const modelFileName = `${options.subfolder ?? ''}/${baseName}`;
232233

233234
const session_options = { ...options.session_options };
234235

@@ -255,7 +256,9 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
255256
if (use_external_data_format) {
256257
let external_data_format;
257258
if (typeof use_external_data_format === 'object') {
258-
if (use_external_data_format.hasOwnProperty(fileName)) {
259+
if (use_external_data_format.hasOwnProperty(baseName)) {
260+
external_data_format = use_external_data_format[baseName];
261+
} else if (use_external_data_format.hasOwnProperty(fileName)) {
259262
external_data_format = use_external_data_format[fileName];
260263
} else {
261264
external_data_format = false;
@@ -269,7 +272,7 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
269272
throw new Error(`The number of external data chunks (${num_chunks}) exceeds the maximum allowed value (${MAX_EXTERNAL_DATA_CHUNKS}).`);
270273
}
271274
for (let i = 0; i < num_chunks; ++i) {
272-
const path = `${fileName}${suffix}.onnx_data${i === 0 ? '' : '_' + i}`;
275+
const path = `${baseName}_data${i === 0 ? '' : '_' + i}`;
273276
const fullPath = `${options.subfolder ?? ''}/${path}`;
274277
externalDataPromises.push(new Promise(async (resolve, reject) => {
275278
const data = await getModelFile(pretrained_model_name_or_path, fullPath, true, options, apis.IS_NODE_ENV);
@@ -313,7 +316,7 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
313316
}
314317
}
315318

316-
let buffer_or_path = await bufferOrPathPromise;
319+
const buffer_or_path = await bufferOrPathPromise;
317320

318321
return { buffer_or_path, session_options, session_config };
319322
}

0 commit comments

Comments
 (0)