Skip to content

Commit 1137173

Browse files
committed
disposed state of DisposeScope & linked list like DisposeScopeStack
1 parent 3eb8326 commit 1137173

File tree

2 files changed

+63
-22
lines changed

2 files changed

+63
-22
lines changed

src/TorchSharp/DisposeScope.cs

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,39 +14,47 @@ namespace TorchSharp
1414
/// </summary>
1515
public sealed class DisposeScope : IDisposable
1616
{
17-
private readonly DisposeScopeManager _disposeScopeManager;
17+
private DisposeScopeManager? _disposeScopeManager;
1818

19-
public DisposeScope(DisposeScopeManager disposeScopeManager)
19+
internal DisposeScope(DisposeScopeManager disposeScopeManager)
2020
{
21-
_disposeScopeManager = disposeScopeManager ?? throw new ArgumentNullException(nameof(disposeScopeManager));
22-
if (disposeScopeManager.DisposeScopeStack.Count > 0) {
23-
OuterScope = disposeScopeManager.DisposeScopeStack[
24-
disposeScopeManager.DisposeScopeStack.Count - 1];
25-
}
21+
_disposeScopeManager = disposeScopeManager;
22+
this.OuterScope = disposeScopeManager.CurrentDisposeScope;
2623
}
2724

2825
/// <summary>
2926
/// The outer scope with relation to this scope.
3027
/// </summary>
31-
internal DisposeScope? OuterScope { get; }
28+
internal DisposeScope? OuterScope { get; set; }
3229

3330
/// <summary>
3431
/// The disposables that are scheduled for disposing.
3532
/// </summary>
36-
/// TODO: There is a ReferenceEqualityComparer coming in .NET 6, use that!
3733
internal HashSet<IDisposable> Disposables { get; private set; } =
3834
new HashSet<IDisposable>(ReferenceEqualityComparer<IDisposable>.Default);
3935

4036
/// <summary>
4137
/// A view of the disposables in the scope - this list will not be kept in synch with the disposables
4238
/// in the scope.
4339
/// </summary>
44-
public IReadOnlyList<IDisposable> DisposablesView => Disposables.ToList();
40+
public IReadOnlyList<IDisposable> DisposablesView {
41+
get {
42+
if (this._disposeScopeManager is null)
43+
throw new ObjectDisposedException("The dispose scope has been disposed.");
44+
return Disposables.ToArray();
45+
}
46+
}
4547

4648
/// <summary>
4749
/// The number of disposables currently held in the scope
4850
/// </summary>
49-
public int DisposablesCount => Disposables.Count;
51+
public int DisposablesCount {
52+
get {
53+
if (this._disposeScopeManager is null)
54+
throw new ObjectDisposedException("The dispose scope has been disposed.");
55+
return Disposables.Count;
56+
}
57+
}
5058

5159
/// <summary>
5260
/// Includes a disposable in the scope - for tensors this is done automatically once the scope has been
@@ -57,6 +65,8 @@ public DisposeScope(DisposeScopeManager disposeScopeManager)
5765
/// <returns></returns>
5866
public T Include<T>(T disposable) where T : IDisposable
5967
{
68+
if (this._disposeScopeManager is null)
69+
throw new ObjectDisposedException("The dispose scope has been disposed.");
6070
Disposables.Add(disposable);
6171
return disposable;
6272
}
@@ -157,6 +167,8 @@ public void MoveToOther(DisposeScope? scope, params IDisposable[] disposables) =
157167
/// </summary>
158168
public void MoveToOther(DisposeScope? scope, IEnumerable<IDisposable> disposables)
159169
{
170+
if (this._disposeScopeManager is null)
171+
throw new ObjectDisposedException("The dispose scope has been disposed.");
160172
foreach (var disposable in disposables) {
161173
if (Disposables.Remove(disposable)) {
162174
AddToOther(scope, disposable);
@@ -208,6 +220,8 @@ public T Detach<T>(T disposable) where T : IDisposable
208220
/// </summary>
209221
public void Detach(IEnumerable<IDisposable> disposables)
210222
{
223+
if (this._disposeScopeManager is null)
224+
throw new ObjectDisposedException("The dispose scope has been disposed.");
211225
foreach (var disposable in disposables) {
212226
if (Disposables.Remove(disposable)) {
213227
_disposeScopeManager.StatisticsInstance.DetachedFromScopeCount++;
@@ -220,6 +234,8 @@ public void Detach(IEnumerable<IDisposable> disposables)
220234

221235
public void Attach(IDisposable disposable)
222236
{
237+
if (this._disposeScopeManager is null)
238+
throw new ObjectDisposedException("The dispose scope has been disposed.");
223239
if (disposable is torch.Tensor tensor) {
224240
if (tensor.OwningDisposeScope == null && !tensor.IsInvalid) {
225241
_disposeScopeManager.StatisticsInstance.DetachedFromScopeCount--;
@@ -241,6 +257,8 @@ public void Attach(IDisposable disposable)
241257
/// </summary>
242258
public void DisposeEverythingBut(IEnumerable<IDisposable> inKeep)
243259
{
260+
if (this._disposeScopeManager is null)
261+
throw new ObjectDisposedException("The dispose scope has been disposed.");
244262
// Avoiding multiple enumerations
245263
var oldList = Disposables;
246264
Disposables = inKeep.ToHashSet(ReferenceEqualityComparer<IDisposable>.Default);
@@ -316,8 +334,11 @@ public T DisposeEverythingBut<T>(T keep) where T : IDisposable
316334
/// </summary>
317335
public void Dispose()
318336
{
337+
if (this._disposeScopeManager is null)
338+
return;
319339
DisposeEverything();
320340
_disposeScopeManager.RemoveDisposeScope(this);
341+
this._disposeScopeManager = null;
321342
}
322343

323344
/// <summary>
@@ -329,6 +350,8 @@ public void Dispose()
329350
/// <param name="disposable">The disposable that was disposed</param>
330351
public void MarkAsDisposed(IDisposable disposable)
331352
{
353+
if (this._disposeScopeManager is null)
354+
throw new ObjectDisposedException("The dispose scope has been disposed.");
332355
_disposeScopeManager.StatisticsInstance.DisposedInScopeCount++;
333356
Disposables.Remove(disposable);
334357
if (disposable is torch.Tensor tensor) {
@@ -345,6 +368,8 @@ public void MarkAsDisposed(IDisposable disposable)
345368

346369
private void AddToOther(DisposeScope? scope, IDisposable disposable)
347370
{
371+
if (this._disposeScopeManager is null)
372+
throw new ObjectDisposedException("The dispose scope has been disposed.");
348373
if (scope != null) {
349374
scope.Disposables.Add(disposable);
350375
} else {

src/TorchSharp/DisposeScopeManager.cs

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,20 @@ public class DisposeScopeManager
1818
internal ThreadDisposeScopeStatistics StatisticsInstance { get; } = new ThreadDisposeScopeStatistics();
1919

2020
internal static DisposeScopeManager ThreadSingleton => (_threadSingleton ??= new DisposeScopeManager());
21-
internal List<DisposeScope> DisposeScopeStack { get; } = new();
21+
internal DisposeScope? CurrentDisposeScope { get; set; } = null;
2222

2323
public static ThreadDisposeScopeStatistics Statistics => ThreadSingleton.StatisticsInstance;
2424

2525
internal DisposeScope? RegisterOnCurrentDisposeScope(IDisposable disposable)
2626
{
27-
if (DisposeScopeStack.Count == 0) {
27+
if (this.CurrentDisposeScope is null) {
2828
StatisticsInstance.CreatedOutsideScopeCount++;
2929
return null;
3030
}
3131

3232
StatisticsInstance.CreatedInScopeCount++;
33-
var current = DisposeScopeStack[DisposeScopeStack.Count - 1];
34-
current.Include(disposable);
35-
return current;
33+
this.CurrentDisposeScope.Include(disposable);
34+
return CurrentDisposeScope;
3635
}
3736

3837
internal static DisposeScope NewDisposeScope()
@@ -42,16 +41,33 @@ internal static DisposeScope NewDisposeScope()
4241

4342
internal void RemoveDisposeScope(DisposeScope disposeScope)
4443
{
45-
var index = DisposeScopeStack.LastIndexOf(disposeScope);
46-
if (index is not -1)
47-
DisposeScopeStack.RemoveAt(index);
44+
var scope = this.CurrentDisposeScope;
45+
if (object.ReferenceEquals(scope, disposeScope)) {
46+
this.CurrentDisposeScope = scope.OuterScope;
47+
return;
48+
}
49+
if (scope is null) {
50+
return;
51+
}
52+
53+
for (; ; ) {
54+
var outerScope = scope.OuterScope;
55+
if (object.ReferenceEquals(outerScope, disposeScope)) {
56+
scope.OuterScope = outerScope.OuterScope;
57+
return;
58+
}
59+
60+
if (outerScope is null) {
61+
return;
62+
}
63+
scope = outerScope;
64+
}
4865
}
4966

5067
private DisposeScope InnerNewDisposeScope()
5168
{
52-
var disposeScope = new DisposeScope(this);
53-
DisposeScopeStack.Add(disposeScope);
54-
return disposeScope;
69+
this.CurrentDisposeScope = new DisposeScope(this);
70+
return this.CurrentDisposeScope;
5571
}
5672
}
5773
}

0 commit comments

Comments
 (0)