Skip to content

Commit bd2449e

Browse files
Honryxenova
andauthored
Optimize tensor.slice() (#1381)
* Optimize tensor.slice() The performance of executing `tensor.slice()` is super poor, especially for the 'logits' tensor with large dimensions. ``` const logits = outputs.logits.slice(null, -1, null);` ``` This is because currently implementation of the `slice` method manually iterates through each element and calculate indices which is a big time consuming if the tensor shape is large. For cases like `slice(null, -1, null)`, where the slicing operation is contiguous along certain dimensions, which can be optimized by bulk copy by using `TypeArray.subarray()` and `TypeArray.set()`. * nit * Add a few more tensor slice unit tests --------- Co-authored-by: Joshua Lochner <[email protected]>
1 parent 82206db commit bd2449e

File tree

2 files changed

+97
-10
lines changed

2 files changed

+97
-10
lines changed

src/utils/tensor.js

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -443,15 +443,46 @@ export class Tensor {
443443
// Precompute strides
444444
const stride = this.stride();
445445

446-
for (let i = 0; i < newBufferSize; ++i) {
447-
let originalIndex = 0;
448-
for (let j = newDims.length - 1, num = i; j >= 0; --j) {
449-
const size = newDims[j];
450-
originalIndex += ((num % size) + newOffsets[j][0]) * stride[j];
451-
num = Math.floor(num / size);
446+
// Detect if the slice is contiguous
447+
let isContiguous = true;
448+
for (let i = 1; i < newDims.length; ++i) {
449+
if (newOffsets[i][0] !== 0 || newOffsets[i][1] !== this.dims[i]) {
450+
isContiguous = false;
451+
break;
452452
}
453-
data[i] = this_data[originalIndex];
454453
}
454+
455+
if (isContiguous) {
456+
// Perform bulk copy for contiguous slices to improve performance
457+
const start = newOffsets[0][0] * stride[0];
458+
const end = newOffsets[0][1] * stride[0];
459+
460+
if (ArrayBuffer.isView(this_data)) {
461+
// If this.data is a TypedArray, use subarray
462+
// @ts-ignore
463+
data.set(this_data.subarray(start, end));
464+
} else if (Array.isArray(this_data)) {
465+
// If this.data is a plain array, use slice
466+
const slicedData = this_data.slice(start, end);
467+
for (let i = 0; i < slicedData.length; ++i) {
468+
data[i] = slicedData[i];
469+
}
470+
} else {
471+
throw new Error("Unsupported data type for slicing");
472+
}
473+
} else {
474+
// Fallback to manual copying for non-contiguous slices
475+
for (let i = 0; i < newBufferSize; ++i) {
476+
let originalIndex = 0;
477+
for (let j = newDims.length - 1, num = i; j >= 0; --j) {
478+
const size = newDims[j];
479+
originalIndex += ((num % size) + newOffsets[j][0]) * stride[j];
480+
num = Math.floor(num / size);
481+
}
482+
data[i] = this_data[originalIndex];
483+
}
484+
}
485+
455486
return new Tensor(this.type, data, newTensorDims);
456487
}
457488

tests/utils/tensor.test.js

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,13 @@ describe("Tensor operations", () => {
5959
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
6060
const t2 = t1.slice(1);
6161
const target = new Tensor("float32", [3, 4], [2]);
62-
6362
compare(t2, target);
6463
});
6564

6665
it("should return a range of rows", () => {
6766
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
6867
const t2 = t1.slice([1, 3]);
6968
const target = new Tensor("float32", [3, 4, 5, 6], [2, 2]);
70-
7169
compare(t2, target);
7270
});
7371

@@ -78,9 +76,67 @@ describe("Tensor operations", () => {
7876
[4, 7],
7977
);
8078
const t2 = t1.slice([1, -1], [1, -1]);
81-
8279
const target = new Tensor("float32", [9, 10, 11, 12, 13, 16, 17, 18, 19, 20], [2, 5]);
80+
compare(t2, target);
81+
});
82+
83+
it("should return the whole tensor when all indices are null/unset", () => {
84+
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
85+
const t2 = t1.slice();
86+
compare(t2, t1);
87+
});
88+
89+
it("should return the whole dimension when index is null", () => {
90+
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
91+
const t2 = t1.slice(null);
92+
compare(t2, t1);
93+
});
94+
95+
it("should slice from index to end when [start, null] is used", () => {
96+
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
97+
const t2 = t1.slice([1, null]);
98+
const target = new Tensor("float32", [3, 4, 5, 6], [2, 2]);
99+
compare(t2, target);
100+
});
101+
102+
it("should slice from beginning to index when [null, end] is used", () => {
103+
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
104+
const t2 = t1.slice([null, 2]);
105+
const target = new Tensor("float32", [1, 2, 3, 4], [2, 2]);
106+
compare(t2, target);
107+
});
108+
109+
it("should handle [null, null] as full slice", () => {
110+
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
111+
const t2 = t1.slice([null, null]);
112+
compare(t2, t1);
113+
});
114+
115+
it("should select a single element when a number is used in slice", () => {
116+
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
117+
const t2 = t1.slice(2, 1);
118+
const target = new Tensor("float32", [6], []);
119+
compare(t2, target);
120+
});
83121

122+
it("should select a single row when a number is used in slice", () => {
123+
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
124+
const t2 = t1.slice(0);
125+
const target = new Tensor("float32", [1, 2], [2]);
126+
compare(t2, target);
127+
});
128+
129+
it("should select a single column when a number is used in slice", () => {
130+
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
131+
const t2 = t1.slice(null, 1);
132+
const target = new Tensor("float32", [2, 4, 6], [3]);
133+
compare(t2, target);
134+
});
135+
136+
it("should handle negative indices in slice", () => {
137+
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
138+
const t2 = t1.slice(-1);
139+
const target = new Tensor("float32", [5, 6], [2]);
84140
compare(t2, target);
85141
});
86142
});

0 commit comments

Comments
 (0)