@@ -287,11 +287,10 @@ sealed class DataLoaderEnumerator : IEnumerator<S>
287
287
{
288
288
private readonly DataLoader < T , S > loader ;
289
289
private IEnumerator < long > shuffler ;
290
- private IReadOnlyList < IDisposable > currentDisposables ;
290
+ private HashSet < IDisposable > ? currentDisposables ;
291
291
public DataLoaderEnumerator ( DataLoader < T , S > loader )
292
292
{
293
293
this . loader = loader ;
294
- this . currentDisposables = Array . Empty < IDisposable > ( ) ;
295
294
// TODO: Use MemberNotNull instead.
296
295
shuffler = null ! ;
297
296
Reset ( ) ;
@@ -333,14 +332,11 @@ public bool MoveNext()
333
332
tensors [ i ] = loader . dataset . GetTensor ( indices [ i ] ) ;
334
333
} ) ;
335
334
336
- using var collate_scope = DisposeScopeManager . NewDisposeScope ( ) ;
335
+ using var collate_scope = torch . NewDisposeScope ( ) ;
337
336
current = loader . collate_fn ( tensors , loader . device ) ;
338
-
339
- // TODO: Will be better if we have something like DetachAll
340
- var view = collate_scope . DisposablesView ;
341
- collate_scope . Detach ( view ) ;
337
+ var disposables = collate_scope . DetachAllAndDispose ( ) ;
342
338
if ( loader . disposeBatch ) {
343
- this . currentDisposables = view ;
339
+ this . currentDisposables = disposables ;
344
340
}
345
341
346
342
return true ;
@@ -373,9 +369,11 @@ public void Dispose()
373
369
374
370
private void DisposeCurrent ( )
375
371
{
376
- foreach ( var x in this . currentDisposables )
377
- x . Dispose ( ) ;
378
- this . currentDisposables = Array . Empty < IDisposable > ( ) ;
372
+ if ( this . currentDisposables is not null ) {
373
+ foreach ( var x in this . currentDisposables )
374
+ x . Dispose ( ) ;
375
+ this . currentDisposables = null ;
376
+ }
379
377
}
380
378
}
381
379
}
0 commit comments