3
3
using System ;
4
4
using System . Collections ;
5
5
using System . Collections . Generic ;
6
+ using System . Diagnostics . CodeAnalysis ;
6
7
using System . Diagnostics . SymbolStore ;
7
8
using System . Linq ;
8
9
using System . Threading ;
@@ -158,20 +159,21 @@ public IterableDataLoader(IterableDataset dataset, int batchSize, bool shuffle =
158
159
}
159
160
}
160
161
162
+ #nullable enable
161
163
/// <summary>
162
164
/// This class supports creating batches from data sets.
163
165
/// </summary>
164
166
public class DataLoader < T , S > : IEnumerable < S > , IDisposable
165
167
{
166
- private Dataset < T > dataset ;
167
- private int batchSize ;
168
- private bool shuffle ;
169
- private bool drop_last ;
170
- private Device device ;
171
- private IEnumerable < long > shuffler ;
172
- private int num_worker ;
173
- private Func < IEnumerable < T > , torch . Device , S > collate_fn ;
174
- private bool autoDispose ;
168
+ private readonly Dataset < T > dataset ;
169
+ private readonly int batchSize ;
170
+ private readonly bool drop_last ;
171
+ private readonly Device device ;
172
+ private readonly IEnumerable < long > shuffler ;
173
+ private readonly int num_worker ;
174
+ private readonly Func < IEnumerable < T > , torch . Device , S > collate_fn ;
175
+ private readonly bool disposeBatch ;
176
+ private readonly bool disposeDataset ;
175
177
176
178
/// <summary>
177
179
/// Pytorch style dataloader
@@ -180,34 +182,38 @@ public class DataLoader<T, S> : IEnumerable<S>, IDisposable
180
182
/// <param name="batchSize">Size of batch</param>
181
183
/// <param name="collate_fn">Callback to merge items make to a batch</param>
182
184
/// <param name="device">device for output tensor</param>
183
- /// <param name="shuffler">Shuffler for dataloader</param>
185
+ /// <param name="shuffler">Shuffler for dataloader. </param>
184
186
/// <param name="num_worker">Count of worker</param>
185
187
/// <param name="drop_last">
186
188
/// Set to true to drop the last incomplete batch, if the dataset size is not divisible by the batch size.
187
189
/// If alse and the size of dataset is not divisible by the batch size, then the last batch will be smaller.
188
190
/// </param>
189
- /// <param name="autoDispose">
190
- /// Indicates whether to automatically dispose the collated tensors (a batch) after an iteration.
191
+ /// <param name="disposeBatch">
192
+ /// Indicates whether to automatically dispose the collated tensors after an iteration.
193
+ /// </param>
194
+ /// <param name="disposeDataset">
195
+ /// Indicates whether to dispose the dataset when being disposed.
191
196
/// </param>
192
197
public DataLoader (
193
198
Dataset < T > dataset ,
194
199
int batchSize ,
195
200
Func < IEnumerable < T > , torch . Device , S > collate_fn ,
196
201
IEnumerable < long > shuffler ,
197
- Device device = null ,
202
+ Device ? device = null ,
198
203
int num_worker = 1 ,
199
204
bool drop_last = false ,
200
- bool autoDispose = true )
205
+ bool disposeBatch = true ,
206
+ bool disposeDataset = true )
201
207
{
202
208
this . dataset = dataset ;
203
209
this . batchSize = batchSize ;
204
- this . shuffle = true ;
205
210
this . drop_last = drop_last ;
206
211
this . device = device ?? CPU ;
207
212
this . shuffler = shuffler ;
208
- this . num_worker = num_worker ;
213
+ this . num_worker = Math . Max ( num_worker , 1 ) ;
209
214
this . collate_fn = collate_fn ;
210
- this . autoDispose = autoDispose ;
215
+ this . disposeBatch = disposeBatch ;
216
+ this . disposeDataset = disposeDataset ;
211
217
}
212
218
213
219
/// <summary>
@@ -224,39 +230,39 @@ public DataLoader(
224
230
/// Set to true to drop the last incomplete batch, if the dataset size is not divisible by the batch size.
225
231
/// If alse and the size of dataset is not divisible by the batch size, then the last batch will be smaller.
226
232
/// </param>
227
- /// <param name="autoDispose ">
233
+ /// <param name="disposeBatch ">
228
234
/// Indicates whether to automatically dispose the collated tensors (a batch) after an iteration.
229
235
/// </param>
236
+ /// <param name="disposeDataset">
237
+ /// Indicates whether to dispose the dataset when being disposed.
238
+ /// </param>
230
239
public DataLoader (
231
240
Dataset < T > dataset ,
232
241
int batchSize ,
233
242
Func < IEnumerable < T > , torch . Device , S > collate_fn ,
234
243
bool shuffle = false ,
235
- Device device = null ,
244
+ Device ? device = null ,
236
245
int ? seed = null ,
237
246
int num_worker = 1 ,
238
247
bool drop_last = false ,
239
- bool autoDispose = true )
248
+ bool disposeBatch = true ,
249
+ bool disposeDataset = true ) :
250
+ this ( dataset , batchSize , collate_fn ,
251
+ shuffle ? new FisherYatesShuffler ( dataset . Count , seed ) : LongRange ( dataset . Count ) ,
252
+ device , num_worker , drop_last , disposeBatch , disposeDataset )
253
+ { }
254
+
255
+ static IEnumerable < long > LongRange ( long count )
240
256
{
241
- this . dataset = dataset ;
242
- this . batchSize = batchSize ;
243
- this . shuffle = shuffle ;
244
- this . drop_last = drop_last ;
245
- this . device = device ?? CPU ;
246
- this . shuffler = seed is null ? new FisherYatesShuffler ( dataset . Count ) : new FisherYatesShuffler ( dataset . Count , seed ) ;
247
- this . num_worker = num_worker ;
248
- this . collate_fn = collate_fn ;
249
- this . autoDispose = autoDispose ;
257
+ for ( long i = 0 ; i < count ; i ++ )
258
+ yield return i ;
250
259
}
251
260
252
261
/// <summary>
253
262
/// Generate enumerator
254
263
/// </summary>
255
264
/// <returns>Enumerator for batch</returns>
256
- public IEnumerator < S > GetEnumerator ( ) =>
257
- new DataLoaderEnumerator (
258
- dataset , batchSize , shuffle , device ,
259
- shuffler , num_worker , collate_fn , autoDispose ) ;
265
+ public IEnumerator < S > GetEnumerator ( ) => new DataLoaderEnumerator ( this ) ;
260
266
261
267
IEnumerator IEnumerable . GetEnumerator ( ) => GetEnumerator ( ) ;
262
268
@@ -265,50 +271,38 @@ public IEnumerator<S> GetEnumerator() =>
265
271
/// </summary>
266
272
public long Count => drop_last ? ( dataset . Count / batchSize ) : ( ( dataset . Count - 1 ) / batchSize + 1 ) ;
267
273
274
+ public void Dispose ( )
275
+ {
276
+ Dispose ( true ) ;
277
+ GC . SuppressFinalize ( this ) ;
278
+ }
279
+
280
+ protected virtual void Dispose ( bool disposing )
281
+ {
282
+ if ( disposing && disposeDataset ) {
283
+ dataset . Dispose ( ) ;
284
+ }
285
+ }
286
+
268
287
private class DataLoaderEnumerator : IEnumerator < S >
269
288
{
270
- private Dataset < T > dataset ;
271
- private int batchSize ;
272
- private Device device ;
273
- private bool shuffle ;
274
- private IEnumerable < long > shuffleEnumerable ;
289
+ private readonly DataLoader < T , S > loader ;
275
290
private IEnumerator < long > shuffler ;
276
- private long currentVal = 0 ;
277
- private int num_worker = 0 ;
278
- private List < IDisposable > currentDisposables ;
279
- private Func < IEnumerable < T > , torch . Device , S > collate_fn ;
280
- public DataLoaderEnumerator (
281
- Dataset < T > dataset ,
282
- int batchSize ,
283
- bool shuffle ,
284
- Device device ,
285
- IEnumerable < long > shuffleEnumerable ,
286
- int num_worker ,
287
- Func < IEnumerable < T > , torch . Device , S > collate_fn ,
288
- bool autoDispose )
291
+ private List < IDisposable > ? currentDisposables ;
292
+ public DataLoaderEnumerator ( DataLoader < T , S > loader )
289
293
{
290
- this . dataset = dataset ;
291
- this . batchSize = batchSize ;
292
- this . device = device ;
293
- this . shuffle = shuffle ;
294
- this . shuffleEnumerable = shuffleEnumerable ;
295
- if ( num_worker < 1 ) num_worker = 1 ;
296
- this . num_worker = num_worker ;
297
- this . collate_fn = collate_fn ;
298
- this . currentDisposables = autoDispose ? new List < IDisposable > ( ) : null ;
294
+ this . loader = loader ;
295
+ if ( loader . disposeBatch )
296
+ this . currentDisposables = new List < IDisposable > ( ) ;
299
297
Reset ( ) ;
300
298
}
301
299
302
- private bool MoveNextValue ( )
300
+ private long ? MoveNextValue ( )
303
301
{
304
- if ( shuffle ) {
305
- if ( ! shuffler . MoveNext ( ) ) return false ;
306
- currentVal = shuffler . Current ;
307
- return true ;
308
- } else {
309
- currentVal ++ ;
310
- return currentVal < dataset . Count ;
302
+ if ( ! shuffler . MoveNext ( ) ) {
303
+ return null ;
311
304
}
305
+ return shuffler . Current ;
312
306
}
313
307
314
308
/// <summary>
@@ -318,107 +312,76 @@ private bool MoveNextValue()
318
312
public bool MoveNext ( )
319
313
{
320
314
DisposeCurrent ( ) ;
321
- using ( var scope = DisposeScopeManager . NewDisposeScope ( ) ) {
322
- if ( ! MoveNextValue ( ) ) return false ;
323
315
324
- var tensorIndexList = new List < long > { currentVal } ;
325
- for ( int i = 1 ; i < batchSize ; i ++ ) {
326
- if ( ! MoveNextValue ( ) ) break ;
327
- tensorIndexList . Add ( currentVal ) ;
316
+ using ( var scope = torch . NewDisposeScope ( ) ) {
317
+ var indices = Enumerable . Range ( 0 , loader . batchSize )
318
+ . Select ( _ => MoveNextValue ( ) )
319
+ . Where ( x => x . HasValue )
320
+ . Cast < long > ( )
321
+ . ToArray ( ) ;
322
+ if ( loader . drop_last && indices . Length < loader . batchSize ) {
323
+ return false ;
328
324
}
329
325
330
- var items = new List < T > ( new T [ tensorIndexList . Count ] ) ;
331
- var taskedBatchCount = 0 ;
332
-
333
- //Run Async
334
- var tasks = new List < Task > ( ) ;
335
- foreach ( var _ in Enumerable . Range ( 1 , num_worker - 1 ) )
336
- tasks . Add ( new ( ProcessPendingBatches ) ) ;
337
- tasks . ForEach ( x => x . Start ( ) ) ;
338
-
339
- ProcessPendingBatches ( ) ;
326
+ var tensors = new T [ indices . Length ] ;
327
+ Enumerable . Range ( 0 , indices . Length )
328
+ . AsParallel ( )
329
+ . WithDegreeOfParallelism ( loader . num_worker )
330
+ . ForAll ( ( i ) => {
331
+ tensors [ i ] = loader . dataset . GetTensor ( indices [ i ] ) ;
332
+ } ) ;
340
333
341
- foreach ( var task in tasks )
342
- task . Wait ( ) ;
343
-
344
- if ( this . currentDisposables is not null ) {
334
+ if ( this . currentDisposables is null ) {
335
+ current = loader . collate_fn ( tensors , loader . device ) ;
336
+ }
337
+ else {
345
338
using ( var collate_scope = DisposeScopeManager . NewDisposeScope ( ) ) {
346
- Current = collate_fn ( items , device ) ;
339
+ current = loader . collate_fn ( tensors , loader . device ) ;
347
340
currentDisposables . AddRange ( collate_scope . DisposablesView ) ;
348
341
collate_scope . Detach ( currentDisposables ) ;
349
342
}
350
343
}
351
- else {
352
- Current = collate_fn ( items , device ) ;
353
- }
354
344
355
- foreach ( var item in items ) {
356
- dataset . DisposeTensor ( item ) ;
345
+ foreach ( var item in tensors ) {
346
+ loader . dataset . DisposeTensor ( item ) ;
357
347
}
358
348
359
349
return true ;
360
-
361
- void ProcessPendingBatches ( )
362
- {
363
- while ( true ) {
364
- var idx = ScheduleBatch ( ) ;
365
- if ( idx is null ) break ;
366
- items [ idx . Value . Item1 ] = dataset . GetTensor ( idx . Value . Item2 ) ;
367
- }
368
- }
369
-
370
- ( int , long ) ? ScheduleBatch ( )
371
- {
372
- var t = Interlocked . Increment ( ref taskedBatchCount ) - 1 ;
373
- if ( t < tensorIndexList . Count )
374
- return ( t , tensorIndexList [ t ] ) ;
375
- return null ;
376
- }
377
350
}
378
351
}
379
352
380
353
/// <summary>
381
354
/// Reset enumerator
382
355
/// </summary>
356
+ [ MemberNotNull ( nameof ( shuffler ) ) ]
383
357
public void Reset ( )
384
358
{
385
359
DisposeCurrent ( ) ;
386
- if ( shuffle ) shuffler = shuffleEnumerable . GetEnumerator ( ) ;
387
- currentVal = - 1 ;
360
+ shuffler ? . Dispose ( ) ;
361
+ shuffler = loader . shuffler . GetEnumerator ( ) ;
388
362
}
389
363
364
+ S ? current ;
390
365
/// <summary>
391
366
/// Current tensor
392
367
/// </summary>
393
- public S Current { get ; private set ; }
368
+ public S Current => current ! ;
394
369
395
- object IEnumerator . Current => Current ;
370
+ object IEnumerator . Current => current ! ;
396
371
397
372
public void Dispose ( )
398
373
{
374
+ shuffler . Dispose ( ) ;
399
375
DisposeCurrent ( ) ;
400
376
}
401
377
402
378
private void DisposeCurrent ( )
403
379
{
404
- if ( currentDisposables is null ) return ;
380
+ if ( currentDisposables is null )
381
+ return ;
405
382
foreach ( var x in currentDisposables )
406
383
x . Dispose ( ) ;
407
384
currentDisposables . Clear ( ) ;
408
- shuffler ? . Dispose ( ) ;
409
- }
410
- }
411
-
412
- public void Dispose ( )
413
- {
414
- Dispose ( true ) ;
415
- GC . SuppressFinalize ( this ) ;
416
- }
417
-
418
- protected virtual void Dispose ( bool disposing )
419
- {
420
- if ( disposing ) {
421
- dataset . Dispose ( ) ;
422
385
}
423
386
}
424
387
}
0 commit comments