@@ -396,12 +396,11 @@ public DataLoaderEnumerator(DataLoader<T, S> loader)
396
396
Reset ( ) ;
397
397
}
398
398
399
- private long ? MoveNextValue ( )
399
+ private static void DisposeAll ( HashSet < IDisposable > disposables )
400
400
{
401
- if ( ! shuffler . MoveNext ( ) ) {
402
- return null ;
401
+ foreach ( var disposable in disposables ) {
402
+ disposable . Dispose ( ) ;
403
403
}
404
- return shuffler . Current ;
405
404
}
406
405
407
406
/// <summary>
@@ -412,35 +411,42 @@ public bool MoveNext()
412
411
{
413
412
DisposeCurrent ( ) ;
414
413
415
- using ( var scope = torch . NewDisposeScope ( ) ) {
416
- var indices = Enumerable . Range ( 0 , loader . batchSize )
417
- . Select ( _ => MoveNextValue ( ) )
418
- . Where ( x => x . HasValue )
419
- . Cast < long > ( )
420
- . ToArray ( ) ;
421
- if ( indices . Length is 0 )
422
- return false ;
423
- if ( loader . drop_last && indices . Length < loader . batchSize ) {
424
- return false ;
425
- }
426
-
427
- var tensors = new T [ indices . Length ] ;
428
- Enumerable . Range ( 0 , indices . Length )
429
- . AsParallel ( )
430
- . WithDegreeOfParallelism ( loader . num_worker )
431
- . ForAll ( ( i ) => {
432
- tensors [ i ] = loader . dataset . GetTensor ( indices [ i ] ) ;
433
- } ) ;
434
-
435
- using var collate_scope = torch . NewDisposeScope ( ) ;
436
- current = loader . collate_fn ( tensors , loader . device ) ;
437
- var disposables = collate_scope . DetachAllAndDispose ( ) ;
438
- if ( loader . disposeBatch ) {
439
- this . currentDisposables = disposables ;
440
- }
441
-
442
- return true ;
414
+ var indices = Enumerable . Range ( 0 , loader . batchSize )
415
+ . Select ( _ => shuffler . MoveNext ( ) ? shuffler . Current : ( long ? ) null )
416
+ . Where ( x => x . HasValue )
417
+ . Cast < long > ( )
418
+ . ToArray ( ) ;
419
+
420
+ if ( indices . Length is 0 )
421
+ return false ;
422
+
423
+ if ( loader . drop_last && indices . Length < loader . batchSize ) {
424
+ return false ;
425
+ }
426
+
427
+ var tensors = new T [ indices . Length ] ;
428
+ var getTensorDisposables = new HashSet < IDisposable > [ indices . Length ] ;
429
+ Enumerable . Range ( 0 , indices . Length )
430
+ . AsParallel ( )
431
+ . WithDegreeOfParallelism ( loader . num_worker )
432
+ . ForAll ( ( i ) => {
433
+ using var getTensorScope = torch . NewDisposeScope ( ) ;
434
+ tensors [ i ] = loader . dataset . GetTensor ( indices [ i ] ) ;
435
+ getTensorDisposables [ i ] = getTensorScope . DetachAllAndDispose ( ) ;
436
+ } ) ;
437
+
438
+ using var collateScope = torch . NewDisposeScope ( ) ;
439
+ this . current = loader . collate_fn ( tensors , loader . device ) ;
440
+ var collateDisposables = collateScope . DetachAllAndDispose ( ) ;
441
+ if ( loader . disposeBatch ) {
442
+ this . currentDisposables = collateDisposables ;
443
443
}
444
+
445
+ foreach ( var set in getTensorDisposables ) {
446
+ DisposeAll ( set ) ;
447
+ }
448
+
449
+ return true ;
444
450
}
445
451
446
452
/// <summary>
@@ -470,8 +476,7 @@ public void Dispose()
470
476
private void DisposeCurrent ( )
471
477
{
472
478
if ( this . currentDisposables is not null ) {
473
- foreach ( var x in this . currentDisposables )
474
- x . Dispose ( ) ;
479
+ DisposeAll ( this . currentDisposables ) ;
475
480
this . currentDisposables = null ;
476
481
}
477
482
}
0 commit comments