Skip to content

Commit b545520

Browse files
committed
Enable DisposeScope for PackedSequence
1 parent 1dd3ae5 commit b545520

File tree

4 files changed

+126
-3
lines changed

4 files changed

+126
-3
lines changed

src/TorchSharp/DisposeScope.cs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,9 @@ public void Detach(IEnumerable<IDisposable> disposables)
213213
if (disposable is torch.Tensor tensor) {
214214
tensor.OwningDisposeScope = null;
215215
}
216+
else if (disposable is torch.nn.utils.rnn.PackedSequence sequence) {
217+
sequence.OwningDisposeScope = null;
218+
}
216219
}
217220
}
218221
}
@@ -239,9 +242,16 @@ public IReadOnlyList<IDisposable> Attach(IEnumerable<IDisposable> disposables)
239242
_disposeScopeManager.StatisticsInstance.DetachedFromScopeCount--;
240243
}
241244
}
245+
else if (disposable is torch.nn.utils.rnn.PackedSequence sequence) {
246+
if (sequence.OwningDisposeScope == null) {
247+
_disposeScopeManager.StatisticsInstance.DetachedFromScopeCount--;
248+
}
249+
}
250+
242251
AddToOther(this, disposable);
243252
result.Add(disposable);
244253
}
254+
245255
return result;
246256
}
247257

@@ -274,6 +284,12 @@ public void DisposeEverythingBut(IEnumerable<IDisposable> inKeep)
274284
if (!tensor.IsInvalid) {
275285
_disposeScopeManager.StatisticsInstance.DisposedInScopeCount++;
276286
}
287+
} else if (disposable is torch.nn.utils.rnn.PackedSequence sequence) {
288+
// No need to have the disposable call back to the scope
289+
sequence.OwningDisposeScope = null;
290+
if (!sequence.IsInvalid) {
291+
_disposeScopeManager.StatisticsInstance.DisposedInScopeCount++;
292+
}
277293
} else {
278294
_disposeScopeManager.StatisticsInstance.DisposedInScopeCount++;
279295
}
@@ -358,6 +374,9 @@ public void MarkAsDisposed(IDisposable disposable)
358374
if (disposable is torch.Tensor tensor) {
359375
tensor.OwningDisposeScope = null;
360376
}
377+
else if (disposable is torch.nn.utils.rnn.PackedSequence sequence) {
378+
sequence.OwningDisposeScope = null;
379+
}
361380
}
362381

363382
/// <summary>
@@ -380,6 +399,9 @@ private void AddToOther(DisposeScope? scope, IDisposable disposable)
380399
if (disposable is torch.Tensor tensor) {
381400
tensor.OwningDisposeScope = scope;
382401
}
402+
else if (disposable is torch.nn.utils.rnn.PackedSequence sequence) {
403+
sequence.OwningDisposeScope = scope;
404+
}
383405
}
384406

385407
internal HashSet<IDisposable> DetachAllAndDispose()
@@ -390,6 +412,9 @@ internal HashSet<IDisposable> DetachAllAndDispose()
390412
if (disposable is torch.Tensor tensor) {
391413
tensor.OwningDisposeScope = null;
392414
}
415+
else if (disposable is torch.nn.utils.rnn.PackedSequence sequence) {
416+
sequence.OwningDisposeScope = null;
417+
}
393418
}
394419

395420
this.Disposables = new();

src/TorchSharp/DisposeScopeManager.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ public class DisposeScopeManager
1818
internal ThreadDisposeScopeStatistics StatisticsInstance { get; } = new ThreadDisposeScopeStatistics();
1919
internal DisposeScope? CurrentDisposeScope { get; private set; } = null;
2020

21-
internal DisposeScope? RegisterOnCurrentDisposeScope(torch.Tensor tensor)
21+
internal DisposeScope? RegisterOnCurrentDisposeScope(IDisposable tensor)
2222
{
2323
if (this.CurrentDisposeScope is null) {
2424
StatisticsInstance.CreatedOutsideScopeCount++;

src/TorchSharp/NN/Utils/PackedSequence.cs

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
2+
23
using System;
34
using System.Runtime.InteropServices;
45
using static TorchSharp.PInvoke.NativeMethods;
@@ -18,6 +19,8 @@ public static partial class rnn
1819
/// </summary>
1920
public sealed class PackedSequence : IDisposable
2021
{
22+
internal DisposeScope OwningDisposeScope { get; set; }
23+
2124
/// <summary>
2225
/// Class wrapping PyTorch's packedsequence object reference.
2326
/// </summary>
@@ -39,6 +42,7 @@ internal HType() : base(IntPtr.Zero, true)
3942
protected override bool ReleaseHandle()
4043
{
4144
THSNN_PackedSequence_dispose(handle);
45+
handle = IntPtr.Zero;
4246
return true;
4347
}
4448
}
@@ -62,6 +66,10 @@ protected override bool ReleaseHandle()
6266
/// The original indices
6367
/// </summary>
6468
public readonly Tensor unsorted_indices;
69+
/// <summary>
70+
/// Is true if the PackedSequence has been disposed, false otherwise.
71+
/// </summary>
72+
public bool IsInvalid => handle.IsInvalid;
6573
private HType handle;
6674

6775
internal PackedSequence(HType handle)
@@ -71,6 +79,11 @@ internal PackedSequence(HType handle)
7179
this.batch_sizes = new Tensor(THSNN_PackedSequence_batch_sizes(handle));
7280
this.sorted_indices = new Tensor(THSNN_PackedSequence_sorted_indices(handle));
7381
this.unsorted_indices = new Tensor(THSNN_PackedSequence_unsorted_indices(handle));
82+
OwningDisposeScope = DisposeScopeManager.ThreadSingleton.RegisterOnCurrentDisposeScope(this.data);
83+
OwningDisposeScope = DisposeScopeManager.ThreadSingleton.RegisterOnCurrentDisposeScope(this.batch_sizes);
84+
OwningDisposeScope = DisposeScopeManager.ThreadSingleton.RegisterOnCurrentDisposeScope(this.sorted_indices);
85+
OwningDisposeScope = DisposeScopeManager.ThreadSingleton.RegisterOnCurrentDisposeScope(this.unsorted_indices);
86+
OwningDisposeScope = DisposeScopeManager.ThreadSingleton.RegisterOnCurrentDisposeScope(this);
7487
}
7588

7689
internal HType Handle => handle;
@@ -84,15 +97,69 @@ public void Dispose()
8497
this.batch_sizes.Dispose();
8598
this.sorted_indices.Dispose();
8699
this.unsorted_indices.Dispose();
100+
OwningDisposeScope?.MarkAsDisposed(this);
87101

88102
if (handle != null && !handle.IsInvalid) {
89103
handle.Dispose();
90104
handle.SetHandleAsInvalid();
105+
106+
}
107+
}
108+
/// <summary>
109+
/// Moves PackedSequence to the outer DisposeScope. If there is no outer DisposeScope, it's detached from the
110+
/// DisposeScope system.
111+
/// </summary>
112+
/// <returns>The same PackedSequence that the method was called on</returns>
113+
public PackedSequence MoveToOuterDisposeScope()
114+
{
115+
OwningDisposeScope?.MoveToOuter(this.data);
116+
OwningDisposeScope?.MoveToOuter(this.batch_sizes);
117+
OwningDisposeScope?.MoveToOuter(this.sorted_indices);
118+
OwningDisposeScope?.MoveToOuter(this.unsorted_indices);
119+
OwningDisposeScope?.MoveToOuter(this);
120+
return this;
121+
}
122+
123+
/// <summary>
124+
/// Detaches the PackedSequence completely from the DisposeScope system.
125+
/// </summary>
126+
/// <returns>The same PackedSequence that the method was called on</returns>
127+
public PackedSequence DetachFromDisposeScope()
128+
{
129+
OwningDisposeScope?.Detach(this.data);
130+
OwningDisposeScope?.Detach(this.batch_sizes);
131+
OwningDisposeScope?.Detach(this.sorted_indices);
132+
OwningDisposeScope?.Detach(this.unsorted_indices);
133+
OwningDisposeScope?.Detach(this);
134+
return this;
135+
}
136+
137+
public PackedSequence MoveToOtherDisposeScope(PackedSequence other)
138+
{
139+
return MoveToOtherDisposeScope(other.OwningDisposeScope);
140+
}
141+
142+
public PackedSequence MoveToOtherDisposeScope(DisposeScope other)
143+
{
144+
if (OwningDisposeScope == null && other != null) {
145+
other.Attach(this.data);
146+
other.Attach(this.batch_sizes);
147+
other.Attach(this.sorted_indices);
148+
other.Attach(this.unsorted_indices);
149+
other.Attach(this);
150+
}
151+
else {
152+
OwningDisposeScope?.MoveToOther(other, this.data);
153+
OwningDisposeScope?.MoveToOther(other, this.batch_sizes);
154+
OwningDisposeScope?.MoveToOther(other, this.sorted_indices);
155+
OwningDisposeScope?.MoveToOther(other, this.unsorted_indices);
156+
OwningDisposeScope?.MoveToOther(other, this);
91157
}
158+
return this;
92159
}
93-
}
160+
}
94161
}
95162
}
96163
}
97164
}
98-
}
165+
}

test/TorchSharpTest/TestNNUtils.cs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,37 @@ public void TestPackSequence()
5454
Assert.True(torch.max(torch.square(inverted_sequences - padded_sequences)).item<long>() == 0);
5555
}
5656

57+
[Fact]
58+
public void TestPackSequenceMoveDisposeScope()
59+
{
60+
nn.utils.rnn.PackedSequence packed_sequence;
61+
var otherScope = NewDisposeScope();
62+
using (var outerScope = NewDisposeScope())
63+
{
64+
using (var innerScope = NewDisposeScope()) {
65+
var (sequences, sequences_len) = make_test();
66+
packed_sequence = torch.nn.utils.rnn.pack_sequence(sequences, enforce_sorted: false);
67+
AssertPackedSequenceValid(packed_sequence);
68+
packed_sequence.MoveToOuterDisposeScope();
69+
}
70+
AssertPackedSequenceValid(packed_sequence);
71+
packed_sequence.MoveToOtherDisposeScope(otherScope);
72+
}
73+
AssertPackedSequenceValid(packed_sequence);
74+
otherScope.Dispose();
75+
Assert.True(packed_sequence.IsInvalid);
76+
Assert.True(packed_sequence.data.IsInvalid);
77+
}
78+
79+
private static void AssertPackedSequenceValid(nn.utils.rnn.PackedSequence packed_sequence)
80+
{
81+
Assert.False(packed_sequence.IsInvalid);
82+
Assert.False(packed_sequence.batch_sizes.IsInvalid);
83+
Assert.False(packed_sequence.data.IsInvalid);
84+
Assert.False(packed_sequence.sorted_indices.IsInvalid);
85+
Assert.False(packed_sequence.unsorted_indices.IsInvalid);
86+
}
87+
5788
[Fact]
5889
public void TestAutoGradGrad()
5990
{

0 commit comments

Comments
 (0)