Skip to content

Commit 3aab729

Browse files
committed
auto dtype selection
1 parent f95475f commit 3aab729

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

src/models.js

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,17 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
182182
}
183183
}
184184

185+
if (dtype === 'auto') {
186+
const config_dtype = custom_config.dtype?.[fileName];
187+
if (config_dtype === 'auto') {
188+
// Choose default dtype based on device, falling back to fp32
189+
dtype = DEFAULT_DEVICE_DTYPE_MAPPING[selectedDevice] ?? DATA_TYPES.fp32;
190+
} else {
191+
// Defined by the custom config, and is not "auto"
192+
dtype = config_dtype;
193+
}
194+
}
195+
185196
const selectedDtype = /** @type {import("./utils/dtypes.js").DataType} */(dtype);
186197

187198
if (!DEFAULT_DTYPE_SUFFIX_MAPPING.hasOwnProperty(selectedDtype)) {

src/utils/dtypes.js

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ export const isWebGpuFp16Supported = (function () {
3131
})();
3232

3333
export const DATA_TYPES = Object.freeze({
34+
auto: 'auto', // Auto-detect based on environment
3435
fp32: 'fp32',
3536
fp16: 'fp16',
3637
q8: 'q8',
@@ -47,7 +48,7 @@ export const DEFAULT_DEVICE_DTYPE_MAPPING = Object.freeze({
4748
[DEVICE_TYPES.wasm]: DATA_TYPES.q8,
4849
});
4950

50-
/** @type {Record<DataType, string>} */
51+
/** @type {Record<Exclude<DataType, "auto">, string>} */
5152
export const DEFAULT_DTYPE_SUFFIX_MAPPING = Object.freeze({
5253
[DATA_TYPES.fp32]: '',
5354
[DATA_TYPES.fp16]: '_fp16',

0 commit comments

Comments
 (0)