Skip to content

Commit 316b139

Browse files
committed
Added ability to iterate over iterators (altough memory management has a problem)
1 parent d35ff70 commit 316b139

File tree

4 files changed

+383
-29
lines changed

4 files changed

+383
-29
lines changed

Test/TorchSharp.cs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,5 +373,38 @@ public void TestMNISTLoader()
373373
Assert.AreEqual(size, i * 32);
374374
}
375375
}
376+
377+
[TestMethod]
378+
public void TestMNISTLoaderWithEpochs()
379+
{
380+
using (var train = Data.Loader.MNIST(@"E:/Source/Repos/LibTorchSharp/MNIST", 32))
381+
{
382+
var size = train.Size();
383+
var epochs = 10;
384+
385+
Assert.IsNotNull(train);
386+
Assert.IsNotNull(size);
387+
388+
int i = 0;
389+
390+
for (int e = 0; e < epochs; e++)
391+
{
392+
foreach (var (data, target) in train)
393+
{
394+
i++;
395+
396+
CollectionAssert.AreEqual(data.Shape, new long[] { 32, 1, 28, 28 });
397+
CollectionAssert.AreEqual(target.Shape, new long[] { 32 });
398+
399+
data.Dispose();
400+
target.Dispose();
401+
}
402+
403+
var t = i;
404+
}
405+
406+
Assert.AreEqual(size * epochs, i * 32);
407+
}
408+
}
376409
}
377410
}

TorchSharp/Data/DataIterator.cs

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@ internal static class ExternMethods
2121

2222
[DllImport("LibTorchSharp")]
2323
extern internal static long Data_Size(IntPtr iterator);
24+
25+
[DllImport("LibTorchSharp")]
26+
extern internal static void Data_Reset(IntPtr iterator);
27+
28+
[DllImport("LibTorchSharp")]
29+
extern internal static void Data_Dispose(IntPtr iterator);
2430
}
2531

2632
/// <summary>
@@ -65,6 +71,8 @@ protected override void Dispose(bool disposing)
6571

6672
protected HType handle;
6773

74+
protected IEnumerator<(ITorchTensor<TData> data, ITorchTensor<TTarget> target)> @enum = null;
75+
6876
/// <summary>
6977
/// Constructor.
7078
/// </summary>
@@ -95,6 +103,7 @@ protected void Dispose(bool disposing)
95103
{
96104
if (disposing)
97105
{
106+
ExternMethods.Data_Dispose(handle.DangerousGetHandle());
98107
handle.Dispose();
99108
handle.SetHandleAsInvalid();
100109
}
@@ -115,7 +124,9 @@ public long Size()
115124
/// <returns></returns>
116125
public IEnumerator<(ITorchTensor<TData> data, ITorchTensor<TTarget> target)> GetEnumerator()
117126
{
118-
return new DataIteratorEnumerator(this);
127+
@enum?.Reset();
128+
@enum = @enum ?? new DataIteratorEnumerator(this);
129+
return @enum;
119130
}
120131

121132

@@ -148,8 +159,7 @@ public DataIteratorEnumerator(DataIterator<TData, TTarget> iterator)
148159
{
149160
get
150161
{
151-
152-
ExternMethods.Data_Current(_iterator.handle.DangerousGetHandle(), _dRef, _tRef);
162+
ExternMethods.Data_Current(_iterator.handle.DangerousGetHandle(), _dRef, _tRef);
153163
return (_darray.Array[0].ToTorchTensor<TData>(), _tarray.Array[0].ToTorchTensor<TTarget>());
154164
}
155165
}
@@ -169,7 +179,8 @@ public bool MoveNext()
169179

170180
public void Reset()
171181
{
172-
throw new InvalidOperationException();
182+
_isFirst = true;
183+
ExternMethods.Data_Reset(_iterator.handle.DangerousGetHandle());
173184
}
174185

175186
public void Dispose()

TorchSharp/Generated/TorchTensor.generated.cs

Lines changed: 85 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,13 @@ public class ByteTensor : ITorchTensor<byte>
1515
[DllImport("LibTorchSharp")]
1616
extern static AtenSharp.ByteTensor.HType THS_getTHTensorUnsafe(HType handle);
1717

18+
[DllImport("LibTorchSharp")]
19+
extern static void THS_Delete(HType handle);
20+
1821
internal sealed class HType : SafeHandle
1922
{
23+
internal bool shouldClean = true;
24+
2025
public HType(IntPtr preexistingHandle, bool ownsHandle) : base(IntPtr.Zero, ownsHandle)
2126
{
2227
SetHandle(preexistingHandle);
@@ -31,8 +36,12 @@ internal HType() : base(IntPtr.Zero, true)
3136

3237
protected override bool ReleaseHandle()
3338
{
34-
var atenTensor = new AtenSharp.ByteTensor(THS_getTHTensorUnsafe(this));
35-
atenTensor.Dispose();
39+
// var atenTensor = new AtenSharp.ByteTensor(THS_getTHTensorUnsafe(this));
40+
// atenTensor.Dispose();
41+
if (shouldClean)
42+
{
43+
THS_Delete (this);
44+
}
3645
return true;
3746
}
3847

@@ -47,9 +56,10 @@ protected override void Dispose(bool disposing)
4756

4857
internal HType handle;
4958

50-
internal ByteTensor(HType handle)
59+
internal ByteTensor(HType handle, bool shouldClean = true)
5160
{
5261
this.handle = handle;
62+
this.handle.shouldClean = shouldClean;
5363
}
5464

5565
/// <summary>
@@ -313,8 +323,13 @@ public class ShortTensor : ITorchTensor<short>
313323
[DllImport("LibTorchSharp")]
314324
extern static AtenSharp.ShortTensor.HType THS_getTHTensorUnsafe(HType handle);
315325

326+
[DllImport("LibTorchSharp")]
327+
extern static void THS_Delete(HType handle);
328+
316329
internal sealed class HType : SafeHandle
317330
{
331+
internal bool shouldClean = true;
332+
318333
public HType(IntPtr preexistingHandle, bool ownsHandle) : base(IntPtr.Zero, ownsHandle)
319334
{
320335
SetHandle(preexistingHandle);
@@ -329,8 +344,12 @@ internal HType() : base(IntPtr.Zero, true)
329344

330345
protected override bool ReleaseHandle()
331346
{
332-
var atenTensor = new AtenSharp.ShortTensor(THS_getTHTensorUnsafe(this));
333-
atenTensor.Dispose();
347+
// var atenTensor = new AtenSharp.ShortTensor(THS_getTHTensorUnsafe(this));
348+
// atenTensor.Dispose();
349+
if (shouldClean)
350+
{
351+
THS_Delete (this);
352+
}
334353
return true;
335354
}
336355

@@ -345,9 +364,10 @@ protected override void Dispose(bool disposing)
345364

346365
internal HType handle;
347366

348-
internal ShortTensor(HType handle)
367+
internal ShortTensor(HType handle, bool shouldClean = true)
349368
{
350369
this.handle = handle;
370+
this.handle.shouldClean = shouldClean;
351371
}
352372

353373
/// <summary>
@@ -611,8 +631,13 @@ public class IntTensor : ITorchTensor<int>
611631
[DllImport("LibTorchSharp")]
612632
extern static AtenSharp.IntTensor.HType THS_getTHTensorUnsafe(HType handle);
613633

634+
[DllImport("LibTorchSharp")]
635+
extern static void THS_Delete(HType handle);
636+
614637
internal sealed class HType : SafeHandle
615638
{
639+
internal bool shouldClean = true;
640+
616641
public HType(IntPtr preexistingHandle, bool ownsHandle) : base(IntPtr.Zero, ownsHandle)
617642
{
618643
SetHandle(preexistingHandle);
@@ -627,8 +652,12 @@ internal HType() : base(IntPtr.Zero, true)
627652

628653
protected override bool ReleaseHandle()
629654
{
630-
var atenTensor = new AtenSharp.IntTensor(THS_getTHTensorUnsafe(this));
631-
atenTensor.Dispose();
655+
// var atenTensor = new AtenSharp.IntTensor(THS_getTHTensorUnsafe(this));
656+
// atenTensor.Dispose();
657+
if (shouldClean)
658+
{
659+
THS_Delete (this);
660+
}
632661
return true;
633662
}
634663

@@ -643,9 +672,10 @@ protected override void Dispose(bool disposing)
643672

644673
internal HType handle;
645674

646-
internal IntTensor(HType handle)
675+
internal IntTensor(HType handle, bool shouldClean = true)
647676
{
648677
this.handle = handle;
678+
this.handle.shouldClean = shouldClean;
649679
}
650680

651681
/// <summary>
@@ -909,8 +939,13 @@ public class LongTensor : ITorchTensor<long>
909939
[DllImport("LibTorchSharp")]
910940
extern static AtenSharp.LongTensor.HType THS_getTHTensorUnsafe(HType handle);
911941

942+
[DllImport("LibTorchSharp")]
943+
extern static void THS_Delete(HType handle);
944+
912945
internal sealed class HType : SafeHandle
913946
{
947+
internal bool shouldClean = true;
948+
914949
public HType(IntPtr preexistingHandle, bool ownsHandle) : base(IntPtr.Zero, ownsHandle)
915950
{
916951
SetHandle(preexistingHandle);
@@ -925,8 +960,12 @@ internal HType() : base(IntPtr.Zero, true)
925960

926961
protected override bool ReleaseHandle()
927962
{
928-
var atenTensor = new AtenSharp.LongTensor(THS_getTHTensorUnsafe(this));
929-
atenTensor.Dispose();
963+
// var atenTensor = new AtenSharp.LongTensor(THS_getTHTensorUnsafe(this));
964+
// atenTensor.Dispose();
965+
if (shouldClean)
966+
{
967+
THS_Delete (this);
968+
}
930969
return true;
931970
}
932971

@@ -941,9 +980,10 @@ protected override void Dispose(bool disposing)
941980

942981
internal HType handle;
943982

944-
internal LongTensor(HType handle)
983+
internal LongTensor(HType handle, bool shouldClean = true)
945984
{
946985
this.handle = handle;
986+
this.handle.shouldClean = shouldClean;
947987
}
948988

949989
/// <summary>
@@ -1207,8 +1247,13 @@ public class DoubleTensor : ITorchTensor<double>
12071247
[DllImport("LibTorchSharp")]
12081248
extern static AtenSharp.DoubleTensor.HType THS_getTHTensorUnsafe(HType handle);
12091249

1250+
[DllImport("LibTorchSharp")]
1251+
extern static void THS_Delete(HType handle);
1252+
12101253
internal sealed class HType : SafeHandle
12111254
{
1255+
internal bool shouldClean = true;
1256+
12121257
public HType(IntPtr preexistingHandle, bool ownsHandle) : base(IntPtr.Zero, ownsHandle)
12131258
{
12141259
SetHandle(preexistingHandle);
@@ -1223,8 +1268,12 @@ internal HType() : base(IntPtr.Zero, true)
12231268

12241269
protected override bool ReleaseHandle()
12251270
{
1226-
var atenTensor = new AtenSharp.DoubleTensor(THS_getTHTensorUnsafe(this));
1227-
atenTensor.Dispose();
1271+
// var atenTensor = new AtenSharp.DoubleTensor(THS_getTHTensorUnsafe(this));
1272+
// atenTensor.Dispose();
1273+
if (shouldClean)
1274+
{
1275+
THS_Delete (this);
1276+
}
12281277
return true;
12291278
}
12301279

@@ -1239,9 +1288,10 @@ protected override void Dispose(bool disposing)
12391288

12401289
internal HType handle;
12411290

1242-
internal DoubleTensor(HType handle)
1291+
internal DoubleTensor(HType handle, bool shouldClean = true)
12431292
{
12441293
this.handle = handle;
1294+
this.handle.shouldClean = shouldClean;
12451295
}
12461296

12471297
/// <summary>
@@ -1505,8 +1555,13 @@ public class FloatTensor : ITorchTensor<float>
15051555
[DllImport("LibTorchSharp")]
15061556
extern static AtenSharp.FloatTensor.HType THS_getTHTensorUnsafe(HType handle);
15071557

1558+
[DllImport("LibTorchSharp")]
1559+
extern static void THS_Delete(HType handle);
1560+
15081561
internal sealed class HType : SafeHandle
15091562
{
1563+
internal bool shouldClean = true;
1564+
15101565
public HType(IntPtr preexistingHandle, bool ownsHandle) : base(IntPtr.Zero, ownsHandle)
15111566
{
15121567
SetHandle(preexistingHandle);
@@ -1521,8 +1576,12 @@ internal HType() : base(IntPtr.Zero, true)
15211576

15221577
protected override bool ReleaseHandle()
15231578
{
1524-
var atenTensor = new AtenSharp.FloatTensor(THS_getTHTensorUnsafe(this));
1525-
atenTensor.Dispose();
1579+
// var atenTensor = new AtenSharp.FloatTensor(THS_getTHTensorUnsafe(this));
1580+
// atenTensor.Dispose();
1581+
if (shouldClean)
1582+
{
1583+
THS_Delete (this);
1584+
}
15261585
return true;
15271586
}
15281587

@@ -1537,9 +1596,10 @@ protected override void Dispose(bool disposing)
15371596

15381597
internal HType handle;
15391598

1540-
internal FloatTensor(HType handle)
1599+
internal FloatTensor(HType handle, bool shouldClean = true)
15411600
{
15421601
this.handle = handle;
1602+
this.handle.shouldClean = shouldClean;
15431603
}
15441604

15451605
/// <summary>
@@ -1806,33 +1866,33 @@ internal enum ATenScalarMapping : short
18061866

18071867
public static class TensorExtensionMethods
18081868
{
1809-
internal static ITorchTensor<T> ToTorchTensor<T>(this IntPtr rawTensor)
1869+
internal static ITorchTensor<T> ToTorchTensor<T>(this IntPtr rawTensor, bool shouldClean = true)
18101870
{
18111871
switch (true)
18121872
{
18131873
case bool _ when typeof(T) == typeof(byte):
18141874
{
1815-
return (ITorchTensor<T>)new ByteTensor(new ByteTensor.HType(rawTensor, true));
1875+
return (ITorchTensor<T>)new ByteTensor(new ByteTensor.HType(rawTensor, true), shouldClean);
18161876
}
18171877
case bool _ when typeof(T) == typeof(short):
18181878
{
1819-
return (ITorchTensor<T>)new ShortTensor(new ShortTensor.HType(rawTensor, true));
1879+
return (ITorchTensor<T>)new ShortTensor(new ShortTensor.HType(rawTensor, true), shouldClean);
18201880
}
18211881
case bool _ when typeof(T) == typeof(int):
18221882
{
1823-
return (ITorchTensor<T>)new IntTensor(new IntTensor.HType(rawTensor, true));
1883+
return (ITorchTensor<T>)new IntTensor(new IntTensor.HType(rawTensor, true), shouldClean);
18241884
}
18251885
case bool _ when typeof(T) == typeof(long):
18261886
{
1827-
return (ITorchTensor<T>)new LongTensor(new LongTensor.HType(rawTensor, true));
1887+
return (ITorchTensor<T>)new LongTensor(new LongTensor.HType(rawTensor, true), shouldClean);
18281888
}
18291889
case bool _ when typeof(T) == typeof(double):
18301890
{
1831-
return (ITorchTensor<T>)new DoubleTensor(new DoubleTensor.HType(rawTensor, true));
1891+
return (ITorchTensor<T>)new DoubleTensor(new DoubleTensor.HType(rawTensor, true), shouldClean);
18321892
}
18331893
case bool _ when typeof(T) == typeof(float):
18341894
{
1835-
return (ITorchTensor<T>)new FloatTensor(new FloatTensor.HType(rawTensor, true));
1895+
return (ITorchTensor<T>)new FloatTensor(new FloatTensor.HType(rawTensor, true), shouldClean);
18361896
}
18371897
default: throw new NotImplementedException($"Creating tensor of type {typeof(T)} is not supported.");
18381898
}

0 commit comments

Comments
 (0)