Skip to content

Commit 01f0b7b

Browse files
committed
Added MNIST loader.
1 parent 79fde58 commit 01f0b7b

File tree

3 files changed

+56
-3
lines changed

3 files changed

+56
-3
lines changed

Test/TorchSharp.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using Microsoft.VisualStudio.TestTools.UnitTesting;
2+
using System.Linq;
23
using TorchSharp.Tensor;
34

45
namespace TorchSharp.Test
@@ -316,5 +317,17 @@ public void TestTrainingAdam()
316317
optimizer.Step();
317318
}
318319
}
320+
321+
[TestMethod]
322+
public void TestMNISTLoader()
323+
{
324+
var train = Data.Loader.MNIST(@"E:/Source/Repos/LibTorchSharp/MNIST", 64);
325+
326+
foreach (var (data, target) in train.Take(10))
327+
{
328+
CollectionAssert.AreEqual(data.Shape, new long[] { 64, 1, 28, 28 });
329+
CollectionAssert.AreEqual(target.Shape, new long[] { 64 });
330+
}
331+
}
319332
}
320333
}

TorchSharp/Data/Loader.cs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Runtime.InteropServices;
5+
using TorchSharp.Tensor;
6+
7+
namespace TorchSharp.Data
8+
{
9+
public class Loader
10+
{
11+
[DllImport("LibTorchSharp")]
12+
extern static void Data_LoaderMNIST(
13+
string filename,
14+
long batchSize,
15+
bool isTrain,
16+
AllocatePinnedArray dataAllocator,
17+
AllocatePinnedArray targetAllocator);
18+
19+
static public IEnumerable<(ITorchTensor<float> data, ITorchTensor<float> target)> MNIST(string filename, long batchSize, bool isTrain = true)
20+
{
21+
IntPtr[] dataPtrArray;
22+
IntPtr[] targetPtrArray;
23+
24+
using (var data = new PinnedArray<IntPtr>())
25+
using (var target = new PinnedArray<IntPtr>())
26+
{
27+
Data_LoaderMNIST(filename, batchSize, isTrain, data.CreateArray, target.CreateArray);
28+
dataPtrArray = data.Array;
29+
targetPtrArray = target.Array;
30+
}
31+
32+
return dataPtrArray
33+
.Zip(
34+
targetPtrArray,
35+
(d, t) => (
36+
(ITorchTensor<float>)new FloatTensor(new FloatTensor.HType(d, true)),
37+
(ITorchTensor<float>)new FloatTensor(new FloatTensor.HType(t, true))));
38+
}
39+
}
40+
}

TorchSharp/NN/Module.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,14 +111,14 @@ public virtual void ZeroGrad()
111111

112112
public virtual IEnumerable<ITorchTensor<float>> Parameters()
113113
{
114-
IntPtr[] ros;
114+
IntPtr[] ptrArray;
115115

116116
using (var pa = new PinnedArray<IntPtr>())
117117
{
118118
NN_GetParameters(handle, pa.CreateArray);
119-
ros = pa.Array;
119+
ptrArray = pa.Array;
120120
}
121-
return ros.Select(x => new FloatTensor(new FloatTensor.HType(x, true)));
121+
return ptrArray.Select(x => new FloatTensor(new FloatTensor.HType(x, true)));
122122
}
123123

124124
[DllImport("LibTorchSharp")]

0 commit comments

Comments
 (0)