@@ -264,15 +264,16 @@ public IterableDataLoader(
264
264
/// </summary>
265
265
public class DataLoader < T , S > : IEnumerable < S > , IDisposable
266
266
{
267
- private readonly Dataset < T > dataset ;
268
- private readonly int batchSize ;
269
- private readonly bool drop_last ;
270
- private readonly Device device ;
271
- private readonly IEnumerable < long > shuffler ;
272
- private readonly int num_worker ;
273
- private readonly Func < IEnumerable < T > , torch . Device , S > collate_fn ;
274
- private readonly bool disposeBatch ;
275
- private readonly bool disposeDataset ;
267
+ public Dataset < T > dataset { get ; }
268
+ public int batch_size { get ; }
269
+ public bool drop_last { get ; }
270
+ public IEnumerable < long > sampler { get ; }
271
+ public int num_workers { get ; }
272
+ public Func < IEnumerable < T > , Device , S > collate_fn { get ; }
273
+
274
+ public Device Device { get ; }
275
+ public bool DisposeBatch { get ; }
276
+ public bool DisposeDataset { get ; }
276
277
277
278
/// <summary>
278
279
/// Pytorch style dataloader
@@ -305,14 +306,14 @@ public DataLoader(
305
306
bool disposeDataset = true )
306
307
{
307
308
this . dataset = dataset ;
308
- this . batchSize = batchSize ;
309
+ this . batch_size = batchSize ;
309
310
this . drop_last = drop_last ;
310
- this . device = device ?? CPU ;
311
- this . shuffler = shuffler ;
312
- this . num_worker = Math . Max ( num_worker , 1 ) ;
311
+ this . Device = device ?? CPU ;
312
+ this . sampler = shuffler ;
313
+ this . num_workers = Math . Max ( num_worker , 1 ) ;
313
314
this . collate_fn = collate_fn ;
314
- this . disposeBatch = disposeBatch ;
315
- this . disposeDataset = disposeDataset ;
315
+ this . DisposeBatch = disposeBatch ;
316
+ this . DisposeDataset = disposeDataset ;
316
317
}
317
318
318
319
/// <summary>
@@ -368,7 +369,7 @@ static IEnumerable<long> LongRange(long count)
368
369
/// <summary>
369
370
/// Size of batch
370
371
/// </summary>
371
- public long Count => drop_last ? ( dataset . Count / batchSize ) : ( ( dataset . Count - 1 ) / batchSize + 1 ) ;
372
+ public long Count => drop_last ? ( dataset . Count / batch_size ) : ( ( dataset . Count - 1 ) / batch_size + 1 ) ;
372
373
373
374
public void Dispose ( )
374
375
{
@@ -378,7 +379,7 @@ public void Dispose()
378
379
379
380
protected virtual void Dispose ( bool disposing )
380
381
{
381
- if ( disposing && disposeDataset ) {
382
+ if ( disposing && DisposeDataset ) {
382
383
dataset . Dispose ( ) ;
383
384
}
384
385
}
@@ -411,7 +412,7 @@ public bool MoveNext()
411
412
{
412
413
DisposeCurrent ( ) ;
413
414
414
- var indices = Enumerable . Range ( 0 , loader . batchSize )
415
+ var indices = Enumerable . Range ( 0 , loader . batch_size )
415
416
. Select ( _ => shuffler . MoveNext ( ) ? shuffler . Current : ( long ? ) null )
416
417
. Where ( x => x . HasValue )
417
418
. Cast < long > ( )
@@ -420,25 +421,25 @@ public bool MoveNext()
420
421
if ( indices . Length is 0 )
421
422
return false ;
422
423
423
- if ( loader . drop_last && indices . Length < loader . batchSize ) {
424
+ if ( loader . drop_last && indices . Length < loader . batch_size ) {
424
425
return false ;
425
426
}
426
427
427
428
var tensors = new T [ indices . Length ] ;
428
429
var getTensorDisposables = new HashSet < IDisposable > [ indices . Length ] ;
429
430
Enumerable . Range ( 0 , indices . Length )
430
431
. AsParallel ( )
431
- . WithDegreeOfParallelism ( loader . num_worker )
432
+ . WithDegreeOfParallelism ( loader . num_workers )
432
433
. ForAll ( ( i ) => {
433
434
using var getTensorScope = torch . NewDisposeScope ( ) ;
434
435
tensors [ i ] = loader . dataset . GetTensor ( indices [ i ] ) ;
435
436
getTensorDisposables [ i ] = getTensorScope . DetachAllAndDispose ( ) ;
436
437
} ) ;
437
438
438
439
using var collateScope = torch . NewDisposeScope ( ) ;
439
- this . current = loader . collate_fn ( tensors , loader . device ) ;
440
+ this . current = loader . collate_fn ( tensors , loader . Device ) ;
440
441
var collateDisposables = collateScope . DetachAllAndDispose ( ) ;
441
- if ( loader . disposeBatch ) {
442
+ if ( loader . DisposeBatch ) {
442
443
this . currentDisposables = collateDisposables ;
443
444
}
444
445
@@ -456,7 +457,7 @@ public void Reset()
456
457
{
457
458
DisposeCurrent ( ) ;
458
459
shuffler ? . Dispose ( ) ;
459
- shuffler = loader . shuffler . GetEnumerator ( ) ;
460
+ shuffler = loader . sampler . GetEnumerator ( ) ;
460
461
}
461
462
462
463
S ? current ;
0 commit comments