1+ /*
2+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3+ * All rights reserved.
4+ *
5+ * This source code is licensed under the BSD-style license found in the
6+ * LICENSE file in the root directory of this source tree.
7+ */
18
29let et ;
310beforeAll ( ( done ) => {
@@ -39,6 +46,23 @@ describe("Tensor", () => {
3946 expect ( tensor . getSizes ( ) ) . toEqual ( [ 2 , 2 ] ) ;
4047 tensor . delete ( ) ;
4148 } ) ;
49+
50+ test ( "scalar type" , ( ) => {
51+ const tensor = et . FloatTensor . ones ( [ 2 , 2 ] ) ;
52+ // ScalarType can only be checked by strict equality.
53+ expect ( tensor . scalarType ) . toBe ( et . ScalarType . Float ) ;
54+ tensor . delete ( ) ;
55+ } ) ;
56+
57+ test ( "long tensor" , ( ) => {
58+ // Number cannot be converted to Long, so we use BigInt instead.
59+ const tensor = et . LongTensor . fromArray ( [ 1n , 2n , 3n , 4n ] , [ 2 , 2 ] ) ;
60+ expect ( tensor . getData ( ) ) . toEqual ( [ 1n , 2n , 3n , 4n ] ) ;
61+ expect ( tensor . getSizes ( ) ) . toEqual ( [ 2 , 2 ] ) ;
62+ // ScalarType can only be checked by strict equality.
63+ expect ( tensor . scalarType ) . toBe ( et . ScalarType . Long ) ;
64+ tensor . delete ( ) ;
65+ } ) ;
4266} ) ;
4367
4468describe ( "Module" , ( ) => {
@@ -66,15 +90,31 @@ describe("Module", () => {
6690 const module = et . Module . load ( "add_mul.pte" ) ;
6791 const methodMeta = module . getMethodMeta ( "forward" ) ;
6892 expect ( methodMeta . name ) . toEqual ( "forward" ) ;
69- methodMeta . delete ( ) ;
7093 module . delete ( ) ;
7194 } ) ;
7295
73- test ( "numInputs is 3 " , ( ) => {
96+ test ( "inputs are tensors " , ( ) => {
7497 const module = et . Module . load ( "add_mul.pte" ) ;
7598 const methodMeta = module . getMethodMeta ( "forward" ) ;
76- expect ( methodMeta . numInputs ) . toEqual ( 3 ) ;
77- methodMeta . delete ( ) ;
99+ expect ( methodMeta . inputTags . length ) . toEqual ( 3 ) ;
100+ // Tags can only be checked by strict equality.
101+ methodMeta . inputTags . forEach ( ( tag ) => expect ( tag ) . toBe ( et . Tag . Tensor ) ) ;
102+ module . delete ( ) ;
103+ } ) ;
104+
105+ test ( "outputs are tensors" , ( ) => {
106+ const module = et . Module . load ( "add_mul.pte" ) ;
107+ const methodMeta = module . getMethodMeta ( "forward" ) ;
108+ expect ( methodMeta . outputTags . length ) . toEqual ( 1 ) ;
109+ // Tags can only be checked by strict equality.
110+ expect ( methodMeta . outputTags [ 0 ] ) . toBe ( et . Tag . Tensor ) ;
111+ module . delete ( ) ;
112+ } ) ;
113+
114+ test ( "num instructions is 2" , ( ) => {
115+ const module = et . Module . load ( "add_mul.pte" ) ;
116+ const methodMeta = module . getMethodMeta ( "forward" ) ;
117+ expect ( methodMeta . numInstructions ) . toEqual ( 2 ) ;
78118 module . delete ( ) ;
79119 } ) ;
80120
@@ -85,23 +125,58 @@ describe("Module", () => {
85125 } ) ;
86126
87127 describe ( "TensorInfo" , ( ) => {
88- test ( "sizes is 2x2" , ( ) => {
128+ test ( "input sizes is 2x2" , ( ) => {
89129 const module = et . Module . load ( "add_mul.pte" ) ;
90130 const methodMeta = module . getMethodMeta ( "forward" ) ;
91- for ( var i = 0 ; i < methodMeta . numInputs ; i ++ ) {
92- const tensorInfo = methodMeta . inputTensorMeta ( i ) ;
131+ expect ( methodMeta . inputTensorMeta . length ) . toEqual ( 3 ) ;
132+ methodMeta . inputTensorMeta . forEach ( ( tensorInfo ) => {
93133 expect ( tensorInfo . sizes ) . toEqual ( [ 2 , 2 ] ) ;
94- tensorInfo . delete ( ) ;
95- }
96- methodMeta . delete ( ) ;
134+ } ) ;
135+ module . delete ( ) ;
136+ } ) ;
137+
138+ test ( "output sizes is 2x2" , ( ) => {
139+ const module = et . Module . load ( "add_mul.pte" ) ;
140+ const methodMeta = module . getMethodMeta ( "forward" ) ;
141+ expect ( methodMeta . outputTensorMeta . length ) . toEqual ( 1 ) ;
142+ expect ( methodMeta . outputTensorMeta [ 0 ] . sizes ) . toEqual ( [ 2 , 2 ] ) ;
143+ module . delete ( ) ;
144+ } ) ;
145+
146+ test ( "dim order is contiguous" , ( ) => {
147+ const module = et . Module . load ( "add_mul.pte" ) ;
148+ const methodMeta = module . getMethodMeta ( "forward" ) ;
149+ methodMeta . inputTensorMeta . forEach ( ( tensorInfo ) => {
150+ expect ( tensorInfo . dimOrder ) . toEqual ( [ 0 , 1 ] ) ;
151+ } ) ;
152+ module . delete ( ) ;
153+ } ) ;
154+
155+ test ( "scalar type is float" , ( ) => {
156+ const module = et . Module . load ( "add_mul.pte" ) ;
157+ const methodMeta = module . getMethodMeta ( "forward" ) ;
158+ methodMeta . inputTensorMeta . forEach ( ( tensorInfo ) => {
159+ // ScalarType can only be checked by strict equality.
160+ expect ( tensorInfo . scalarType ) . toBe ( et . ScalarType . Float ) ;
161+ } ) ;
162+ module . delete ( ) ;
163+ } ) ;
164+
165+ test ( "memory planned" , ( ) => {
166+ const module = et . Module . load ( "add_mul.pte" ) ;
167+ const methodMeta = module . getMethodMeta ( "forward" ) ;
168+ methodMeta . inputTensorMeta . forEach ( ( tensorInfo ) => {
169+ expect ( tensorInfo . isMemoryPlanned ) . toBe ( true ) ;
170+ } ) ;
97171 module . delete ( ) ;
98172 } ) ;
99173
100- test ( "out of range " , ( ) => {
174+ test ( "nbytes is 16 " , ( ) => {
101175 const module = et . Module . load ( "add_mul.pte" ) ;
102176 const methodMeta = module . getMethodMeta ( "forward" ) ;
103- expect ( ( ) => methodMeta . inputTensorMeta ( 3 ) ) . toThrow ( ) ;
104- methodMeta . delete ( ) ;
177+ methodMeta . inputTensorMeta . forEach ( ( tensorInfo ) => {
178+ expect ( tensorInfo . nbytes ) . toEqual ( 16 ) ;
179+ } ) ;
105180 module . delete ( ) ;
106181 } ) ;
107182 } ) ;
@@ -170,7 +245,7 @@ describe("Module", () => {
170245
171246 test ( "wrong input type" , ( ) => {
172247 const module = et . Module . load ( "add.pte" ) ;
173- const inputs = [ et . FloatTensor . ones ( [ 1 ] ) , et . IntTensor . ones ( [ 1 ] ) ] ;
248+ const inputs = [ et . FloatTensor . ones ( [ 1 ] ) , et . LongTensor . ones ( [ 1 ] ) ] ;
174249 expect ( ( ) => module . execute ( "forward" , inputs ) ) . toThrow ( ) ;
175250
176251 inputs . forEach ( ( input ) => input . delete ( ) ) ;
0 commit comments