@@ -17,9 +17,8 @@ public static partial class utils
17
17
{
18
18
public static partial class data
19
19
{
20
-
21
20
public static Modules . DataLoader DataLoader (
22
- Dataset dataset ,
21
+ IDataset < IReadOnlyDictionary < string , torch . Tensor > > dataset ,
23
22
int batchSize , IEnumerable < long > shuffler ,
24
23
Device device = null ,
25
24
int num_worker = 1 , bool drop_last = false ,
@@ -34,7 +33,7 @@ public static Modules.DataLoader DataLoader(
34
33
}
35
34
36
35
public static Modules . DataLoader DataLoader (
37
- Dataset dataset ,
36
+ IDataset < IReadOnlyDictionary < string , torch . Tensor > > dataset ,
38
37
int batchSize , bool shuffle = false ,
39
38
Device device = null , int ? seed = null ,
40
39
int num_worker = 1 , bool drop_last = false ,
@@ -49,7 +48,7 @@ public static Modules.DataLoader DataLoader(
49
48
}
50
49
51
50
public static Modules . IterableDataLoader DataLoader (
52
- IterableDataset dataset ,
51
+ IDataset < IEnumerable < Tensor > > dataset ,
53
52
int batchSize , IEnumerable < long > shuffler ,
54
53
Device device = null ,
55
54
int num_worker = 1 , bool drop_last = false ,
@@ -64,7 +63,7 @@ public static Modules.IterableDataLoader DataLoader(
64
63
}
65
64
66
65
public static Modules . IterableDataLoader DataLoader (
67
- IterableDataset dataset ,
66
+ IDataset < IEnumerable < Tensor > > dataset ,
68
67
int batchSize , bool shuffle = false ,
69
68
Device device = null , int ? seed = null ,
70
69
int num_worker = 1 , bool drop_last = false ,
@@ -90,7 +89,8 @@ namespace Modules
90
89
/// Data loader. Combines a dataset and a sampler, and provides an enumerator over the given dataset.
91
90
/// </summary>
92
91
/// <remarks>This class is used for map-style data sets</remarks>
93
- public class DataLoader : DataLoader < Dictionary < string , torch . Tensor > , Dictionary < string , torch . Tensor > >
92
+ public class DataLoader : DataLoader < IReadOnlyDictionary < string , torch . Tensor > ,
93
+ Dictionary < string , torch . Tensor > >
94
94
{
95
95
/// <summary>
96
96
/// Pytorch style dataloader
@@ -111,7 +111,7 @@ public class DataLoader : DataLoader<Dictionary<string, torch.Tensor>, Dictionar
111
111
/// Indicates whether to dispose the dataset when being disposed.
112
112
/// </param>
113
113
public DataLoader (
114
- Dataset dataset ,
114
+ IDataset < IReadOnlyDictionary < string , torch . Tensor > > dataset ,
115
115
int batchSize , IEnumerable < long > shuffler ,
116
116
Device device = null ,
117
117
int num_worker = 1 , bool drop_last = false ,
@@ -144,7 +144,7 @@ public DataLoader(
144
144
/// Indicates whether to dispose the dataset when being disposed.
145
145
/// </param>
146
146
public DataLoader (
147
- Dataset dataset ,
147
+ IDataset < IReadOnlyDictionary < string , torch . Tensor > > dataset ,
148
148
int batchSize , bool shuffle = false ,
149
149
Device device = null , int ? seed = null ,
150
150
int num_worker = 1 , bool drop_last = false ,
@@ -157,7 +157,8 @@ public DataLoader(
157
157
{
158
158
}
159
159
160
- private static Dictionary < string , torch . Tensor > Collate ( IEnumerable < Dictionary < string , torch . Tensor > > dic , torch . Device device )
160
+ private static Dictionary < string , torch . Tensor > Collate (
161
+ IEnumerable < IReadOnlyDictionary < string , torch . Tensor > > dic , torch . Device device )
161
162
{
162
163
using ( torch . NewDisposeScope ( ) ) {
163
164
Dictionary < string , torch . Tensor > batch = new ( ) ;
@@ -176,7 +177,8 @@ public DataLoader(
176
177
/// Data loader. Combines a dataset and a sampler, and provides an enumerator over the given dataset.
177
178
/// </summary>
178
179
/// <remarks>This class is used for list-style data sets</remarks>
179
- public class IterableDataLoader : DataLoader < IList < torch . Tensor > , IList < torch . Tensor > >
180
+ public class IterableDataLoader :
181
+ DataLoader < IEnumerable < torch . Tensor > , IList < torch . Tensor > >
180
182
{
181
183
/// <summary>
182
184
/// Pytorch style dataloader
@@ -197,7 +199,7 @@ public class IterableDataLoader : DataLoader<IList<torch.Tensor>, IList<torch.Te
197
199
/// Indicates whether to dispose the dataset when being disposed.
198
200
/// </param>
199
201
public IterableDataLoader (
200
- IterableDataset dataset ,
202
+ IDataset < IEnumerable < Tensor > > dataset ,
201
203
int batchSize , IEnumerable < long > shuffler ,
202
204
Device device = null ,
203
205
int num_worker = 1 , bool drop_last = false ,
@@ -230,7 +232,7 @@ public IterableDataLoader(
230
232
/// Indicates whether to dispose the dataset when being disposed.
231
233
/// </param>
232
234
public IterableDataLoader (
233
- IterableDataset dataset ,
235
+ IDataset < IEnumerable < Tensor > > dataset ,
234
236
int batchSize , bool shuffle = false ,
235
237
Device device = null , int ? seed = null ,
236
238
int num_worker = 1 , bool drop_last = false ,
@@ -243,12 +245,18 @@ public IterableDataLoader(
243
245
{
244
246
}
245
247
246
- private static IList < torch . Tensor > Collate ( IEnumerable < IList < torch . Tensor > > dic , torch . Device device )
248
+ private static IList < torch . Tensor > Collate (
249
+ IReadOnlyList < IEnumerable < torch . Tensor > > dic , torch . Device device )
247
250
{
251
+ var dicCopy = new List < torch . Tensor [ ] > ( ) ;
252
+ foreach ( var e in dic ) {
253
+ dicCopy . Add ( e . ToArray ( ) ) ;
254
+ }
255
+
248
256
using ( torch . NewDisposeScope ( ) ) {
249
257
List < torch . Tensor > batch = new ( ) ;
250
- for ( var x = 0 ; x < dic . First ( ) . Count ; x ++ ) {
251
- var t = cat ( dic . Select ( k => k [ x ] . unsqueeze ( 0 ) ) . ToArray ( ) , 0 ) ;
258
+ for ( var x = 0 ; x < dicCopy [ 0 ] . Length ; x ++ ) {
259
+ var t = cat ( dicCopy . Select ( k => k [ x ] . unsqueeze ( 0 ) ) . ToArray ( ) , 0 ) ;
252
260
if ( t . device_type != device . type || t . device_index != device . index )
253
261
t = t . to ( device ) ;
254
262
batch . Add ( t . MoveToOuterDisposeScope ( ) ) ;
@@ -264,12 +272,12 @@ public IterableDataLoader(
264
272
/// </summary>
265
273
public class DataLoader < T , S > : IEnumerable < S > , IDisposable
266
274
{
267
- public Dataset < T > dataset { get ; }
275
+ public IDataset < T > dataset { get ; }
268
276
public int batch_size { get ; }
269
277
public bool drop_last { get ; }
270
278
public IEnumerable < long > sampler { get ; }
271
279
public int num_workers { get ; }
272
- public Func < IEnumerable < T > , Device , S > collate_fn { get ; }
280
+ public Func < IReadOnlyList < T > , Device , S > collate_fn { get ; }
273
281
274
282
public Device Device { get ; }
275
283
public bool DisposeBatch { get ; }
@@ -295,9 +303,9 @@ public class DataLoader<T, S> : IEnumerable<S>, IDisposable
295
303
/// Indicates whether to dispose the dataset when being disposed.
296
304
/// </param>
297
305
public DataLoader (
298
- Dataset < T > dataset ,
306
+ IDataset < T > dataset ,
299
307
int batchSize ,
300
- Func < IEnumerable < T > , torch . Device , S > collate_fn ,
308
+ Func < IReadOnlyList < T > , torch . Device , S > collate_fn ,
301
309
IEnumerable < long > shuffler ,
302
310
Device ? device = null ,
303
311
int num_worker = 1 ,
@@ -337,9 +345,9 @@ public DataLoader(
337
345
/// Indicates whether to dispose the dataset when being disposed.
338
346
/// </param>
339
347
public DataLoader (
340
- Dataset < T > dataset ,
348
+ IDataset < T > dataset ,
341
349
int batchSize ,
342
- Func < IEnumerable < T > , torch . Device , S > collate_fn ,
350
+ Func < IReadOnlyList < T > , torch . Device , S > collate_fn ,
343
351
bool shuffle = false ,
344
352
Device ? device = null ,
345
353
int ? seed = null ,
@@ -432,7 +440,7 @@ public bool MoveNext()
432
440
. WithDegreeOfParallelism ( loader . num_workers )
433
441
. ForAll ( ( i ) => {
434
442
using var getTensorScope = torch . NewDisposeScope ( ) ;
435
- tensors [ i ] = loader . dataset . GetTensor ( indices [ i ] ) ;
443
+ tensors [ i ] = loader . dataset [ indices [ i ] ] ;
436
444
getTensorDisposables [ i ] = getTensorScope . DetachAllAndDispose ( ) ;
437
445
} ) ;
438
446
0 commit comments