Skip to content

Commit 9e74fc3

Browse files
Remove vslice as it doesn't work as it should.
Update documentation. Update tests.
1 parent be6a733 commit 9e74fc3

File tree

2 files changed

+26
-78
lines changed

2 files changed

+26
-78
lines changed

src/utils/tensor.js

Lines changed: 22 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -348,60 +348,31 @@ export class Tensor {
348348
return new Tensor(this.type, this.data.slice(), this.dims.slice());
349349
}
350350

351-
/**
352-
* Performs a vertical slice operation on a 2D Tensor.
353-
* @param {number|number[]|number[][]} slices - The slice specification:
354-
* - If a number is given, then a single column is selected.
355-
* - If an array of two numbers is given, then a range of columns [start, end (exclusive)] is selected.
356-
* - If an array of arrays is given, then those specific columns are selected.
357-
* @returns {Tensor} A new Tensor containing the selected columns.
358-
* @throws {Error} If the slice input is invalid.
359-
*/
360-
vslice(slices) {
361-
const rowDim = this.dims[0];
362-
const colDim = this.dims[1];
363-
364-
// Handle different slice cases (single column, range, or list of specific columns)
365-
let selectedCols = [];
366-
if (typeof slices === 'number') {
367-
// Single column slice
368-
selectedCols = [slices];
369-
} else if (Array.isArray(slices) && slices.length === 2 && !Array.isArray(slices[0]) && !Array.isArray(slices[1])) {
370-
// Range slice [start, end]
371-
const [start, end] = slices;
372-
selectedCols = Array.from({ length: end - start }, (_, i) => i + start);
373-
} else if (Array.isArray(slices) && Array.isArray(slices[0])) {
374-
// Specific column list [[col1], [col2]]
375-
selectedCols = slices.flat();
376-
} else {
377-
throw new Error('Invalid slice input');
378-
}
379-
380-
// Determine new dimensions: rows remain the same, columns are based on selection
381-
const newTensorDims = [rowDim, selectedCols.length];
382-
const newBufferSize = newTensorDims[0] * newTensorDims[1];
383-
// Allocate memory
384-
// @ts-ignore
385-
const data = new this.data.constructor(newBufferSize);
386-
387-
// Fill the new data array by selecting the correct columns
388-
for (let row = 0; row < rowDim; ++row) {
389-
for (let i = 0; i < selectedCols.length; ++i) {
390-
const col = selectedCols[i];
391-
const targetIndex = row * newTensorDims[1] + i;
392-
const originalIndex = row * colDim + col;
393-
data[targetIndex] = this.data[originalIndex];
394-
}
395-
}
396-
397-
return new Tensor(this.type, data, newTensorDims);
398-
}
399-
400351
/**
401352
* Performs a slice operation on the Tensor along specified dimensions.
353+
*
354+
* Consider a Tensor that has a dimension of [4, 7]:
355+
* ```
356+
* [ 1, 2, 3, 4, 5, 6, 7]
357+
* [ 8, 9, 10, 11, 12, 13, 14]
358+
* [15, 16, 17, 18, 19, 20, 21]
359+
* [22, 23, 24, 25, 26, 27, 28]
360+
* ```
361+
* We can slice against the two dims of row and column, for instance in this
362+
* case we can start at the second element, and return to the second last,
363+
* like this:
364+
* ```
365+
* tensor.slice([1, -1], [1, -1]);
366+
* ```
367+
* which would return:
368+
* ```
369+
* [ 9, 10, 11, 12, 13 ]
370+
* [ 16, 17, 18, 19, 20 ]
371+
* ```
372+
*
402373
* @param {...(number|number[]|null)} slices - The slice specifications for each dimension.
403-
* - If a number is given, then a single column is selected.
404-
* - If an array of two numbers is given, then a range of columns [start, end (exclusive)] is selected.
374+
* - If a number is given, then a single element is selected.
375+
* - If an array of two numbers is given, then a range of elements [start, end (exclusive)] is selected.
405376
* - If null is given, then the entire dimension is selected.
406377
* @returns {Tensor} A new Tensor containing the selected elements.
407378
* @throws {Error} If the slice input is invalid.

tests/utils/tensor.test.js

Lines changed: 4 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -69,35 +69,12 @@ describe("Tensor operations", () => {
6969
compare(t2, target);
7070
});
7171

72-
it("should return a given column dim", async () => {
73-
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
74-
const t2 = t1.vslice(1);
75-
const target = new Tensor("float32", [2, 4, 6], [3, 1]);
76-
77-
compare(t2, target);
78-
});
79-
80-
it("should return a range of cols", async () => {
81-
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [3, 4]);
82-
// The end index is not included.
83-
const t2 = t1.vslice([1, 3]);
84-
const target = new Tensor("float32", [2, 3, 6, 7, 10, 11], [3, 2]);
85-
86-
compare(t2, target);
87-
});
88-
89-
it("should return a every third row", async () => {
72+
it("should return a crop", async () => {
9073
// Create 21 nodes.
91-
const t1 = new Tensor("float32", Array.from({ length: 21 }, (v, i) => v = ++i), [3, 7]);
92-
93-
// Extract every third column.
94-
const indices = Array.from({ length: t1.dims[1] }, (_, i) => i)
95-
.filter(i => i % 3 === 0)
96-
// Make sure to wrap each in an array since an array creates a new range.
97-
.map(v => [v]);
98-
const t2 = t1.vslice(indices);
74+
const t1 = new Tensor("float32", Array.from({ length: 28 }, (v, i) => v = ++i), [4, 7]);
75+
const t2 = t1.slice([1, -1], [1, -1]);
9976

100-
const target = new Tensor("float32", [1, 4, 7, 8, 11, 14, 15, 18, 21], [3, 3]);
77+
const target = new Tensor("float32", [9, 10, 11, 12, 13, 16, 17, 18, 19, 20], [2, 5]);
10178

10279
compare(t2, target);
10380
});

0 commit comments

Comments
 (0)