Skip to content

Commit aa48efb

Browse files
committed
Set enum names for convenience
1 parent f4e7144 commit aa48efb

File tree

2 files changed

+23
-12
lines changed

2 files changed

+23
-12
lines changed

extension/wasm/test/unittests.js

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -59,17 +59,15 @@ describe("Tensor", () => {
5959

6060
test("scalar type", () => {
6161
const tensor = et.Tensor.ones([2, 2]);
62-
// ScalarType can only be checked by strict equality.
63-
expect(tensor.scalarType).toBe(et.ScalarType.Float);
62+
expect(tensor.scalarType).toEqual(et.ScalarType.Float);
6463
tensor.delete();
6564
});
6665

6766
test("long tensor", () => {
6867
const tensor = et.Tensor.ones([2, 2], et.ScalarType.Long);
6968
expect(tensor.data).toEqual(new BigInt64Array([1n, 1n, 1n, 1n]));
7069
expect(tensor.sizes).toEqual([2, 2]);
71-
// ScalarType can only be checked by strict equality.
72-
expect(tensor.scalarType).toBe(et.ScalarType.Long);
70+
expect(tensor.scalarType).toEqual(et.ScalarType.Long);
7371
tensor.delete();
7472
});
7573

@@ -78,8 +76,7 @@ describe("Tensor", () => {
7876
const tensor = et.Tensor.fromArray([2, 2], [1n, 2n, 3n, 4n]);
7977
expect(tensor.data).toEqual(new BigInt64Array([1n, 2n, 3n, 4n]));
8078
expect(tensor.sizes).toEqual([2, 2]);
81-
// ScalarType can only be checked by strict equality.
82-
expect(tensor.scalarType).toBe(et.ScalarType.Long);
79+
expect(tensor.scalarType).toEqual(et.ScalarType.Long);
8380
tensor.delete();
8481
});
8582
});
@@ -124,17 +121,15 @@ describe("Module", () => {
124121
const module = et.Module.load("add_mul.pte");
125122
const methodMeta = module.getMethodMeta("forward");
126123
expect(methodMeta.inputTags.length).toEqual(3);
127-
// Tags can only be checked by strict equality.
128-
methodMeta.inputTags.forEach((tag) => expect(tag).toBe(et.Tag.Tensor));
124+
expect(methodMeta.inputTags).toEqual([et.Tag.Tensor, et.Tag.Tensor, et.Tag.Tensor]);
129125
module.delete();
130126
});
131127

132128
test("outputs are tensors", () => {
133129
const module = et.Module.load("add_mul.pte");
134130
const methodMeta = module.getMethodMeta("forward");
135131
expect(methodMeta.outputTags.length).toEqual(1);
136-
// Tags can only be checked by strict equality.
137-
expect(methodMeta.outputTags[0]).toBe(et.Tag.Tensor);
132+
expect(methodMeta.outputTags).toEqual([et.Tag.Tensor]);
138133
module.delete();
139134
});
140135

@@ -183,8 +178,7 @@ describe("Module", () => {
183178
const module = et.Module.load("add_mul.pte");
184179
const methodMeta = module.getMethodMeta("forward");
185180
methodMeta.inputTensorMeta.forEach((tensorInfo) => {
186-
// ScalarType can only be checked by strict equality.
187-
expect(tensorInfo.scalarType).toBe(et.ScalarType.Float);
181+
expect(tensorInfo.scalarType).toEqual(et.ScalarType.Float);
188182
});
189183
module.delete();
190184
});
@@ -311,3 +305,11 @@ describe("Module", () => {
311305
});
312306
});
313307
});
308+
309+
describe("sanity", () => {
310+
// Emscripten enums are equal by default for some reason.
311+
test("different enums are not equal", () => {
312+
expect(et.ScalarType.Float).not.toEqual(et.ScalarType.Long);
313+
expect(et.Tag.Int).not.toEqual(et.Tag.Double);
314+
});
315+
});

extension/wasm/wasm_bindings.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,15 @@ EMSCRIPTEN_BINDINGS(WasmBindings) {
539539
&JsMethodMeta::memory_planned_buffer_sizes)
540540
.field("backends", &JsMethodMeta::backends)
541541
.field("numInstructions", &JsMethodMeta::num_instructions);
542+
543+
// For some reason Embind doesn't make it easy to get the names of enums.
544+
// Additionally, different enums of the same type are considered to be equal.
545+
// Assigning the name field fixes both of these issues.
546+
#define JS_ASSIGN_SCALAR_TYPE_NAME(T, NAME) \
547+
EM_ASM(Module.ScalarType.NAME.name = #NAME);
548+
JS_FORALL_SUPPORTED_TENSOR_TYPES(JS_ASSIGN_SCALAR_TYPE_NAME)
549+
#define JS_ASSIGN_TAG_NAME(NAME) EM_ASM(Module.Tag.NAME.name = #NAME);
550+
EXECUTORCH_FORALL_TAGS(JS_ASSIGN_TAG_NAME)
542551
}
543552

544553
} // namespace wasm

0 commit comments

Comments
 (0)