Skip to content

Commit c22c9d3

Browse files
authored
Merge branch '1303' into temp
2 parents e220b28 + 5cd1dff commit c22c9d3

File tree

4 files changed

+71
-46
lines changed

4 files changed

+71
-46
lines changed

RELEASENOTES.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,14 @@ __Breaking Changes__:
88

99
- `torchvision.dataset.MNIST` will try more mirrors.
1010
- The thrown exception might be changed when it fails to download `MNIST`, `FashionMNIST` or `KMNIST`.
11+
- `ObjectDisposedException` will now be thrown when trying to use the disposed dispose scopes.
12+
- The constructor of dispose scopes is no longer `public`. Use `torch.NewDisposeScope` instead.
1113

1214
__API Changes__:
1315

1416
- #1291 `Tensor.grad()` and `Tensor.set_grad()` have been replaced by a new property `Tensor.grad`.
1517
- A potential memory leak caused by `set_grad` has been resolved.
18+
- `Include` method of dispose scopes has been removed. Use `Attach` instead.
1619

1720
__Bug Fixes__:
1821

@@ -25,6 +28,7 @@ __Bug Fixes__:
2528

2629
- `TensorDataset` will now keep the aliases detached from dispose scopes, to avoid the unexpected disposal.
2730
- `DataLoaderEnumerator` has been completely rewritten to resolve the unexpected shuffler disposal, the ignorance of drop last and the incorrect count of worker.
31+
- #1303 Allow dispose scopes to be disposed out of LIFO order.
2832

2933
# NuGet Version 0.102.4
3034

src/TorchSharp/DisposeScope.cs

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,50 +14,46 @@ namespace TorchSharp
1414
/// </summary>
1515
public sealed class DisposeScope : IDisposable
1616
{
17-
private readonly DisposeScopeManager _disposeScopeManager;
17+
private DisposeScopeManager? _disposeScopeManager;
1818

19-
public DisposeScope(DisposeScopeManager disposeScopeManager)
19+
internal DisposeScope(DisposeScopeManager disposeScopeManager)
2020
{
21-
_disposeScopeManager = disposeScopeManager ?? throw new ArgumentNullException(nameof(disposeScopeManager));
22-
if (disposeScopeManager.DisposeScopeStack.Count > 0) {
23-
OuterScope = disposeScopeManager.DisposeScopeStack.Peek();
24-
}
21+
_disposeScopeManager = disposeScopeManager;
22+
this.OuterScope = disposeScopeManager.CurrentDisposeScope;
2523
}
2624

2725
/// <summary>
2826
/// The outer scope with relation to this scope.
2927
/// </summary>
30-
internal DisposeScope? OuterScope { get; }
28+
internal DisposeScope? OuterScope { get; set; }
3129

3230
/// <summary>
3331
/// The disposables that are scheduled for disposing.
3432
/// </summary>
35-
/// TODO: There is a ReferenceEqualityComparer coming in .NET 6, use that!
3633
internal HashSet<IDisposable> Disposables { get; private set; } =
3734
new HashSet<IDisposable>(ReferenceEqualityComparer<IDisposable>.Default);
3835

3936
/// <summary>
4037
/// A view of the disposables in the scope - this list will not be kept in synch with the disposables
4138
/// in the scope.
4239
/// </summary>
43-
public IReadOnlyList<IDisposable> DisposablesView => Disposables.ToList();
40+
public IReadOnlyList<IDisposable> DisposablesView {
41+
get {
42+
if (this._disposeScopeManager is null)
43+
throw new ObjectDisposedException("The dispose scope has been disposed.");
44+
return Disposables.ToArray();
45+
}
46+
}
4447

4548
/// <summary>
4649
/// The number of disposables currently held in the scope
4750
/// </summary>
48-
public int DisposablesCount => Disposables.Count;
49-
50-
/// <summary>
51-
/// Includes a disposable in the scope - for tensors this is done automatically once the scope has been
52-
/// created. Use this method to add additional disposables that should be disposed, but you typically
53-
/// don't need to call this method.
54-
/// </summary>
55-
/// <param name="disposable">The disposable to keep in the scope</param>
56-
/// <returns></returns>
57-
public T Include<T>(T disposable) where T : IDisposable
58-
{
59-
Disposables.Add(disposable);
60-
return disposable;
51+
public int DisposablesCount {
52+
get {
53+
if (this._disposeScopeManager is null)
54+
throw new ObjectDisposedException("The dispose scope has been disposed.");
55+
return Disposables.Count;
56+
}
6157
}
6258

6359
/// <summary>
@@ -156,6 +152,8 @@ public void MoveToOther(DisposeScope? scope, params IDisposable[] disposables) =
156152
/// </summary>
157153
public void MoveToOther(DisposeScope? scope, IEnumerable<IDisposable> disposables)
158154
{
155+
if (this._disposeScopeManager is null)
156+
throw new ObjectDisposedException("The dispose scope has been disposed.");
159157
foreach (var disposable in disposables) {
160158
if (Disposables.Remove(disposable)) {
161159
AddToOther(scope, disposable);
@@ -207,6 +205,8 @@ public T Detach<T>(T disposable) where T : IDisposable
207205
/// </summary>
208206
public void Detach(IEnumerable<IDisposable> disposables)
209207
{
208+
if (this._disposeScopeManager is null)
209+
throw new ObjectDisposedException("The dispose scope has been disposed.");
210210
foreach (var disposable in disposables) {
211211
if (Disposables.Remove(disposable)) {
212212
_disposeScopeManager.StatisticsInstance.DetachedFromScopeCount++;
@@ -219,6 +219,8 @@ public void Detach(IEnumerable<IDisposable> disposables)
219219

220220
public void Attach(IDisposable disposable)
221221
{
222+
if (this._disposeScopeManager is null)
223+
throw new ObjectDisposedException("The dispose scope has been disposed.");
222224
if (disposable is torch.Tensor tensor) {
223225
if (tensor.OwningDisposeScope == null && !tensor.IsInvalid) {
224226
_disposeScopeManager.StatisticsInstance.DetachedFromScopeCount--;
@@ -240,6 +242,8 @@ public void Attach(IDisposable disposable)
240242
/// </summary>
241243
public void DisposeEverythingBut(IEnumerable<IDisposable> inKeep)
242244
{
245+
if (this._disposeScopeManager is null)
246+
throw new ObjectDisposedException("The dispose scope has been disposed.");
243247
// Avoiding multiple enumerations
244248
var oldList = Disposables;
245249
Disposables = inKeep.ToHashSet(ReferenceEqualityComparer<IDisposable>.Default);
@@ -315,8 +319,11 @@ public T DisposeEverythingBut<T>(T keep) where T : IDisposable
315319
/// </summary>
316320
public void Dispose()
317321
{
322+
if (this._disposeScopeManager is null)
323+
return;
318324
DisposeEverything();
319325
_disposeScopeManager.RemoveDisposeScope(this);
326+
this._disposeScopeManager = null;
320327
}
321328

322329
/// <summary>
@@ -328,6 +335,8 @@ public void Dispose()
328335
/// <param name="disposable">The disposable that was disposed</param>
329336
public void MarkAsDisposed(IDisposable disposable)
330337
{
338+
if (this._disposeScopeManager is null)
339+
throw new ObjectDisposedException("The dispose scope has been disposed.");
331340
_disposeScopeManager.StatisticsInstance.DisposedInScopeCount++;
332341
Disposables.Remove(disposable);
333342
if (disposable is torch.Tensor tensor) {
@@ -344,6 +353,8 @@ public void MarkAsDisposed(IDisposable disposable)
344353

345354
private void AddToOther(DisposeScope? scope, IDisposable disposable)
346355
{
356+
if (this._disposeScopeManager is null)
357+
throw new ObjectDisposedException("The dispose scope has been disposed.");
347358
if (scope != null) {
348359
scope.Disposables.Add(disposable);
349360
} else {

src/TorchSharp/DisposeScopeManager.cs

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
22

33
using System;
4-
using System.Collections.Generic;
5-
using System.Diagnostics;
64

75
#nullable enable
86
namespace TorchSharp
@@ -15,43 +13,54 @@ namespace TorchSharp
1513
public class DisposeScopeManager
1614
{
1715
[ThreadStatic] private static DisposeScopeManager? _threadSingleton;
18-
internal ThreadDisposeScopeStatistics StatisticsInstance { get; } = new ThreadDisposeScopeStatistics();
19-
2016
internal static DisposeScopeManager ThreadSingleton => (_threadSingleton ??= new DisposeScopeManager());
21-
internal Stack<DisposeScope> DisposeScopeStack { get; } = new();
2217

23-
public static ThreadDisposeScopeStatistics Statistics => ThreadSingleton.StatisticsInstance;
18+
internal ThreadDisposeScopeStatistics StatisticsInstance { get; } = new ThreadDisposeScopeStatistics();
19+
internal DisposeScope? CurrentDisposeScope { get; private set; } = null;
2420

25-
internal DisposeScope? RegisterOnCurrentDisposeScope(IDisposable disposable)
21+
internal DisposeScope? RegisterOnCurrentDisposeScope(torch.Tensor tensor)
2622
{
27-
if (DisposeScopeStack.Count == 0) {
23+
if (this.CurrentDisposeScope is null) {
2824
StatisticsInstance.CreatedOutsideScopeCount++;
2925
return null;
3026
}
3127

3228
StatisticsInstance.CreatedInScopeCount++;
33-
var current = DisposeScopeStack.Peek();
34-
current.Include(disposable);
35-
return current;
36-
}
37-
38-
internal static DisposeScope NewDisposeScope()
39-
{
40-
return ThreadSingleton.InnerNewDisposeScope();
29+
this.CurrentDisposeScope.Disposables.Add(tensor);
30+
return CurrentDisposeScope;
4131
}
4232

4333
internal void RemoveDisposeScope(DisposeScope disposeScope)
4434
{
45-
Debug.Assert(DisposeScopeStack.Count > 0);
46-
Debug.Assert(DisposeScopeStack.Peek() == disposeScope);
47-
DisposeScopeStack.Pop();
35+
var scope = this.CurrentDisposeScope;
36+
if (object.ReferenceEquals(scope, disposeScope)) {
37+
this.CurrentDisposeScope = scope.OuterScope;
38+
return;
39+
}
40+
if (scope is null) {
41+
return;
42+
}
43+
44+
for (; ; ) {
45+
var outerScope = scope.OuterScope;
46+
if (object.ReferenceEquals(outerScope, disposeScope)) {
47+
scope.OuterScope = outerScope.OuterScope;
48+
return;
49+
}
50+
51+
if (outerScope is null) {
52+
return;
53+
}
54+
scope = outerScope;
55+
}
4856
}
4957

50-
private DisposeScope InnerNewDisposeScope()
58+
internal DisposeScope NewDisposeScope()
5159
{
52-
var disposeScope = new DisposeScope(this);
53-
DisposeScopeStack.Push(disposeScope);
54-
return disposeScope;
60+
this.CurrentDisposeScope = new DisposeScope(this);
61+
return this.CurrentDisposeScope;
5562
}
63+
64+
public static ThreadDisposeScopeStatistics Statistics => ThreadSingleton.StatisticsInstance;
5665
}
5766
}

src/TorchSharp/Tensor/Tensor.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7374,7 +7374,8 @@ public static long max_int_value(ScalarType type)
73747374
/// Creates a new dispose scope for the current thread. Any tensor created within the dispose scope will
73757375
/// be automatically disposed once the dispose scope is disposed.
73767376
/// </summary>
7377-
public static DisposeScope NewDisposeScope() => DisposeScopeManager.NewDisposeScope();
7377+
public static DisposeScope NewDisposeScope() =>
7378+
DisposeScopeManager.ThreadSingleton.NewDisposeScope();
73787379

73797380
/// <summary>
73807381
/// Creates a new dispose scope for the current thread, wrapping an expression.

0 commit comments

Comments
 (0)