1
1
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
2
2
3
3
using System . Collections . Generic ;
4
+ using System . Linq ;
5
+ using TorchSharp . Modules ;
4
6
using Xunit ;
5
7
6
8
@@ -27,6 +29,26 @@ private class TestIterableDataset : torch.utils.data.IterableDataset
27
29
}
28
30
}
29
31
32
+ private class TestDatasetFromIEnumerable < T > : torch . utils . data . IDataset < T >
33
+ {
34
+ private readonly T [ ] values ;
35
+ public TestDatasetFromIEnumerable ( IEnumerable < T > values )
36
+ {
37
+ this . values = values . ToArray ( ) ;
38
+ this . Disposed = false ;
39
+ }
40
+
41
+ public bool Disposed { get ; set ; }
42
+
43
+ public T this [ long index ] => values [ index ] ;
44
+
45
+ public long Count => values . LongLength ;
46
+
47
+ public void Dispose ( ) {
48
+ this . Disposed = true ;
49
+ }
50
+ }
51
+
30
52
[ Fact ]
31
53
public void DatasetTest ( )
32
54
{
@@ -230,5 +252,59 @@ public void CustomSeedTest()
230
252
iterator . Dispose ( ) ;
231
253
iterator2 . Dispose ( ) ;
232
254
}
255
+
256
+ [ Fact ]
257
+ public void ConcatDatasetTest ( )
258
+ {
259
+ using var dataset1 = new TestDatasetFromIEnumerable < ( int , int ) > ( new [ ] {
260
+ ( 1 , 1 ) , // dataset 1 value 1
261
+ ( 1 , 2 ) , // dataset 1 value 2
262
+ ( 1 , 3 ) ,
263
+ } ) ;
264
+ using var dataset2 = new TestDatasetFromIEnumerable < ( int , int ) > ( new [ ] {
265
+ ( 2 , 1 ) ,
266
+ ( 2 , 2 ) ,
267
+ } ) ;
268
+ using var dataset3 = new TestDatasetFromIEnumerable < ( int , int ) > ( new [ ] {
269
+ ( 3 , 1 ) ,
270
+ ( 3 , 2 ) ,
271
+ ( 3 , 3 ) ,
272
+ ( 3 , 4 ) ,
273
+ } ) ;
274
+
275
+ using var dataset = new ConcatDataset < ( int , int ) > ( new [ ] {
276
+ dataset1 , dataset2 , dataset3
277
+ } ) ;
278
+
279
+ Assert . Equal ( 3 + 2 + 4 , dataset . Count ) ;
280
+
281
+ Assert . Equal ( ( 1 , 1 ) , dataset [ 0 ] ) ;
282
+ Assert . Equal ( ( 1 , 2 ) , dataset [ 1 ] ) ;
283
+ Assert . Equal ( ( 1 , 3 ) , dataset [ 2 ] ) ;
284
+ Assert . Equal ( ( 2 , 1 ) , dataset [ 3 ] ) ;
285
+ Assert . Equal ( ( 2 , 2 ) , dataset [ 4 ] ) ;
286
+ Assert . Equal ( ( 3 , 1 ) , dataset [ 5 ] ) ;
287
+ Assert . Equal ( ( 3 , 2 ) , dataset [ 6 ] ) ;
288
+ Assert . Equal ( ( 3 , 3 ) , dataset [ 7 ] ) ;
289
+ Assert . Equal ( ( 3 , 4 ) , dataset [ 8 ] ) ;
290
+
291
+ Assert . Equal ( ( 1 , 1 ) , dataset [ - 9 ] ) ;
292
+ Assert . Equal ( ( 1 , 2 ) , dataset [ - 8 ] ) ;
293
+ Assert . Equal ( ( 1 , 3 ) , dataset [ - 7 ] ) ;
294
+ Assert . Equal ( ( 2 , 1 ) , dataset [ - 6 ] ) ;
295
+ Assert . Equal ( ( 2 , 2 ) , dataset [ - 5 ] ) ;
296
+ Assert . Equal ( ( 3 , 1 ) , dataset [ - 4 ] ) ;
297
+ Assert . Equal ( ( 3 , 2 ) , dataset [ - 3 ] ) ;
298
+ Assert . Equal ( ( 3 , 3 ) , dataset [ - 2 ] ) ;
299
+ Assert . Equal ( ( 3 , 4 ) , dataset [ - 1 ] ) ;
300
+
301
+ Assert . False ( dataset1 . Disposed ) ;
302
+ Assert . False ( dataset2 . Disposed ) ;
303
+ Assert . False ( dataset3 . Disposed ) ;
304
+ dataset . Dispose ( ) ;
305
+ Assert . True ( dataset1 . Disposed ) ;
306
+ Assert . True ( dataset2 . Disposed ) ;
307
+ Assert . True ( dataset3 . Disposed ) ;
308
+ }
233
309
}
234
310
}
0 commit comments