Skip to content

Commit 2acd97f

Browse files
committed
Changed JsTensor data to return memory view
1 parent 7da23d0 commit 2acd97f

File tree

2 files changed

+14
-15
lines changed

2 files changed

+14
-15
lines changed

extension/wasm/test/unittests.js

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,21 @@ beforeAll((done) => {
1717
describe("Tensor", () => {
1818
test("ones", () => {
1919
const tensor = et.Tensor.ones([2, 2]);
20-
expect(tensor.data).toEqual([1, 1, 1, 1]);
20+
expect(tensor.data).toEqual(new Float32Array([1, 1, 1, 1]));
2121
expect(tensor.sizes).toEqual([2, 2]);
2222
tensor.delete();
2323
});
2424

2525
test("zeros", () => {
2626
const tensor = et.Tensor.zeros([2, 2]);
27-
expect(tensor.data).toEqual([0, 0, 0, 0]);
27+
expect(tensor.data).toEqual(new Float32Array([0, 0, 0, 0]));
2828
expect(tensor.sizes).toEqual([2, 2]);
2929
tensor.delete();
3030
});
3131

3232
test("fromArray", () => {
3333
const tensor = et.Tensor.fromArray([2, 2], [1, 2, 3, 4]);
34-
expect(tensor.data).toEqual([1, 2, 3, 4]);
34+
expect(tensor.data).toEqual(new Float32Array([1, 2, 3, 4]));
3535
expect(tensor.sizes).toEqual([2, 2]);
3636
tensor.delete();
3737
});
@@ -42,7 +42,7 @@ describe("Tensor", () => {
4242

4343
test("full", () => {
4444
const tensor = et.Tensor.full([2, 2], 3);
45-
expect(tensor.data).toEqual([3, 3, 3, 3]);
45+
expect(tensor.data).toEqual(new Float32Array([3, 3, 3, 3]));
4646
expect(tensor.sizes).toEqual([2, 2]);
4747
tensor.delete();
4848
});
@@ -56,7 +56,7 @@ describe("Tensor", () => {
5656

5757
test("long tensor", () => {
5858
const tensor = et.Tensor.ones([2, 2], et.ScalarType.Long);
59-
expect(tensor.data).toEqual([1n, 1n, 1n, 1n]);
59+
expect(tensor.data).toEqual(new BigInt64Array([1n, 1n, 1n, 1n]));
6060
expect(tensor.sizes).toEqual([2, 2]);
6161
// ScalarType can only be checked by strict equality.
6262
expect(tensor.scalarType).toBe(et.ScalarType.Long);
@@ -66,7 +66,7 @@ describe("Tensor", () => {
6666
test("infer long tensor", () => {
6767
// Number cannot be converted to Long, so we use BigInt instead.
6868
const tensor = et.Tensor.fromArray([2, 2], [1n, 2n, 3n, 4n]);
69-
expect(tensor.data).toEqual([1n, 2n, 3n, 4n]);
69+
expect(tensor.data).toEqual(new BigInt64Array([1n, 2n, 3n, 4n]));
7070
expect(tensor.sizes).toEqual([2, 2]);
7171
// ScalarType can only be checked by strict equality.
7272
expect(tensor.scalarType).toBe(et.ScalarType.Long);
@@ -206,7 +206,7 @@ describe("Module", () => {
206206
const output = module.execute("forward", inputs);
207207

208208
expect(output.length).toEqual(1);
209-
expect(output[0].data).toEqual([2]);
209+
expect(output[0].data).toEqual(new Float32Array([2]));
210210
expect(output[0].sizes).toEqual([1]);
211211

212212
inputs.forEach((input) => input.delete());
@@ -220,7 +220,7 @@ describe("Module", () => {
220220
const output = module.execute("forward", inputs);
221221

222222
expect(output.length).toEqual(1);
223-
expect(output[0].data).toEqual([3, 3, 3, 3]);
223+
expect(output[0].data).toEqual(new Float32Array([3, 3, 3, 3]));
224224
expect(output[0].sizes).toEqual([2, 2]);
225225

226226
inputs.forEach((input) => input.delete());
@@ -234,7 +234,7 @@ describe("Module", () => {
234234
const output = module.forward(inputs);
235235

236236
expect(output.length).toEqual(1);
237-
expect(output[0].data).toEqual([3, 3, 3, 3]);
237+
expect(output[0].data).toEqual(new Float32Array([3, 3, 3, 3]));
238238
expect(output[0].sizes).toEqual([2, 2]);
239239

240240
inputs.forEach((input) => input.delete());
@@ -284,14 +284,14 @@ describe("Module", () => {
284284
const output = module.forward(inputs);
285285

286286
expect(output.length).toEqual(1);
287-
expect(output[0].data).toEqual([3, 3, 3, 3]);
287+
expect(output[0].data).toEqual(new Float32Array([3, 3, 3, 3]));
288288
expect(output[0].sizes).toEqual([2, 2]);
289289

290290
const inputs2 = [output[0], output[0], output[0]];
291291
const output2 = module.forward(inputs2);
292292

293293
expect(output2.length).toEqual(1);
294-
expect(output2[0].data).toEqual([21, 21, 21, 21]);
294+
expect(output2[0].data).toEqual(new Float32Array([21, 21, 21, 21]));
295295
expect(output2[0].sizes).toEqual([2, 2]);
296296

297297
inputs.forEach((input) => input.delete());

extension/wasm/wasm_bindings.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,12 @@ class JsTensor {
102102
ScalarType get_scalar_type() const {
103103
return tensor_->scalar_type();
104104
}
105-
val_array<val> get_data() const {
105+
val get_data() const {
106106
switch (get_scalar_type()) {
107107
#define JS_CASE_TENSOR_TO_VAL_TYPE(T, NAME) \
108108
case ScalarType::NAME: \
109-
return val::array( \
110-
get_tensor().data_ptr<T>(), \
111-
get_tensor().data_ptr<T>() + get_tensor().numel());
109+
return val( \
110+
typed_memory_view(get_tensor().numel(), get_tensor().data_ptr<T>()));
112111
JS_FORALL_SUPPORTED_TENSOR_TYPES(JS_CASE_TENSOR_TO_VAL_TYPE)
113112
default:
114113
THROW_JS_ERROR(

0 commit comments

Comments
 (0)