1
+ // Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
2
+ using System ;
3
+ using System . Collections . Generic ;
4
+ using System . Linq ;
5
+ using TorchSharp . Modules ;
6
+
7
+ namespace TorchSharp
8
+ {
9
+ public static partial class torch
10
+ {
11
+ public static partial class utils
12
+ {
13
+ public static partial class data
14
+ {
15
+ public static ConcatDataset < T > ConcatDataset < T > ( IEnumerable < IDataset < T > > datasets )
16
+ {
17
+ return new ConcatDataset < T > ( datasets ) ;
18
+ }
19
+ }
20
+ }
21
+ }
22
+
23
+ namespace Modules
24
+ {
25
+ public class ConcatDataset < T > : torch . utils . data . Dataset < T >
26
+ {
27
+ private static IEnumerable < long > Cumsum (
28
+ IEnumerable < torch . utils . data . IDataset < T > > datasets )
29
+ {
30
+ var s = 0L ;
31
+ foreach ( var e in datasets ) {
32
+ s += e . Count ;
33
+ yield return s ;
34
+ }
35
+ }
36
+ private static long bisectRight ( long [ ] a , long x )
37
+ {
38
+ var lo = 0 ;
39
+ var hi = a . Length ;
40
+ while ( lo < hi ) {
41
+ var mid = ( lo + hi ) / 2 ;
42
+ if ( x < a [ mid ] )
43
+ hi = mid ;
44
+ else
45
+ lo = mid + 1 ;
46
+ }
47
+ return lo ;
48
+ }
49
+
50
+
51
+ private readonly torch . utils . data . IDataset < T > [ ] _datasets ;
52
+ public IReadOnlyList < torch . utils . data . IDataset < T > > datasets => _datasets ;
53
+
54
+ private readonly long [ ] _cumulativeSizes ;
55
+ public IReadOnlyList < long > cumulative_sizes => _cumulativeSizes ;
56
+
57
+ private readonly bool autoDispose ;
58
+
59
+ public ConcatDataset (
60
+ IEnumerable < torch . utils . data . IDataset < T > > datasets ,
61
+ bool autoDispose = true )
62
+ {
63
+ this . _datasets = datasets . ToArray ( ) ;
64
+ if ( this . _datasets . Length is 0 )
65
+ throw new ArgumentException (
66
+ "datasets should not be an empty iterable" , nameof ( datasets ) ) ;
67
+
68
+ // PyTorch also says 'ConcatDataset does not support IterableDataset'.
69
+ // But it's not our torch.utils.data.IterableDataset in TorchSharp.
70
+ this . _cumulativeSizes = Cumsum ( datasets ) . ToArray ( ) ;
71
+
72
+ this . autoDispose = autoDispose ;
73
+ }
74
+
75
+ public override long Count => this . _cumulativeSizes . Last ( ) ;
76
+
77
+ public override T GetTensor ( long index )
78
+ {
79
+ if ( index < 0 ) {
80
+ if ( - index > this . Count ) {
81
+ throw new ArgumentException (
82
+ "absolute value of index should not exceed dataset length" ,
83
+ nameof ( index ) ) ;
84
+ }
85
+ index = this . Count + index ;
86
+ }
87
+
88
+ var datasetIdx = bisectRight ( this . _cumulativeSizes , index ) ;
89
+ long sampleIdx ;
90
+ if ( datasetIdx == 0 )
91
+ sampleIdx = index ;
92
+ else
93
+ sampleIdx = index - this . _cumulativeSizes [ datasetIdx - 1 ] ;
94
+ return this . _datasets [ datasetIdx ] [ sampleIdx ] ;
95
+ }
96
+
97
+ protected override void Dispose ( bool disposing )
98
+ {
99
+ if ( disposing && autoDispose ) {
100
+ foreach ( var dataset in this . _datasets )
101
+ dataset . Dispose ( ) ;
102
+ }
103
+
104
+ base . Dispose ( disposing ) ;
105
+ }
106
+ }
107
+ }
108
+ }
0 commit comments