@@ -59,17 +59,15 @@ describe("Tensor", () => {
5959
6060 test ( "scalar type" , ( ) => {
6161 const tensor = et . Tensor . ones ( [ 2 , 2 ] ) ;
62- // ScalarType can only be checked by strict equality.
63- expect ( tensor . scalarType ) . toBe ( et . ScalarType . Float ) ;
62+ expect ( tensor . scalarType ) . toEqual ( et . ScalarType . Float ) ;
6463 tensor . delete ( ) ;
6564 } ) ;
6665
6766 test ( "long tensor" , ( ) => {
6867 const tensor = et . Tensor . ones ( [ 2 , 2 ] , et . ScalarType . Long ) ;
6968 expect ( tensor . data ) . toEqual ( new BigInt64Array ( [ 1n , 1n , 1n , 1n ] ) ) ;
7069 expect ( tensor . sizes ) . toEqual ( [ 2 , 2 ] ) ;
71- // ScalarType can only be checked by strict equality.
72- expect ( tensor . scalarType ) . toBe ( et . ScalarType . Long ) ;
70+ expect ( tensor . scalarType ) . toEqual ( et . ScalarType . Long ) ;
7371 tensor . delete ( ) ;
7472 } ) ;
7573
@@ -78,8 +76,7 @@ describe("Tensor", () => {
7876 const tensor = et . Tensor . fromArray ( [ 2 , 2 ] , [ 1n , 2n , 3n , 4n ] ) ;
7977 expect ( tensor . data ) . toEqual ( new BigInt64Array ( [ 1n , 2n , 3n , 4n ] ) ) ;
8078 expect ( tensor . sizes ) . toEqual ( [ 2 , 2 ] ) ;
81- // ScalarType can only be checked by strict equality.
82- expect ( tensor . scalarType ) . toBe ( et . ScalarType . Long ) ;
79+ expect ( tensor . scalarType ) . toEqual ( et . ScalarType . Long ) ;
8380 tensor . delete ( ) ;
8481 } ) ;
8582} ) ;
@@ -124,17 +121,15 @@ describe("Module", () => {
124121 const module = et . Module . load ( "add_mul.pte" ) ;
125122 const methodMeta = module . getMethodMeta ( "forward" ) ;
126123 expect ( methodMeta . inputTags . length ) . toEqual ( 3 ) ;
127- // Tags can only be checked by strict equality.
128- methodMeta . inputTags . forEach ( ( tag ) => expect ( tag ) . toBe ( et . Tag . Tensor ) ) ;
124+ expect ( methodMeta . inputTags ) . toEqual ( [ et . Tag . Tensor , et . Tag . Tensor , et . Tag . Tensor ] ) ;
129125 module . delete ( ) ;
130126 } ) ;
131127
132128 test ( "outputs are tensors" , ( ) => {
133129 const module = et . Module . load ( "add_mul.pte" ) ;
134130 const methodMeta = module . getMethodMeta ( "forward" ) ;
135131 expect ( methodMeta . outputTags . length ) . toEqual ( 1 ) ;
136- // Tags can only be checked by strict equality.
137- expect ( methodMeta . outputTags [ 0 ] ) . toBe ( et . Tag . Tensor ) ;
132+ expect ( methodMeta . outputTags ) . toEqual ( [ et . Tag . Tensor ] ) ;
138133 module . delete ( ) ;
139134 } ) ;
140135
@@ -183,8 +178,7 @@ describe("Module", () => {
183178 const module = et . Module . load ( "add_mul.pte" ) ;
184179 const methodMeta = module . getMethodMeta ( "forward" ) ;
185180 methodMeta . inputTensorMeta . forEach ( ( tensorInfo ) => {
186- // ScalarType can only be checked by strict equality.
187- expect ( tensorInfo . scalarType ) . toBe ( et . ScalarType . Float ) ;
181+ expect ( tensorInfo . scalarType ) . toEqual ( et . ScalarType . Float ) ;
188182 } ) ;
189183 module . delete ( ) ;
190184 } ) ;
@@ -311,3 +305,11 @@ describe("Module", () => {
311305 } ) ;
312306 } ) ;
313307} ) ;
308+
309+ describe ( "sanity" , ( ) => {
310+ // Emscripten enums are equal by default for some reason.
311+ test ( "different enums are not equal" , ( ) => {
312+ expect ( et . ScalarType . Float ) . not . toEqual ( et . ScalarType . Long ) ;
313+ expect ( et . Tag . Int ) . not . toEqual ( et . Tag . Double ) ;
314+ } ) ;
315+ } ) ;
0 commit comments