Skip to content

Commit d328e84

Browse files
Fix the creation and replacement of a Tensor with a Parameter instance
1 parent 569e332 commit d328e84

File tree

5 files changed

+121
-103
lines changed

5 files changed

+121
-103
lines changed

src/TorchSharp/DisposeScope.cs

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,21 @@ public void Detach(IEnumerable<IDisposable> disposables)
224224
}
225225
}
226226

227+
/// <summary>
228+
/// Replaces registration of one tensor with another.
229+
/// </summary>
230+
/// <param name="original">The original tensor, possibly registered under a dispose scope.</param>
231+
/// <param name="replacement">The replacement tensor.</param>
232+
internal static void ReplaceWith(torch.Tensor original, torch.Tensor replacement)
233+
{
234+
DisposeScope? scope = original.OwningDisposeScope;
235+
236+
if (scope != null && scope.Disposables.Remove(original)) {
237+
original.OwningDisposeScope = null;
238+
AddToOther(scope, replacement);
239+
}
240+
}
241+
227242
public void Attach(IDisposable disposable)
228243
{
229244
_ = Attach((IEnumerable<IDisposable>)new[] { disposable });
@@ -369,10 +384,10 @@ public void MarkAsDisposed(IDisposable disposable)
369384
/// <returns></returns>
370385
public bool Contains(IDisposable disposable) => Disposables.Contains(disposable);
371386

372-
private bool AddToOther(DisposeScope scope, IDisposable disposable)
387+
private static bool AddToOther(DisposeScope scope, IDisposable disposable)
373388
{
374-
if (this._disposeScopeManager is null)
375-
throw new ObjectDisposedException(this.GetType().FullName);
389+
// if (this._disposeScopeManager is null)
390+
// throw new ObjectDisposedException(this.GetType().FullName);
376391

377392
DisposeScope? oldScope;
378393
if (disposable is torch.Tensor t) {

src/TorchSharp/NN/Parameter.cs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,11 @@ public class Parameter : Tensor
2424
/// <param name="data">A tensor, which will become empty.</param>
2525
/// <param name="requires_grad"></param>
2626
public Parameter(Tensor data, bool requires_grad = true) :
27-
base(data.with_requires_grad(requires_grad).MoveHandle())
27+
base(data.with_requires_grad(requires_grad).MoveHandle(), false)
2828
{
2929
var scope = data.OwningDisposeScope;
3030
if (scope is not null) {
31-
this.OwningDisposeScope = scope;
32-
scope.Attach(this);
33-
scope.Detach(data);
31+
DisposeScope.ReplaceWith(data, this);
3432
}
3533
}
3634

src/TorchSharp/Tensor/Tensor.cs

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,14 @@ public partial class Tensor : IDisposable
3434

3535
internal DisposeScope? OwningDisposeScope { get; set; }
3636

37-
internal Tensor(IntPtr handle)
37+
internal Tensor(IntPtr handle, bool register = true)
3838
{
3939
this.handle = handle;
4040
System.Threading.Interlocked.Increment(ref _totalCount);
4141
_peakCount = Math.Max(_totalCount, _peakCount);
42-
OwningDisposeScope = DisposeScopeManager.ThreadSingleton.RegisterOnCurrentDisposeScope(this);
42+
if (register) {
43+
OwningDisposeScope = DisposeScopeManager.ThreadSingleton.RegisterOnCurrentDisposeScope(this);
44+
}
4345
}
4446

4547
/// <summary>
@@ -213,6 +215,11 @@ public IntPtr Handle {
213215
}
214216
}
215217

218+
/// <summary>
219+
/// Disassociates the native tensor handle from the managed Tensor.
220+
/// </summary>
221+
/// <remarks>Used to create a Parameter instance.</remarks>
222+
/// <returns>The handle to the underlying native tensor.</returns>
216223
internal IntPtr MoveHandle()
217224
{
218225
var h = handle;
@@ -2172,7 +2179,7 @@ public Tensor transpose_(long dim0, long dim1)
21722179
CheckForErrors();
21732180
return this;
21742181
}
2175-
2182+
21762183
public Tensor threshold(Scalar threshold, Scalar value)
21772184
{
21782185
var res = NativeMethods.THSTensor_threshold(Handle, threshold.Handle, value.Handle);
@@ -2717,7 +2724,7 @@ public Tensor softmax(long dim, ScalarType? dtype = null) =>
27172724
torch.special.softmax(this, dim, dtype);
27182725

27192726

2720-
public Tensor softplus(int beta = 1, int threshold = 20) =>
2727+
public Tensor softplus(int beta = 1, int threshold = 20) =>
27212728
softplus1(beta, threshold);
27222729

27232730
private Tensor softplus1(Scalar beta, Scalar threshold)
@@ -2787,9 +2794,9 @@ public Tensor rrelu_(double lower = one_eighth, double upper = one_third)
27872794
}
27882795

27892796
public Tensor celu() => this.celu(1.0);
2790-
2797+
27912798
public Tensor celu_() => this.celu_(1.0);
2792-
2799+
27932800
public Tensor celu(Scalar alpha)
27942801
{
27952802
var res = NativeMethods.THSTensor_celu(Handle, alpha.Handle);

test/TorchSharpTest/TestDisposeScopes.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ public void DisposeScopesCanBeNestled()
192192
DisposeScopeManager.Statistics.Reset();
193193
}
194194

195-
[Fact(Skip = "https://github.com/dotnet/TorchSharp/issues/1397")]
195+
[Fact]
196196
public void DisposeScopeWorksForTestTraining1()
197197
{
198198
DisposeScopeManager.Statistics.Reset();
@@ -207,7 +207,7 @@ public void DisposeScopeWorksForTestTraining1()
207207
DisposeScopeManager.Statistics.Reset();
208208
}
209209

210-
[Fact(Skip = "https://github.com/dotnet/TorchSharp/issues/1397")]
210+
[Fact]
211211
public void DisposeScopeWorksForTestTrainingConv2d()
212212
{
213213
DisposeScopeManager.Statistics.Reset();

0 commit comments

Comments
 (0)