Skip to content

Commit b0afc3f

Browse files
committed
DetachAllAndDispose
1 parent 5542eda commit b0afc3f

File tree

2 files changed

+25
-11
lines changed

2 files changed

+25
-11
lines changed

src/TorchSharp/DataLoader.cs

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -287,11 +287,10 @@ sealed class DataLoaderEnumerator : IEnumerator<S>
287287
{
288288
private readonly DataLoader<T, S> loader;
289289
private IEnumerator<long> shuffler;
290-
private IReadOnlyList<IDisposable> currentDisposables;
290+
private HashSet<IDisposable>? currentDisposables;
291291
public DataLoaderEnumerator(DataLoader<T, S> loader)
292292
{
293293
this.loader = loader;
294-
this.currentDisposables = Array.Empty<IDisposable>();
295294
// TODO: Use MemberNotNull instead.
296295
shuffler = null!;
297296
Reset();
@@ -333,14 +332,11 @@ public bool MoveNext()
333332
tensors[i] = loader.dataset.GetTensor(indices[i]);
334333
});
335334

336-
using var collate_scope = DisposeScopeManager.NewDisposeScope();
335+
using var collate_scope = torch.NewDisposeScope();
337336
current = loader.collate_fn(tensors, loader.device);
338-
339-
// TODO: Will be better if we have something like DetachAll
340-
var view = collate_scope.DisposablesView;
341-
collate_scope.Detach(view);
337+
var disposables = collate_scope.DetachAllAndDispose();
342338
if (loader.disposeBatch) {
343-
this.currentDisposables = view;
339+
this.currentDisposables = disposables;
344340
}
345341

346342
return true;
@@ -373,9 +369,11 @@ public void Dispose()
373369

374370
private void DisposeCurrent()
375371
{
376-
foreach (var x in this.currentDisposables)
377-
x.Dispose();
378-
this.currentDisposables = Array.Empty<IDisposable>();
372+
if (this.currentDisposables is not null) {
373+
foreach (var x in this.currentDisposables)
374+
x.Dispose();
375+
this.currentDisposables = null;
376+
}
379377
}
380378
}
381379
}

src/TorchSharp/DisposeScope.cs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,5 +365,21 @@ private void AddToOther(DisposeScope? scope, IDisposable disposable)
365365
tensor.OwningDisposeScope = scope;
366366
}
367367
}
368+
369+
internal HashSet<IDisposable> DetachAllAndDispose()
370+
{
371+
var disposables = this.Disposables;
372+
foreach (var disposable in this.Disposables) {
373+
this._disposeScopeManager!.StatisticsInstance.DetachedFromScopeCount++;
374+
if (disposable is torch.Tensor tensor) {
375+
tensor.OwningDisposeScope = null;
376+
}
377+
}
378+
379+
this.Disposables = new();
380+
this.Dispose();
381+
382+
return disposables;
383+
}
368384
}
369385
}

0 commit comments

Comments
 (0)