Skip to content

Commit fbc745b

Browse files
authored
Use Float16Array instead of Uint16Array for kvcache when available (#1208)
1 parent 8bef102 commit fbc745b

File tree

3 files changed

+8
-4
lines changed

3 files changed

+8
-4
lines changed

src/models.js

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ import {
108108
stack,
109109
std_mean,
110110
Tensor,
111+
DataTypeMap,
111112
} from './utils/tensor.js';
112113
import { RawImage } from './utils/image.js';
113114

@@ -1847,7 +1848,7 @@ export class PreTrainedModel extends Callable {
18471848
} else {
18481849
const session = this.sessions['decoder_model_merged'] ?? this.sessions['model'];
18491850
const dtype = session?.config?.kv_cache_dtype ?? 'float32';
1850-
const empty = (dtype === 'float16') ? new Uint16Array() : [];
1851+
const empty = (dtype === 'float16') ? new DataTypeMap.float16() : [];
18511852

18521853
const batch_size = (decoderFeeds[this.main_input_name] ?? decoderFeeds.attention_mask)?.dims?.[0] ?? 1;
18531854
const shapes = getKeyValueShapes(this.config, { batch_size });

src/utils/tensor.js

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@ import {
2020

2121
import { TensorOpRegistry } from '../ops/registry.js';
2222

23-
const DataTypeMap = Object.freeze({
23+
export const DataTypeMap = Object.freeze({
2424
float32: Float32Array,
25-
float16: Uint16Array,
25+
// @ts-expect-error ts(2552) Limited availability of Float16Array across browsers:
26+
// https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Float16Array
27+
float16: typeof Float16Array !== "undefined" ? Float16Array: Uint16Array,
2628
float64: Float64Array,
2729
string: Array, // string[]
2830
int8: Int8Array,

tests/init.js

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ export function init() {
2222
let registerBackend = ONNX_COMMON.registerBackend;
2323

2424
// Define the constructors to monkey-patch
25-
const TYPED_ARRAYS_CONSTRUCTOR_NAMES = ["Int8Array", "Int16Array", "Int32Array", "BigInt64Array", "Uint8Array", "Uint8ClampedArray", "Uint16Array", "Uint32Array", "BigUint64Array", "Float32Array", "Float64Array"];
25+
const TYPED_ARRAYS_CONSTRUCTOR_NAMES = ["Int8Array", "Int16Array", "Int32Array", "BigInt64Array", "Uint8Array", "Uint8ClampedArray", "Uint16Array", "Uint32Array", "BigUint64Array", "Float16Array", "Float32Array", "Float64Array"];
2626

2727
// Keep a reference to the original initialization method
2828
const originalMethod = onnxruntimeBackend.init;
@@ -36,6 +36,7 @@ export function init() {
3636
for (const ctorName of TYPED_ARRAYS_CONSTRUCTOR_NAMES) {
3737
// Get the constructor from the current context
3838
const ctor = globalThis[ctorName];
39+
if (ctor === undefined) continue; // If unavailable, skip the patching
3940

4041
// Get the corresponding test function from the `util` module
4142
const value = types[`is${ctorName}`].bind(types);

0 commit comments

Comments
 (0)