Skip to content

Commit bbc46df

Browse files
committed
Cleaned up tensor classes
1 parent 160f15f commit bbc46df

File tree

2 files changed

+144
-148
lines changed

2 files changed

+144
-148
lines changed

extension/wasm/test/executorch_wasm.test.js

Lines changed: 44 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -16,49 +16,58 @@ beforeAll((done) => {
1616

1717
describe("Tensor", () => {
1818
test("ones", () => {
19-
const tensor = et.FloatTensor.ones([2, 2]);
20-
expect(tensor.getData()).toEqual([1, 1, 1, 1]);
21-
expect(tensor.getSizes()).toEqual([2, 2]);
19+
const tensor = et.Tensor.ones([2, 2]);
20+
expect(tensor.data).toEqual([1, 1, 1, 1]);
21+
expect(tensor.sizes).toEqual([2, 2]);
2222
tensor.delete();
2323
});
2424

2525
test("zeros", () => {
26-
const tensor = et.FloatTensor.zeros([2, 2]);
27-
expect(tensor.getData()).toEqual([0, 0, 0, 0]);
28-
expect(tensor.getSizes()).toEqual([2, 2]);
26+
const tensor = et.Tensor.zeros([2, 2]);
27+
expect(tensor.data).toEqual([0, 0, 0, 0]);
28+
expect(tensor.sizes).toEqual([2, 2]);
2929
tensor.delete();
3030
});
3131

3232
test("fromArray", () => {
33-
const tensor = et.FloatTensor.fromArray([1, 2, 3, 4], [2, 2]);
34-
expect(tensor.getData()).toEqual([1, 2, 3, 4]);
35-
expect(tensor.getSizes()).toEqual([2, 2]);
33+
const tensor = et.Tensor.fromArray([2, 2], [1, 2, 3, 4]);
34+
expect(tensor.data).toEqual([1, 2, 3, 4]);
35+
expect(tensor.sizes).toEqual([2, 2]);
3636
tensor.delete();
3737
});
3838

3939
test("fromArray wrong size", () => {
40-
expect(() => et.FloatTensor.fromArray([1, 2, 3, 4], [3, 2])).toThrow();
40+
expect(() => et.Tensor.fromArray([3, 2], [1, 2, 3, 4])).toThrow();
4141
});
4242

4343
test("full", () => {
44-
const tensor = et.FloatTensor.full([2, 2], 3);
45-
expect(tensor.getData()).toEqual([3, 3, 3, 3]);
46-
expect(tensor.getSizes()).toEqual([2, 2]);
44+
const tensor = et.Tensor.full([2, 2], 3);
45+
expect(tensor.data).toEqual([3, 3, 3, 3]);
46+
expect(tensor.sizes).toEqual([2, 2]);
4747
tensor.delete();
4848
});
4949

5050
test("scalar type", () => {
51-
const tensor = et.FloatTensor.ones([2, 2]);
51+
const tensor = et.Tensor.ones([2, 2]);
5252
// ScalarType can only be checked by strict equality.
5353
expect(tensor.scalarType).toBe(et.ScalarType.Float);
5454
tensor.delete();
5555
});
5656

5757
test("long tensor", () => {
58+
const tensor = et.Tensor.ones([2, 2], et.ScalarType.Long);
59+
expect(tensor.data).toEqual([1n, 1n, 1n, 1n]);
60+
expect(tensor.sizes).toEqual([2, 2]);
61+
// ScalarType can only be checked by strict equality.
62+
expect(tensor.scalarType).toBe(et.ScalarType.Long);
63+
tensor.delete();
64+
});
65+
66+
test("infer long tensor", () => {
5867
// Number cannot be converted to Long, so we use BigInt instead.
59-
const tensor = et.LongTensor.fromArray([1n, 2n, 3n, 4n], [2, 2]);
60-
expect(tensor.getData()).toEqual([1n, 2n, 3n, 4n]);
61-
expect(tensor.getSizes()).toEqual([2, 2]);
68+
const tensor = et.Tensor.fromArray([2, 2], [1n, 2n, 3n, 4n]);
69+
expect(tensor.data).toEqual([1n, 2n, 3n, 4n]);
70+
expect(tensor.sizes).toEqual([2, 2]);
6271
// ScalarType can only be checked by strict equality.
6372
expect(tensor.scalarType).toBe(et.ScalarType.Long);
6473
tensor.delete();
@@ -185,12 +194,12 @@ describe("Module", () => {
185194
describe("execute", () => {
186195
test("add normally", () => {
187196
const module = et.Module.load("add.pte");
188-
const inputs = [et.FloatTensor.ones([1]), et.FloatTensor.ones([1])];
197+
const inputs = [et.Tensor.ones([1]), et.Tensor.ones([1])];
189198
const output = module.execute("forward", inputs);
190199

191200
expect(output.length).toEqual(1);
192-
expect(output[0].getData()).toEqual([2]);
193-
expect(output[0].getSizes()).toEqual([1]);
201+
expect(output[0].data).toEqual([2]);
202+
expect(output[0].sizes).toEqual([1]);
194203

195204
inputs.forEach((input) => input.delete());
196205
output.forEach((output) => output.delete());
@@ -199,12 +208,12 @@ describe("Module", () => {
199208

200209
test("add_mul normally", () => {
201210
const module = et.Module.load("add_mul.pte");
202-
const inputs = [et.FloatTensor.ones([2, 2]), et.FloatTensor.ones([2, 2]), et.FloatTensor.ones([2, 2])];
211+
const inputs = [et.Tensor.ones([2, 2]), et.Tensor.ones([2, 2]), et.Tensor.ones([2, 2])];
203212
const output = module.execute("forward", inputs);
204213

205214
expect(output.length).toEqual(1);
206-
expect(output[0].getData()).toEqual([3, 3, 3, 3]);
207-
expect(output[0].getSizes()).toEqual([2, 2]);
215+
expect(output[0].data).toEqual([3, 3, 3, 3]);
216+
expect(output[0].sizes).toEqual([2, 2]);
208217

209218
inputs.forEach((input) => input.delete());
210219
output.forEach((output) => output.delete());
@@ -213,12 +222,12 @@ describe("Module", () => {
213222

214223
test("forward directly", () => {
215224
const module = et.Module.load("add_mul.pte");
216-
const inputs = [et.FloatTensor.ones([2, 2]), et.FloatTensor.ones([2, 2]), et.FloatTensor.ones([2, 2])];
225+
const inputs = [et.Tensor.ones([2, 2]), et.Tensor.ones([2, 2]), et.Tensor.ones([2, 2])];
217226
const output = module.forward(inputs);
218227

219228
expect(output.length).toEqual(1);
220-
expect(output[0].getData()).toEqual([3, 3, 3, 3]);
221-
expect(output[0].getSizes()).toEqual([2, 2]);
229+
expect(output[0].data).toEqual([3, 3, 3, 3]);
230+
expect(output[0].sizes).toEqual([2, 2]);
222231

223232
inputs.forEach((input) => input.delete());
224233
output.forEach((output) => output.delete());
@@ -227,7 +236,7 @@ describe("Module", () => {
227236

228237
test("wrong number of inputs", () => {
229238
const module = et.Module.load("add_mul.pte");
230-
const inputs = [et.FloatTensor.ones([2, 2]), et.FloatTensor.ones([2, 2])];
239+
const inputs = [et.Tensor.ones([2, 2]), et.Tensor.ones([2, 2])];
231240
expect(() => module.execute("forward", inputs)).toThrow();
232241

233242
inputs.forEach((input) => input.delete());
@@ -236,7 +245,7 @@ describe("Module", () => {
236245

237246
test("wrong input size", () => {
238247
const module = et.Module.load("add.pte");
239-
const inputs = [et.FloatTensor.ones([2, 1]), et.FloatTensor.ones([2, 1])];
248+
const inputs = [et.Tensor.ones([2, 1]), et.Tensor.ones([2, 1])];
240249
expect(() => module.execute("forward", inputs)).toThrow();
241250

242251
inputs.forEach((input) => input.delete());
@@ -245,7 +254,7 @@ describe("Module", () => {
245254

246255
test("wrong input type", () => {
247256
const module = et.Module.load("add.pte");
248-
const inputs = [et.FloatTensor.ones([1]), et.LongTensor.ones([1])];
257+
const inputs = [et.Tensor.ones([1]), et.Tensor.ones([1], et.ScalarType.Long)];
249258
expect(() => module.execute("forward", inputs)).toThrow();
250259

251260
inputs.forEach((input) => input.delete());
@@ -254,7 +263,7 @@ describe("Module", () => {
254263

255264
test("method does not exist", () => {
256265
const module = et.Module.load("add.pte");
257-
const inputs = [et.FloatTensor.ones([1]), et.FloatTensor.ones([1])];
266+
const inputs = [et.Tensor.ones([1]), et.Tensor.ones([1])];
258267
expect(() => module.execute("does_not_exist", inputs)).toThrow();
259268

260269
inputs.forEach((input) => input.delete());
@@ -263,19 +272,19 @@ describe("Module", () => {
263272

264273
test("output tensor can be reused", () => {
265274
const module = et.Module.load("add_mul.pte");
266-
const inputs = [et.FloatTensor.ones([2, 2]), et.FloatTensor.ones([2, 2]), et.FloatTensor.ones([2, 2])];
275+
const inputs = [et.Tensor.ones([2, 2]), et.Tensor.ones([2, 2]), et.Tensor.ones([2, 2])];
267276
const output = module.forward(inputs);
268277

269278
expect(output.length).toEqual(1);
270-
expect(output[0].getData()).toEqual([3, 3, 3, 3]);
271-
expect(output[0].getSizes()).toEqual([2, 2]);
279+
expect(output[0].data).toEqual([3, 3, 3, 3]);
280+
expect(output[0].sizes).toEqual([2, 2]);
272281

273282
const inputs2 = [output[0], output[0], output[0]];
274283
const output2 = module.forward(inputs2);
275284

276285
expect(output2.length).toEqual(1);
277-
expect(output2[0].getData()).toEqual([21, 21, 21, 21]);
278-
expect(output2[0].getSizes()).toEqual([2, 2]);
286+
expect(output2[0].data).toEqual([21, 21, 21, 21]);
287+
expect(output2[0].sizes).toEqual([2, 2]);
279288

280289
inputs.forEach((input) => input.delete());
281290
output.forEach((output) => output.delete());

0 commit comments

Comments
 (0)