Skip to content

Commit a96fa5a

Browse files
committed
Optimize tensor.slice()
The performance of executing `tensor.slice()` is super poor, especially for the 'logits' tensor with large dimensions. ``` const logits = outputs.logits.slice(null, -1, null);` ``` This is because currently implementation of the `slice` method manually iterates through each element and calculate indices which is a big time consuming if the tensor shape is large. For cases like `slice(null, -1, null)`, where the slicing operation is contiguous along certain dimensions, which can be optimized by bulk copy by using `TypeArray.subarray()` and `TypeArray.set()`.
1 parent 82206db commit a96fa5a

File tree

1 file changed

+38
-7
lines changed

1 file changed

+38
-7
lines changed

src/utils/tensor.js

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -443,15 +443,46 @@ export class Tensor {
443443
// Precompute strides
444444
const stride = this.stride();
445445

446-
for (let i = 0; i < newBufferSize; ++i) {
447-
let originalIndex = 0;
448-
for (let j = newDims.length - 1, num = i; j >= 0; --j) {
449-
const size = newDims[j];
450-
originalIndex += ((num % size) + newOffsets[j][0]) * stride[j];
451-
num = Math.floor(num / size);
446+
// Detect if the slice is contiguous
447+
let isContiguous = true;
448+
for (let i = 1; i < newDims.length; ++i) {
449+
if (newOffsets[i][0] !== 0 || newOffsets[i][1] !== this.dims[i]) {
450+
isContiguous = false;
451+
break;
452452
}
453-
data[i] = this_data[originalIndex];
454453
}
454+
455+
if (isContiguous) {
456+
// Perform bulk copy for contiguous slices to improve performance
457+
const start = newOffsets[0][0] * stride[0];
458+
const end = newOffsets[0][1] * stride[0];
459+
460+
if (ArrayBuffer.isView(this_data)) {
461+
// If this.data is a TypedArray, use subarray
462+
// @ts-ignore
463+
data.set(this_data.subarray(start, end));
464+
} else if (Array.isArray(this_data)) {
465+
// If this.data is a plain array, use slice
466+
const slicedData = this_data.slice(start, end);
467+
for (let i = 0; i < slicedData.length; i++) {
468+
data[i] = slicedData[i];
469+
}
470+
} else {
471+
throw new Error("Unsupported data type for slicing");
472+
}
473+
} else {
474+
// Fallback to manual copying for non-contiguous slices
475+
for (let i = 0; i < newBufferSize; ++i) {
476+
let originalIndex = 0;
477+
for (let j = newDims.length - 1, num = i; j >= 0; --j) {
478+
const size = newDims[j];
479+
originalIndex += ((num % size) + newOffsets[j][0]) * stride[j];
480+
num = Math.floor(num / size);
481+
}
482+
data[i] = this_data[originalIndex];
483+
}
484+
}
485+
455486
return new Tensor(this.type, data, newTensorDims);
456487
}
457488

0 commit comments

Comments
 (0)