Skip to content

Commit cb3190b

Browse files
committed
Add a few more tensor slice unit tests
1 parent 0fcc97f commit cb3190b

File tree

1 file changed

+59
-3
lines changed

1 file changed

+59
-3
lines changed

tests/utils/tensor.test.js

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,13 @@ describe("Tensor operations", () => {
5959
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
6060
const t2 = t1.slice(1);
6161
const target = new Tensor("float32", [3, 4], [2]);
62-
6362
compare(t2, target);
6463
});
6564

6665
it("should return a range of rows", () => {
6766
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
6867
const t2 = t1.slice([1, 3]);
6968
const target = new Tensor("float32", [3, 4, 5, 6], [2, 2]);
70-
7169
compare(t2, target);
7270
});
7371

@@ -78,9 +76,67 @@ describe("Tensor operations", () => {
7876
[4, 7],
7977
);
8078
const t2 = t1.slice([1, -1], [1, -1]);
81-
8279
const target = new Tensor("float32", [9, 10, 11, 12, 13, 16, 17, 18, 19, 20], [2, 5]);
80+
compare(t2, target);
81+
});
82+
83+
it("should return the whole tensor when all indices are null/unset", () => {
84+
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
85+
const t2 = t1.slice();
86+
compare(t2, t1);
87+
});
88+
89+
it("should return the whole dimension when index is null", () => {
90+
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
91+
const t2 = t1.slice(null);
92+
compare(t2, t1);
93+
});
94+
95+
it("should slice from index to end when [start, null] is used", () => {
96+
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
97+
const t2 = t1.slice([1, null]);
98+
const target = new Tensor("float32", [3, 4, 5, 6], [2, 2]);
99+
compare(t2, target);
100+
});
101+
102+
it("should slice from beginning to index when [null, end] is used", () => {
103+
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
104+
const t2 = t1.slice([null, 2]);
105+
const target = new Tensor("float32", [1, 2, 3, 4], [2, 2]);
106+
compare(t2, target);
107+
});
108+
109+
it("should handle [null, null] as full slice", () => {
110+
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
111+
const t2 = t1.slice([null, null]);
112+
compare(t2, t1);
113+
});
114+
115+
it("should select a single element when a number is used in slice", () => {
116+
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
117+
const t2 = t1.slice(2, 1);
118+
const target = new Tensor("float32", [6], []);
119+
compare(t2, target);
120+
});
83121

122+
it("should select a single row when a number is used in slice", () => {
123+
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
124+
const t2 = t1.slice(0);
125+
const target = new Tensor("float32", [1, 2], [2]);
126+
compare(t2, target);
127+
});
128+
129+
it("should select a single column when a number is used in slice", () => {
130+
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
131+
const t2 = t1.slice(null, 1);
132+
const target = new Tensor("float32", [2, 4, 6], [3]);
133+
compare(t2, target);
134+
});
135+
136+
it("should handle negative indices in slice", () => {
137+
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
138+
const t2 = t1.slice(-1);
139+
const target = new Tensor("float32", [5, 6], [2]);
84140
compare(t2, target);
85141
});
86142
});

0 commit comments

Comments
 (0)