@@ -16,49 +16,49 @@ beforeAll((done) => {
1616
1717describe ( "Tensor" , ( ) => {
1818 test ( "ones" , ( ) => {
19- const tensor = et . FloatTensor . ones ( [ 2 , 2 ] ) ;
20- expect ( tensor . getData ( ) ) . toEqual ( [ 1 , 1 , 1 , 1 ] ) ;
21- expect ( tensor . getSizes ( ) ) . toEqual ( [ 2 , 2 ] ) ;
19+ const tensor = et . Tensor . ones ( [ 2 , 2 ] ) ;
20+ expect ( tensor . data ) . toEqual ( [ 1 , 1 , 1 , 1 ] ) ;
21+ expect ( tensor . sizes ) . toEqual ( [ 2 , 2 ] ) ;
2222 tensor . delete ( ) ;
2323 } ) ;
2424
2525 test ( "zeros" , ( ) => {
26- const tensor = et . FloatTensor . zeros ( [ 2 , 2 ] ) ;
27- expect ( tensor . getData ( ) ) . toEqual ( [ 0 , 0 , 0 , 0 ] ) ;
28- expect ( tensor . getSizes ( ) ) . toEqual ( [ 2 , 2 ] ) ;
26+ const tensor = et . Tensor . zeros ( [ 2 , 2 ] ) ;
27+ expect ( tensor . data ) . toEqual ( [ 0 , 0 , 0 , 0 ] ) ;
28+ expect ( tensor . sizes ) . toEqual ( [ 2 , 2 ] ) ;
2929 tensor . delete ( ) ;
3030 } ) ;
3131
3232 test ( "fromArray" , ( ) => {
33- const tensor = et . FloatTensor . fromArray ( [ 1 , 2 , 3 , 4 ] , [ 2 , 2 ] ) ;
34- expect ( tensor . getData ( ) ) . toEqual ( [ 1 , 2 , 3 , 4 ] ) ;
35- expect ( tensor . getSizes ( ) ) . toEqual ( [ 2 , 2 ] ) ;
33+ const tensor = et . Tensor . fromArray ( [ 2 , 2 ] , [ 1 , 2 , 3 , 4 ] ) ;
34+ expect ( tensor . data ) . toEqual ( [ 1 , 2 , 3 , 4 ] ) ;
35+ expect ( tensor . sizes ) . toEqual ( [ 2 , 2 ] ) ;
3636 tensor . delete ( ) ;
3737 } ) ;
3838
3939 test ( "fromArray wrong size" , ( ) => {
40- expect ( ( ) => et . FloatTensor . fromArray ( [ 1 , 2 , 3 , 4 ] , [ 3 , 2 ] ) ) . toThrow ( ) ;
40+ expect ( ( ) => et . Tensor . fromArray ( [ 3 , 2 ] , [ 1 , 2 , 3 , 4 ] ) ) . toThrow ( ) ;
4141 } ) ;
4242
4343 test ( "full" , ( ) => {
44- const tensor = et . FloatTensor . full ( [ 2 , 2 ] , 3 ) ;
45- expect ( tensor . getData ( ) ) . toEqual ( [ 3 , 3 , 3 , 3 ] ) ;
46- expect ( tensor . getSizes ( ) ) . toEqual ( [ 2 , 2 ] ) ;
44+ const tensor = et . Tensor . full ( [ 2 , 2 ] , 3 ) ;
45+ expect ( tensor . data ) . toEqual ( [ 3 , 3 , 3 , 3 ] ) ;
46+ expect ( tensor . sizes ) . toEqual ( [ 2 , 2 ] ) ;
4747 tensor . delete ( ) ;
4848 } ) ;
4949
5050 test ( "scalar type" , ( ) => {
51- const tensor = et . FloatTensor . ones ( [ 2 , 2 ] ) ;
51+ const tensor = et . Tensor . ones ( [ 2 , 2 ] ) ;
5252 // ScalarType can only be checked by strict equality.
5353 expect ( tensor . scalarType ) . toBe ( et . ScalarType . Float ) ;
5454 tensor . delete ( ) ;
5555 } ) ;
5656
5757 test ( "long tensor" , ( ) => {
5858 // Number cannot be converted to Long, so we use BigInt instead.
59- const tensor = et . LongTensor . fromArray ( [ 1n , 2n , 3n , 4n ] , [ 2 , 2 ] ) ;
60- expect ( tensor . getData ( ) ) . toEqual ( [ 1n , 2n , 3n , 4n ] ) ;
61- expect ( tensor . getSizes ( ) ) . toEqual ( [ 2 , 2 ] ) ;
59+ const tensor = et . Tensor . fromArray ( [ 2 , 2 ] , [ 1n , 2n , 3n , 4n ] , et . ScalarType . Long ) ;
60+ expect ( tensor . data ) . toEqual ( [ 1n , 2n , 3n , 4n ] ) ;
61+ expect ( tensor . sizes ) . toEqual ( [ 2 , 2 ] ) ;
6262 // ScalarType can only be checked by strict equality.
6363 expect ( tensor . scalarType ) . toBe ( et . ScalarType . Long ) ;
6464 tensor . delete ( ) ;
@@ -185,12 +185,12 @@ describe("Module", () => {
185185 describe ( "execute" , ( ) => {
186186 test ( "add normally" , ( ) => {
187187 const module = et . Module . load ( "add.pte" ) ;
188- const inputs = [ et . FloatTensor . ones ( [ 1 ] ) , et . FloatTensor . ones ( [ 1 ] ) ] ;
188+ const inputs = [ et . Tensor . ones ( [ 1 ] ) , et . Tensor . ones ( [ 1 ] ) ] ;
189189 const output = module . execute ( "forward" , inputs ) ;
190190
191191 expect ( output . length ) . toEqual ( 1 ) ;
192- expect ( output [ 0 ] . getData ( ) ) . toEqual ( [ 2 ] ) ;
193- expect ( output [ 0 ] . getSizes ( ) ) . toEqual ( [ 1 ] ) ;
192+ expect ( output [ 0 ] . data ) . toEqual ( [ 2 ] ) ;
193+ expect ( output [ 0 ] . sizes ) . toEqual ( [ 1 ] ) ;
194194
195195 inputs . forEach ( ( input ) => input . delete ( ) ) ;
196196 output . forEach ( ( output ) => output . delete ( ) ) ;
@@ -199,12 +199,12 @@ describe("Module", () => {
199199
200200 test ( "add_mul normally" , ( ) => {
201201 const module = et . Module . load ( "add_mul.pte" ) ;
202- const inputs = [ et . FloatTensor . ones ( [ 2 , 2 ] ) , et . FloatTensor . ones ( [ 2 , 2 ] ) , et . FloatTensor . ones ( [ 2 , 2 ] ) ] ;
202+ const inputs = [ et . Tensor . ones ( [ 2 , 2 ] ) , et . Tensor . ones ( [ 2 , 2 ] ) , et . Tensor . ones ( [ 2 , 2 ] ) ] ;
203203 const output = module . execute ( "forward" , inputs ) ;
204204
205205 expect ( output . length ) . toEqual ( 1 ) ;
206- expect ( output [ 0 ] . getData ( ) ) . toEqual ( [ 3 , 3 , 3 , 3 ] ) ;
207- expect ( output [ 0 ] . getSizes ( ) ) . toEqual ( [ 2 , 2 ] ) ;
206+ expect ( output [ 0 ] . data ) . toEqual ( [ 3 , 3 , 3 , 3 ] ) ;
207+ expect ( output [ 0 ] . sizes ) . toEqual ( [ 2 , 2 ] ) ;
208208
209209 inputs . forEach ( ( input ) => input . delete ( ) ) ;
210210 output . forEach ( ( output ) => output . delete ( ) ) ;
@@ -213,12 +213,12 @@ describe("Module", () => {
213213
214214 test ( "forward directly" , ( ) => {
215215 const module = et . Module . load ( "add_mul.pte" ) ;
216- const inputs = [ et . FloatTensor . ones ( [ 2 , 2 ] ) , et . FloatTensor . ones ( [ 2 , 2 ] ) , et . FloatTensor . ones ( [ 2 , 2 ] ) ] ;
216+ const inputs = [ et . Tensor . ones ( [ 2 , 2 ] ) , et . Tensor . ones ( [ 2 , 2 ] ) , et . Tensor . ones ( [ 2 , 2 ] ) ] ;
217217 const output = module . forward ( inputs ) ;
218218
219219 expect ( output . length ) . toEqual ( 1 ) ;
220- expect ( output [ 0 ] . getData ( ) ) . toEqual ( [ 3 , 3 , 3 , 3 ] ) ;
221- expect ( output [ 0 ] . getSizes ( ) ) . toEqual ( [ 2 , 2 ] ) ;
220+ expect ( output [ 0 ] . data ) . toEqual ( [ 3 , 3 , 3 , 3 ] ) ;
221+ expect ( output [ 0 ] . sizes ) . toEqual ( [ 2 , 2 ] ) ;
222222
223223 inputs . forEach ( ( input ) => input . delete ( ) ) ;
224224 output . forEach ( ( output ) => output . delete ( ) ) ;
@@ -227,7 +227,7 @@ describe("Module", () => {
227227
228228 test ( "wrong number of inputs" , ( ) => {
229229 const module = et . Module . load ( "add_mul.pte" ) ;
230- const inputs = [ et . FloatTensor . ones ( [ 2 , 2 ] ) , et . FloatTensor . ones ( [ 2 , 2 ] ) ] ;
230+ const inputs = [ et . Tensor . ones ( [ 2 , 2 ] ) , et . Tensor . ones ( [ 2 , 2 ] ) ] ;
231231 expect ( ( ) => module . execute ( "forward" , inputs ) ) . toThrow ( ) ;
232232
233233 inputs . forEach ( ( input ) => input . delete ( ) ) ;
@@ -236,7 +236,7 @@ describe("Module", () => {
236236
237237 test ( "wrong input size" , ( ) => {
238238 const module = et . Module . load ( "add.pte" ) ;
239- const inputs = [ et . FloatTensor . ones ( [ 2 , 1 ] ) , et . FloatTensor . ones ( [ 2 , 1 ] ) ] ;
239+ const inputs = [ et . Tensor . ones ( [ 2 , 1 ] ) , et . Tensor . ones ( [ 2 , 1 ] ) ] ;
240240 expect ( ( ) => module . execute ( "forward" , inputs ) ) . toThrow ( ) ;
241241
242242 inputs . forEach ( ( input ) => input . delete ( ) ) ;
@@ -245,7 +245,7 @@ describe("Module", () => {
245245
246246 test ( "wrong input type" , ( ) => {
247247 const module = et . Module . load ( "add.pte" ) ;
248- const inputs = [ et . FloatTensor . ones ( [ 1 ] ) , et . LongTensor . ones ( [ 1 ] ) ] ;
248+ const inputs = [ et . Tensor . ones ( [ 1 ] ) , et . Tensor . ones ( [ 1 ] , et . ScalarType . Long ) ] ;
249249 expect ( ( ) => module . execute ( "forward" , inputs ) ) . toThrow ( ) ;
250250
251251 inputs . forEach ( ( input ) => input . delete ( ) ) ;
@@ -254,7 +254,7 @@ describe("Module", () => {
254254
255255 test ( "method does not exist" , ( ) => {
256256 const module = et . Module . load ( "add.pte" ) ;
257- const inputs = [ et . FloatTensor . ones ( [ 1 ] ) , et . FloatTensor . ones ( [ 1 ] ) ] ;
257+ const inputs = [ et . Tensor . ones ( [ 1 ] ) , et . Tensor . ones ( [ 1 ] ) ] ;
258258 expect ( ( ) => module . execute ( "does_not_exist" , inputs ) ) . toThrow ( ) ;
259259
260260 inputs . forEach ( ( input ) => input . delete ( ) ) ;
@@ -263,19 +263,19 @@ describe("Module", () => {
263263
264264 test ( "output tensor can be reused" , ( ) => {
265265 const module = et . Module . load ( "add_mul.pte" ) ;
266- const inputs = [ et . FloatTensor . ones ( [ 2 , 2 ] ) , et . FloatTensor . ones ( [ 2 , 2 ] ) , et . FloatTensor . ones ( [ 2 , 2 ] ) ] ;
266+ const inputs = [ et . Tensor . ones ( [ 2 , 2 ] ) , et . Tensor . ones ( [ 2 , 2 ] ) , et . Tensor . ones ( [ 2 , 2 ] ) ] ;
267267 const output = module . forward ( inputs ) ;
268268
269269 expect ( output . length ) . toEqual ( 1 ) ;
270- expect ( output [ 0 ] . getData ( ) ) . toEqual ( [ 3 , 3 , 3 , 3 ] ) ;
271- expect ( output [ 0 ] . getSizes ( ) ) . toEqual ( [ 2 , 2 ] ) ;
270+ expect ( output [ 0 ] . data ) . toEqual ( [ 3 , 3 , 3 , 3 ] ) ;
271+ expect ( output [ 0 ] . sizes ) . toEqual ( [ 2 , 2 ] ) ;
272272
273273 const inputs2 = [ output [ 0 ] , output [ 0 ] , output [ 0 ] ] ;
274274 const output2 = module . forward ( inputs2 ) ;
275275
276276 expect ( output2 . length ) . toEqual ( 1 ) ;
277- expect ( output2 [ 0 ] . getData ( ) ) . toEqual ( [ 21 , 21 , 21 , 21 ] ) ;
278- expect ( output2 [ 0 ] . getSizes ( ) ) . toEqual ( [ 2 , 2 ] ) ;
277+ expect ( output2 [ 0 ] . data ) . toEqual ( [ 21 , 21 , 21 , 21 ] ) ;
278+ expect ( output2 [ 0 ] . sizes ) . toEqual ( [ 2 , 2 ] ) ;
279279
280280 inputs . forEach ( ( input ) => input . delete ( ) ) ;
281281 output . forEach ( ( output ) => output . delete ( ) ) ;
0 commit comments