@@ -17,8 +17,9 @@ public static partial class utils
17
17
{
18
18
public static partial class data
19
19
{
20
+
20
21
public static Modules . DataLoader DataLoader (
21
- IDataset < IReadOnlyDictionary < string , torch . Tensor > > dataset ,
22
+ Dataset dataset ,
22
23
int batchSize , IEnumerable < long > shuffler ,
23
24
Device device = null ,
24
25
int num_worker = 1 , bool drop_last = false ,
@@ -33,7 +34,7 @@ public static Modules.DataLoader DataLoader(
33
34
}
34
35
35
36
public static Modules . DataLoader DataLoader (
36
- IDataset < IReadOnlyDictionary < string , torch . Tensor > > dataset ,
37
+ Dataset dataset ,
37
38
int batchSize , bool shuffle = false ,
38
39
Device device = null , int ? seed = null ,
39
40
int num_worker = 1 , bool drop_last = false ,
@@ -48,7 +49,7 @@ public static Modules.DataLoader DataLoader(
48
49
}
49
50
50
51
public static Modules . IterableDataLoader DataLoader (
51
- IDataset < IEnumerable < Tensor > > dataset ,
52
+ IterableDataset dataset ,
52
53
int batchSize , IEnumerable < long > shuffler ,
53
54
Device device = null ,
54
55
int num_worker = 1 , bool drop_last = false ,
@@ -63,7 +64,7 @@ public static Modules.IterableDataLoader DataLoader(
63
64
}
64
65
65
66
public static Modules . IterableDataLoader DataLoader (
66
- IDataset < IEnumerable < Tensor > > dataset ,
67
+ IterableDataset dataset ,
67
68
int batchSize , bool shuffle = false ,
68
69
Device device = null , int ? seed = null ,
69
70
int num_worker = 1 , bool drop_last = false ,
@@ -89,8 +90,7 @@ namespace Modules
89
90
/// Data loader. Combines a dataset and a sampler, and provides an enumerator over the given dataset.
90
91
/// </summary>
91
92
/// <remarks>This class is used for map-style data sets</remarks>
92
- public class DataLoader : DataLoader < IReadOnlyDictionary < string , torch . Tensor > ,
93
- Dictionary < string , torch . Tensor > >
93
+ public class DataLoader : DataLoader < Dictionary < string , torch . Tensor > , Dictionary < string , torch . Tensor > >
94
94
{
95
95
/// <summary>
96
96
/// Pytorch style dataloader
@@ -111,7 +111,7 @@ public class DataLoader : DataLoader<IReadOnlyDictionary<string, torch.Tensor>,
111
111
/// Indicates whether to dispose the dataset when being disposed.
112
112
/// </param>
113
113
public DataLoader (
114
- IDataset < IReadOnlyDictionary < string , torch . Tensor > > dataset ,
114
+ Dataset dataset ,
115
115
int batchSize , IEnumerable < long > shuffler ,
116
116
Device device = null ,
117
117
int num_worker = 1 , bool drop_last = false ,
@@ -144,7 +144,7 @@ public DataLoader(
144
144
/// Indicates whether to dispose the dataset when being disposed.
145
145
/// </param>
146
146
public DataLoader (
147
- IDataset < IReadOnlyDictionary < string , torch . Tensor > > dataset ,
147
+ Dataset dataset ,
148
148
int batchSize , bool shuffle = false ,
149
149
Device device = null , int ? seed = null ,
150
150
int num_worker = 1 , bool drop_last = false ,
@@ -157,8 +157,7 @@ public DataLoader(
157
157
{
158
158
}
159
159
160
- private static Dictionary < string , torch . Tensor > Collate (
161
- IEnumerable < IReadOnlyDictionary < string , torch . Tensor > > dic , torch . Device device )
160
+ private static Dictionary < string , torch . Tensor > Collate ( IEnumerable < Dictionary < string , torch . Tensor > > dic , torch . Device device )
162
161
{
163
162
using ( torch . NewDisposeScope ( ) ) {
164
163
Dictionary < string , torch . Tensor > batch = new ( ) ;
@@ -177,8 +176,7 @@ public DataLoader(
177
176
/// Data loader. Combines a dataset and a sampler, and provides an enumerator over the given dataset.
178
177
/// </summary>
179
178
/// <remarks>This class is used for list-style data sets</remarks>
180
- public class IterableDataLoader :
181
- DataLoader < IEnumerable < torch . Tensor > , IList < torch . Tensor > >
179
+ public class IterableDataLoader : DataLoader < IList < torch . Tensor > , IList < torch . Tensor > >
182
180
{
183
181
/// <summary>
184
182
/// Pytorch style dataloader
@@ -199,7 +197,7 @@ public class IterableDataLoader :
199
197
/// Indicates whether to dispose the dataset when being disposed.
200
198
/// </param>
201
199
public IterableDataLoader (
202
- IDataset < IEnumerable < Tensor > > dataset ,
200
+ IterableDataset dataset ,
203
201
int batchSize , IEnumerable < long > shuffler ,
204
202
Device device = null ,
205
203
int num_worker = 1 , bool drop_last = false ,
@@ -232,7 +230,7 @@ public IterableDataLoader(
232
230
/// Indicates whether to dispose the dataset when being disposed.
233
231
/// </param>
234
232
public IterableDataLoader (
235
- IDataset < IEnumerable < Tensor > > dataset ,
233
+ IterableDataset dataset ,
236
234
int batchSize , bool shuffle = false ,
237
235
Device device = null , int ? seed = null ,
238
236
int num_worker = 1 , bool drop_last = false ,
@@ -245,18 +243,12 @@ public IterableDataLoader(
245
243
{
246
244
}
247
245
248
- private static IList < torch . Tensor > Collate (
249
- IReadOnlyList < IEnumerable < torch . Tensor > > dic , torch . Device device )
246
+ private static IList < torch . Tensor > Collate ( IEnumerable < IList < torch . Tensor > > dic , torch . Device device )
250
247
{
251
- var dicCopy = new List < torch . Tensor [ ] > ( ) ;
252
- foreach ( var e in dic ) {
253
- dicCopy . Add ( e . ToArray ( ) ) ;
254
- }
255
-
256
248
using ( torch . NewDisposeScope ( ) ) {
257
249
List < torch . Tensor > batch = new ( ) ;
258
- for ( var x = 0 ; x < dicCopy [ 0 ] . Length ; x ++ ) {
259
- var t = cat ( dicCopy . Select ( k => k [ x ] . unsqueeze ( 0 ) ) . ToArray ( ) , 0 ) ;
250
+ for ( var x = 0 ; x < dic . First ( ) . Count ; x ++ ) {
251
+ var t = cat ( dic . Select ( k => k [ x ] . unsqueeze ( 0 ) ) . ToArray ( ) , 0 ) ;
260
252
if ( t . device_type != device . type || t . device_index != device . index )
261
253
t = t . to ( device ) ;
262
254
batch . Add ( t . MoveToOuterDisposeScope ( ) ) ;
@@ -272,12 +264,12 @@ public IterableDataLoader(
272
264
/// </summary>
273
265
public class DataLoader < T , S > : IEnumerable < S > , IDisposable
274
266
{
275
- public IDataset < T > dataset { get ; }
267
+ public Dataset < T > dataset { get ; }
276
268
public int batch_size { get ; }
277
269
public bool drop_last { get ; }
278
270
public IEnumerable < long > sampler { get ; }
279
271
public int num_workers { get ; }
280
- public Func < IReadOnlyList < T > , Device , S > collate_fn { get ; }
272
+ public Func < IEnumerable < T > , Device , S > collate_fn { get ; }
281
273
282
274
public Device Device { get ; }
283
275
public bool DisposeBatch { get ; }
@@ -303,9 +295,9 @@ public class DataLoader<T, S> : IEnumerable<S>, IDisposable
303
295
/// Indicates whether to dispose the dataset when being disposed.
304
296
/// </param>
305
297
public DataLoader (
306
- IDataset < T > dataset ,
298
+ Dataset < T > dataset ,
307
299
int batchSize ,
308
- Func < IReadOnlyList < T > , torch . Device , S > collate_fn ,
300
+ Func < IEnumerable < T > , torch . Device , S > collate_fn ,
309
301
IEnumerable < long > shuffler ,
310
302
Device ? device = null ,
311
303
int num_worker = 1 ,
@@ -345,9 +337,9 @@ public DataLoader(
345
337
/// Indicates whether to dispose the dataset when being disposed.
346
338
/// </param>
347
339
public DataLoader (
348
- IDataset < T > dataset ,
340
+ Dataset < T > dataset ,
349
341
int batchSize ,
350
- Func < IReadOnlyList < T > , torch . Device , S > collate_fn ,
342
+ Func < IEnumerable < T > , torch . Device , S > collate_fn ,
351
343
bool shuffle = false ,
352
344
Device ? device = null ,
353
345
int ? seed = null ,
@@ -440,7 +432,7 @@ public bool MoveNext()
440
432
. WithDegreeOfParallelism ( loader . num_workers )
441
433
. ForAll ( ( i ) => {
442
434
using var getTensorScope = torch . NewDisposeScope ( ) ;
443
- tensors [ i ] = loader . dataset [ indices [ i ] ] ;
435
+ tensors [ i ] = loader . dataset . GetTensor ( indices [ i ] ) ;
444
436
getTensorDisposables [ i ] = getTensorScope . DetachAllAndDispose ( ) ;
445
437
} ) ;
446
438
0 commit comments