@@ -16,49 +16,58 @@ beforeAll((done) => {
1616
1717describe ( "Tensor" , ( ) => {
1818 test ( "ones" , ( ) => {
19- const tensor = et . FloatTensor . ones ( [ 2 , 2 ] ) ;
20- expect ( tensor . getData ( ) ) . toEqual ( [ 1 , 1 , 1 , 1 ] ) ;
21- expect ( tensor . getSizes ( ) ) . toEqual ( [ 2 , 2 ] ) ;
19+ const tensor = et . Tensor . ones ( [ 2 , 2 ] ) ;
20+ expect ( tensor . data ) . toEqual ( [ 1 , 1 , 1 , 1 ] ) ;
21+ expect ( tensor . sizes ) . toEqual ( [ 2 , 2 ] ) ;
2222 tensor . delete ( ) ;
2323 } ) ;
2424
2525 test ( "zeros" , ( ) => {
26- const tensor = et . FloatTensor . zeros ( [ 2 , 2 ] ) ;
27- expect ( tensor . getData ( ) ) . toEqual ( [ 0 , 0 , 0 , 0 ] ) ;
28- expect ( tensor . getSizes ( ) ) . toEqual ( [ 2 , 2 ] ) ;
26+ const tensor = et . Tensor . zeros ( [ 2 , 2 ] ) ;
27+ expect ( tensor . data ) . toEqual ( [ 0 , 0 , 0 , 0 ] ) ;
28+ expect ( tensor . sizes ) . toEqual ( [ 2 , 2 ] ) ;
2929 tensor . delete ( ) ;
3030 } ) ;
3131
3232 test ( "fromArray" , ( ) => {
33- const tensor = et . FloatTensor . fromArray ( [ 1 , 2 , 3 , 4 ] , [ 2 , 2 ] ) ;
34- expect ( tensor . getData ( ) ) . toEqual ( [ 1 , 2 , 3 , 4 ] ) ;
35- expect ( tensor . getSizes ( ) ) . toEqual ( [ 2 , 2 ] ) ;
33+ const tensor = et . Tensor . fromArray ( [ 2 , 2 ] , [ 1 , 2 , 3 , 4 ] ) ;
34+ expect ( tensor . data ) . toEqual ( [ 1 , 2 , 3 , 4 ] ) ;
35+ expect ( tensor . sizes ) . toEqual ( [ 2 , 2 ] ) ;
3636 tensor . delete ( ) ;
3737 } ) ;
3838
3939 test ( "fromArray wrong size" , ( ) => {
40- expect ( ( ) => et . FloatTensor . fromArray ( [ 1 , 2 , 3 , 4 ] , [ 3 , 2 ] ) ) . toThrow ( ) ;
40+ expect ( ( ) => et . Tensor . fromArray ( [ 3 , 2 ] , [ 1 , 2 , 3 , 4 ] ) ) . toThrow ( ) ;
4141 } ) ;
4242
4343 test ( "full" , ( ) => {
44- const tensor = et . FloatTensor . full ( [ 2 , 2 ] , 3 ) ;
45- expect ( tensor . getData ( ) ) . toEqual ( [ 3 , 3 , 3 , 3 ] ) ;
46- expect ( tensor . getSizes ( ) ) . toEqual ( [ 2 , 2 ] ) ;
44+ const tensor = et . Tensor . full ( [ 2 , 2 ] , 3 ) ;
45+ expect ( tensor . data ) . toEqual ( [ 3 , 3 , 3 , 3 ] ) ;
46+ expect ( tensor . sizes ) . toEqual ( [ 2 , 2 ] ) ;
4747 tensor . delete ( ) ;
4848 } ) ;
4949
5050 test ( "scalar type" , ( ) => {
51- const tensor = et . FloatTensor . ones ( [ 2 , 2 ] ) ;
51+ const tensor = et . Tensor . ones ( [ 2 , 2 ] ) ;
5252 // ScalarType can only be checked by strict equality.
5353 expect ( tensor . scalarType ) . toBe ( et . ScalarType . Float ) ;
5454 tensor . delete ( ) ;
5555 } ) ;
5656
5757 test ( "long tensor" , ( ) => {
58+ const tensor = et . Tensor . ones ( [ 2 , 2 ] , et . ScalarType . Long ) ;
59+ expect ( tensor . data ) . toEqual ( [ 1n , 1n , 1n , 1n ] ) ;
60+ expect ( tensor . sizes ) . toEqual ( [ 2 , 2 ] ) ;
61+ // ScalarType can only be checked by strict equality.
62+ expect ( tensor . scalarType ) . toBe ( et . ScalarType . Long ) ;
63+ tensor . delete ( ) ;
64+ } ) ;
65+
66+ test ( "infer long tensor" , ( ) => {
5867 // Number cannot be converted to Long, so we use BigInt instead.
59- const tensor = et . LongTensor . fromArray ( [ 1n , 2n , 3n , 4n ] , [ 2 , 2 ] ) ;
60- expect ( tensor . getData ( ) ) . toEqual ( [ 1n , 2n , 3n , 4n ] ) ;
61- expect ( tensor . getSizes ( ) ) . toEqual ( [ 2 , 2 ] ) ;
68+ const tensor = et . Tensor . fromArray ( [ 2 , 2 ] , [ 1n , 2n , 3n , 4n ] ) ;
69+ expect ( tensor . data ) . toEqual ( [ 1n , 2n , 3n , 4n ] ) ;
70+ expect ( tensor . sizes ) . toEqual ( [ 2 , 2 ] ) ;
6271 // ScalarType can only be checked by strict equality.
6372 expect ( tensor . scalarType ) . toBe ( et . ScalarType . Long ) ;
6473 tensor . delete ( ) ;
@@ -185,12 +194,12 @@ describe("Module", () => {
185194 describe ( "execute" , ( ) => {
186195 test ( "add normally" , ( ) => {
187196 const module = et . Module . load ( "add.pte" ) ;
188- const inputs = [ et . FloatTensor . ones ( [ 1 ] ) , et . FloatTensor . ones ( [ 1 ] ) ] ;
197+ const inputs = [ et . Tensor . ones ( [ 1 ] ) , et . Tensor . ones ( [ 1 ] ) ] ;
189198 const output = module . execute ( "forward" , inputs ) ;
190199
191200 expect ( output . length ) . toEqual ( 1 ) ;
192- expect ( output [ 0 ] . getData ( ) ) . toEqual ( [ 2 ] ) ;
193- expect ( output [ 0 ] . getSizes ( ) ) . toEqual ( [ 1 ] ) ;
201+ expect ( output [ 0 ] . data ) . toEqual ( [ 2 ] ) ;
202+ expect ( output [ 0 ] . sizes ) . toEqual ( [ 1 ] ) ;
194203
195204 inputs . forEach ( ( input ) => input . delete ( ) ) ;
196205 output . forEach ( ( output ) => output . delete ( ) ) ;
@@ -199,12 +208,12 @@ describe("Module", () => {
199208
200209 test ( "add_mul normally" , ( ) => {
201210 const module = et . Module . load ( "add_mul.pte" ) ;
202- const inputs = [ et . FloatTensor . ones ( [ 2 , 2 ] ) , et . FloatTensor . ones ( [ 2 , 2 ] ) , et . FloatTensor . ones ( [ 2 , 2 ] ) ] ;
211+ const inputs = [ et . Tensor . ones ( [ 2 , 2 ] ) , et . Tensor . ones ( [ 2 , 2 ] ) , et . Tensor . ones ( [ 2 , 2 ] ) ] ;
203212 const output = module . execute ( "forward" , inputs ) ;
204213
205214 expect ( output . length ) . toEqual ( 1 ) ;
206- expect ( output [ 0 ] . getData ( ) ) . toEqual ( [ 3 , 3 , 3 , 3 ] ) ;
207- expect ( output [ 0 ] . getSizes ( ) ) . toEqual ( [ 2 , 2 ] ) ;
215+ expect ( output [ 0 ] . data ) . toEqual ( [ 3 , 3 , 3 , 3 ] ) ;
216+ expect ( output [ 0 ] . sizes ) . toEqual ( [ 2 , 2 ] ) ;
208217
209218 inputs . forEach ( ( input ) => input . delete ( ) ) ;
210219 output . forEach ( ( output ) => output . delete ( ) ) ;
@@ -213,12 +222,12 @@ describe("Module", () => {
213222
214223 test ( "forward directly" , ( ) => {
215224 const module = et . Module . load ( "add_mul.pte" ) ;
216- const inputs = [ et . FloatTensor . ones ( [ 2 , 2 ] ) , et . FloatTensor . ones ( [ 2 , 2 ] ) , et . FloatTensor . ones ( [ 2 , 2 ] ) ] ;
225+ const inputs = [ et . Tensor . ones ( [ 2 , 2 ] ) , et . Tensor . ones ( [ 2 , 2 ] ) , et . Tensor . ones ( [ 2 , 2 ] ) ] ;
217226 const output = module . forward ( inputs ) ;
218227
219228 expect ( output . length ) . toEqual ( 1 ) ;
220- expect ( output [ 0 ] . getData ( ) ) . toEqual ( [ 3 , 3 , 3 , 3 ] ) ;
221- expect ( output [ 0 ] . getSizes ( ) ) . toEqual ( [ 2 , 2 ] ) ;
229+ expect ( output [ 0 ] . data ) . toEqual ( [ 3 , 3 , 3 , 3 ] ) ;
230+ expect ( output [ 0 ] . sizes ) . toEqual ( [ 2 , 2 ] ) ;
222231
223232 inputs . forEach ( ( input ) => input . delete ( ) ) ;
224233 output . forEach ( ( output ) => output . delete ( ) ) ;
@@ -227,7 +236,7 @@ describe("Module", () => {
227236
228237 test ( "wrong number of inputs" , ( ) => {
229238 const module = et . Module . load ( "add_mul.pte" ) ;
230- const inputs = [ et . FloatTensor . ones ( [ 2 , 2 ] ) , et . FloatTensor . ones ( [ 2 , 2 ] ) ] ;
239+ const inputs = [ et . Tensor . ones ( [ 2 , 2 ] ) , et . Tensor . ones ( [ 2 , 2 ] ) ] ;
231240 expect ( ( ) => module . execute ( "forward" , inputs ) ) . toThrow ( ) ;
232241
233242 inputs . forEach ( ( input ) => input . delete ( ) ) ;
@@ -236,7 +245,7 @@ describe("Module", () => {
236245
237246 test ( "wrong input size" , ( ) => {
238247 const module = et . Module . load ( "add.pte" ) ;
239- const inputs = [ et . FloatTensor . ones ( [ 2 , 1 ] ) , et . FloatTensor . ones ( [ 2 , 1 ] ) ] ;
248+ const inputs = [ et . Tensor . ones ( [ 2 , 1 ] ) , et . Tensor . ones ( [ 2 , 1 ] ) ] ;
240249 expect ( ( ) => module . execute ( "forward" , inputs ) ) . toThrow ( ) ;
241250
242251 inputs . forEach ( ( input ) => input . delete ( ) ) ;
@@ -245,7 +254,7 @@ describe("Module", () => {
245254
246255 test ( "wrong input type" , ( ) => {
247256 const module = et . Module . load ( "add.pte" ) ;
248- const inputs = [ et . FloatTensor . ones ( [ 1 ] ) , et . LongTensor . ones ( [ 1 ] ) ] ;
257+ const inputs = [ et . Tensor . ones ( [ 1 ] ) , et . Tensor . ones ( [ 1 ] , et . ScalarType . Long ) ] ;
249258 expect ( ( ) => module . execute ( "forward" , inputs ) ) . toThrow ( ) ;
250259
251260 inputs . forEach ( ( input ) => input . delete ( ) ) ;
@@ -254,7 +263,7 @@ describe("Module", () => {
254263
255264 test ( "method does not exist" , ( ) => {
256265 const module = et . Module . load ( "add.pte" ) ;
257- const inputs = [ et . FloatTensor . ones ( [ 1 ] ) , et . FloatTensor . ones ( [ 1 ] ) ] ;
266+ const inputs = [ et . Tensor . ones ( [ 1 ] ) , et . Tensor . ones ( [ 1 ] ) ] ;
258267 expect ( ( ) => module . execute ( "does_not_exist" , inputs ) ) . toThrow ( ) ;
259268
260269 inputs . forEach ( ( input ) => input . delete ( ) ) ;
@@ -263,19 +272,19 @@ describe("Module", () => {
263272
264273 test ( "output tensor can be reused" , ( ) => {
265274 const module = et . Module . load ( "add_mul.pte" ) ;
266- const inputs = [ et . FloatTensor . ones ( [ 2 , 2 ] ) , et . FloatTensor . ones ( [ 2 , 2 ] ) , et . FloatTensor . ones ( [ 2 , 2 ] ) ] ;
275+ const inputs = [ et . Tensor . ones ( [ 2 , 2 ] ) , et . Tensor . ones ( [ 2 , 2 ] ) , et . Tensor . ones ( [ 2 , 2 ] ) ] ;
267276 const output = module . forward ( inputs ) ;
268277
269278 expect ( output . length ) . toEqual ( 1 ) ;
270- expect ( output [ 0 ] . getData ( ) ) . toEqual ( [ 3 , 3 , 3 , 3 ] ) ;
271- expect ( output [ 0 ] . getSizes ( ) ) . toEqual ( [ 2 , 2 ] ) ;
279+ expect ( output [ 0 ] . data ) . toEqual ( [ 3 , 3 , 3 , 3 ] ) ;
280+ expect ( output [ 0 ] . sizes ) . toEqual ( [ 2 , 2 ] ) ;
272281
273282 const inputs2 = [ output [ 0 ] , output [ 0 ] , output [ 0 ] ] ;
274283 const output2 = module . forward ( inputs2 ) ;
275284
276285 expect ( output2 . length ) . toEqual ( 1 ) ;
277- expect ( output2 [ 0 ] . getData ( ) ) . toEqual ( [ 21 , 21 , 21 , 21 ] ) ;
278- expect ( output2 [ 0 ] . getSizes ( ) ) . toEqual ( [ 2 , 2 ] ) ;
286+ expect ( output2 [ 0 ] . data ) . toEqual ( [ 21 , 21 , 21 , 21 ] ) ;
287+ expect ( output2 [ 0 ] . sizes ) . toEqual ( [ 2 , 2 ] ) ;
279288
280289 inputs . forEach ( ( input ) => input . delete ( ) ) ;
281290 output . forEach ( ( output ) => output . delete ( ) ) ;
0 commit comments