@@ -59,15 +59,13 @@ describe("Tensor operations", () => {
5959 const t1 = new Tensor ( "float32" , [ 1 , 2 , 3 , 4 , 5 , 6 ] , [ 3 , 2 ] ) ;
6060 const t2 = t1 . slice ( 1 ) ;
6161 const target = new Tensor ( "float32" , [ 3 , 4 ] , [ 2 ] ) ;
62-
6362 compare ( t2 , target ) ;
6463 } ) ;
6564
6665 it ( "should return a range of rows" , ( ) => {
6766 const t1 = new Tensor ( "float32" , [ 1 , 2 , 3 , 4 , 5 , 6 ] , [ 3 , 2 ] ) ;
6867 const t2 = t1 . slice ( [ 1 , 3 ] ) ;
6968 const target = new Tensor ( "float32" , [ 3 , 4 , 5 , 6 ] , [ 2 , 2 ] ) ;
70-
7169 compare ( t2 , target ) ;
7270 } ) ;
7371
@@ -78,9 +76,67 @@ describe("Tensor operations", () => {
7876 [ 4 , 7 ] ,
7977 ) ;
8078 const t2 = t1 . slice ( [ 1 , - 1 ] , [ 1 , - 1 ] ) ;
81-
8279 const target = new Tensor ( "float32" , [ 9 , 10 , 11 , 12 , 13 , 16 , 17 , 18 , 19 , 20 ] , [ 2 , 5 ] ) ;
80+ compare ( t2 , target ) ;
81+ } ) ;
82+
83+ it ( "should return the whole tensor when all indices are null/unset" , ( ) => {
84+ const t1 = new Tensor ( "float32" , [ 1 , 2 , 3 , 4 , 5 , 6 ] , [ 3 , 2 ] ) ;
85+ const t2 = t1 . slice ( ) ;
86+ compare ( t2 , t1 ) ;
87+ } ) ;
88+
89+ it ( "should return the whole dimension when index is null" , ( ) => {
90+ const t1 = new Tensor ( "float32" , [ 1 , 2 , 3 , 4 , 5 , 6 ] , [ 3 , 2 ] ) ;
91+ const t2 = t1 . slice ( null ) ;
92+ compare ( t2 , t1 ) ;
93+ } ) ;
94+
95+ it ( "should slice from index to end when [start, null] is used" , ( ) => {
96+ const t1 = new Tensor ( "float32" , [ 1 , 2 , 3 , 4 , 5 , 6 ] , [ 3 , 2 ] ) ;
97+ const t2 = t1 . slice ( [ 1 , null ] ) ;
98+ const target = new Tensor ( "float32" , [ 3 , 4 , 5 , 6 ] , [ 2 , 2 ] ) ;
99+ compare ( t2 , target ) ;
100+ } ) ;
101+
102+ it ( "should slice from beginning to index when [null, end] is used" , ( ) => {
103+ const t1 = new Tensor ( "float32" , [ 1 , 2 , 3 , 4 , 5 , 6 ] , [ 3 , 2 ] ) ;
104+ const t2 = t1 . slice ( [ null , 2 ] ) ;
105+ const target = new Tensor ( "float32" , [ 1 , 2 , 3 , 4 ] , [ 2 , 2 ] ) ;
106+ compare ( t2 , target ) ;
107+ } ) ;
108+
109+ it ( "should handle [null, null] as full slice" , ( ) => {
110+ const t1 = new Tensor ( "float32" , [ 1 , 2 , 3 , 4 , 5 , 6 ] , [ 3 , 2 ] ) ;
111+ const t2 = t1 . slice ( [ null , null ] ) ;
112+ compare ( t2 , t1 ) ;
113+ } ) ;
114+
115+ it ( "should select a single element when a number is used in slice" , ( ) => {
116+ const t1 = new Tensor ( "float32" , [ 1 , 2 , 3 , 4 , 5 , 6 ] , [ 3 , 2 ] ) ;
117+ const t2 = t1 . slice ( 2 , 1 ) ;
118+ const target = new Tensor ( "float32" , [ 6 ] , [ ] ) ;
119+ compare ( t2 , target ) ;
120+ } ) ;
83121
122+ it ( "should select a single row when a number is used in slice" , ( ) => {
123+ const t1 = new Tensor ( "float32" , [ 1 , 2 , 3 , 4 , 5 , 6 ] , [ 3 , 2 ] ) ;
124+ const t2 = t1 . slice ( 0 ) ;
125+ const target = new Tensor ( "float32" , [ 1 , 2 ] , [ 2 ] ) ;
126+ compare ( t2 , target ) ;
127+ } ) ;
128+
129+ it ( "should select a single column when a number is used in slice" , ( ) => {
130+ const t1 = new Tensor ( "float32" , [ 1 , 2 , 3 , 4 , 5 , 6 ] , [ 3 , 2 ] ) ;
131+ const t2 = t1 . slice ( null , 1 ) ;
132+ const target = new Tensor ( "float32" , [ 2 , 4 , 6 ] , [ 3 ] ) ;
133+ compare ( t2 , target ) ;
134+ } ) ;
135+
136+ it ( "should handle negative indices in slice" , ( ) => {
137+ const t1 = new Tensor ( "float32" , [ 1 , 2 , 3 , 4 , 5 , 6 ] , [ 3 , 2 ] ) ;
138+ const t2 = t1 . slice ( - 1 ) ;
139+ const target = new Tensor ( "float32" , [ 5 , 6 ] , [ 2 ] ) ;
84140 compare ( t2 , target ) ;
85141 } ) ;
86142 } ) ;
0 commit comments