Skip to content

Commit d641c36

Browse files
committed
Done with MNIST. Added example project. Now need to make it work!
1 parent 91f9bc2 commit d641c36

File tree

14 files changed

+304
-47
lines changed

14 files changed

+304
-47
lines changed

Examples/Examples.csproj

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
<Project Sdk="Microsoft.NET.Sdk">
2+
3+
<PropertyGroup>
4+
<OutputType>Exe</OutputType>
5+
<TargetFramework>netcoreapp2.1</TargetFramework>
6+
</PropertyGroup>
7+
8+
<ItemGroup>
9+
<ProjectReference Include="..\TorchSharp\TorchSharp.csproj" />
10+
</ItemGroup>
11+
12+
</Project>

Examples/MNIST.cs

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using TorchSharp.Tensor;
4+
5+
namespace TorchSharp.Examples
6+
{
7+
public class MNIST
8+
{
9+
static void Main(string[] args)
10+
{
11+
var train = Data.Loader.MNIST(@"E:/Source/Repos/LibTorchSharp/MNIST", 64, out int size);
12+
13+
var model = new Model();
14+
15+
var optimizer = NN.Optimizer.SGD(model.Parameters(), 0.01, 0.5);
16+
17+
for (var epoch = 1; epoch <= 10; epoch++)
18+
{
19+
Train(model, optimizer, train, epoch, size);
20+
}
21+
}
22+
23+
private class Model : NN.Module
24+
{
25+
private NN.Module conv1 = NN.Module.Conv2D(1, 10, 5);
26+
private NN.Module conv2 = NN.Module.Conv2D(10, 20, 5);
27+
private NN.Module fc1 = NN.Module.Linear(320, 50);
28+
private NN.Module fc2 = NN.Module.Linear(50, 10);
29+
30+
public Model() : base(IntPtr.Zero)
31+
{
32+
RegisterModule(conv1);
33+
RegisterModule(conv2);
34+
RegisterModule(fc1);
35+
RegisterModule(fc2);
36+
}
37+
38+
public override ITorchTensor<float> Forward<T>(ITorchTensor<T> tensor)
39+
{
40+
var x = conv1.Forward(tensor);
41+
x = NN.Module.MaxPool2D(x, 2);
42+
x = NN.Module.Relu(x);
43+
44+
x = conv2.Forward(x);
45+
x = NN.Module.FeatureDropout(x);
46+
x = NN.Module.MaxPool2D(x, 2);
47+
48+
x = x.View(new long[] { -1, 320 });
49+
50+
x = fc1.Forward(x);
51+
x = NN.Module.Relu(x);
52+
x = NN.Module.Dropout(x, 0.5, _isTraining);
53+
54+
x = fc2.Forward(x);
55+
56+
return NN.Module.LogSoftMax(x, 1);
57+
}
58+
}
59+
60+
private static void Train(NN.Module model, NN.Optimizer optimizer, IEnumerable<(ITorchTensor<float>, ITorchTensor<float>)> dataLoader, int epoch, int size)
61+
{
62+
model.Train();
63+
64+
int batchId = 0;
65+
66+
foreach (var (data, target) in dataLoader)
67+
{
68+
optimizer.ZeroGrad();
69+
70+
var output = model.Forward(data);
71+
72+
var loss = NN.LossFunction.NLL(output, target);
73+
74+
loss.Backward();
75+
76+
optimizer.Step();
77+
78+
batchId++;
79+
80+
Console.WriteLine($"\rTrain Epoch: {epoch} [{batchId} / {size}] Loss: {loss.Item}");
81+
}
82+
}
83+
}
84+
}

Test/TorchSharp.cs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
using Microsoft.VisualStudio.TestTools.UnitTesting;
2+
using System;
3+
using System.Collections.Generic;
24
using System.Linq;
35
using TorchSharp.Tensor;
46

@@ -321,12 +323,14 @@ public void TestTrainingAdam()
321323
[TestMethod]
322324
public void TestMNISTLoader()
323325
{
324-
var train = Data.Loader.MNIST(@"E:/Source/Repos/LibTorchSharp/MNIST", 64);
326+
var train = Data.Loader.MNIST(@"E:/Source/Repos/LibTorchSharp/MNIST", 32, out int size);
327+
int i = 0;
325328

326-
foreach (var (data, target) in train.Take(10))
329+
foreach (var (data, target) in train.SkipLast(2))
327330
{
328-
CollectionAssert.AreEqual(data.Shape, new long[] { 64, 1, 28, 28 });
329-
CollectionAssert.AreEqual(target.Shape, new long[] { 64 });
331+
CollectionAssert.AreEqual(data.Shape, new long[] { 32, 1, 28, 28 });
332+
CollectionAssert.AreEqual(target.Shape, new long[] { 32 });
333+
i++;
330334
}
331335
}
332336
}

TorchSharp/Data/Loader.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ extern static void Data_LoaderMNIST(
1616
AllocatePinnedArray dataAllocator,
1717
AllocatePinnedArray targetAllocator);
1818

19-
static public IEnumerable<(ITorchTensor<float> data, ITorchTensor<float> target)> MNIST(string filename, long batchSize, bool isTrain = true)
19+
static public IEnumerable<(ITorchTensor<float> data, ITorchTensor<float> target)> MNIST(string filename, long batchSize, out int size, bool isTrain = true)
2020
{
2121
IntPtr[] dataPtrArray;
2222
IntPtr[] targetPtrArray;
@@ -27,6 +27,7 @@ extern static void Data_LoaderMNIST(
2727
Data_LoaderMNIST(filename, batchSize, isTrain, data.CreateArray, target.CreateArray);
2828
dataPtrArray = data.Array;
2929
targetPtrArray = target.Array;
30+
size = data.Array.Length;
3031
}
3132

3233
return dataPtrArray

TorchSharp/Generated/TorchTensor.generated.cs

Lines changed: 90 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,20 @@ public ITorchTensor<float> Grad()
252252
return new FloatTensor(THS_Grad(handle));
253253
}
254254

255+
[DllImport("LibTorchSharp")]
256+
extern static HType THS_View(HType src, IntPtr shape, int length);
257+
258+
public ITorchTensor<byte> View(params long[] shape)
259+
{
260+
unsafe
261+
{
262+
fixed (long* pshape = shape)
263+
{
264+
return new ByteTensor (THS_View (handle, (IntPtr)pshape, shape.Length));
265+
}
266+
}
267+
}
268+
255269
[DllImport("LibTorchSharp")]
256270
extern static HType THS_Sub_(HType src, IntPtr trg, bool is_grad);
257271

@@ -536,6 +550,20 @@ public ITorchTensor<float> Grad()
536550
return new FloatTensor(THS_Grad(handle));
537551
}
538552

553+
[DllImport("LibTorchSharp")]
554+
extern static HType THS_View(HType src, IntPtr shape, int length);
555+
556+
public ITorchTensor<short> View(params long[] shape)
557+
{
558+
unsafe
559+
{
560+
fixed (long* pshape = shape)
561+
{
562+
return new ShortTensor (THS_View (handle, (IntPtr)pshape, shape.Length));
563+
}
564+
}
565+
}
566+
539567
[DllImport("LibTorchSharp")]
540568
extern static HType THS_Sub_(HType src, IntPtr trg, bool is_grad);
541569

@@ -820,6 +848,20 @@ public ITorchTensor<float> Grad()
820848
return new FloatTensor(THS_Grad(handle));
821849
}
822850

851+
[DllImport("LibTorchSharp")]
852+
extern static HType THS_View(HType src, IntPtr shape, int length);
853+
854+
public ITorchTensor<int> View(params long[] shape)
855+
{
856+
unsafe
857+
{
858+
fixed (long* pshape = shape)
859+
{
860+
return new IntTensor (THS_View (handle, (IntPtr)pshape, shape.Length));
861+
}
862+
}
863+
}
864+
823865
[DllImport("LibTorchSharp")]
824866
extern static HType THS_Sub_(HType src, IntPtr trg, bool is_grad);
825867

@@ -1104,6 +1146,20 @@ public ITorchTensor<float> Grad()
11041146
return new FloatTensor(THS_Grad(handle));
11051147
}
11061148

1149+
[DllImport("LibTorchSharp")]
1150+
extern static HType THS_View(HType src, IntPtr shape, int length);
1151+
1152+
public ITorchTensor<long> View(params long[] shape)
1153+
{
1154+
unsafe
1155+
{
1156+
fixed (long* pshape = shape)
1157+
{
1158+
return new LongTensor (THS_View (handle, (IntPtr)pshape, shape.Length));
1159+
}
1160+
}
1161+
}
1162+
11071163
[DllImport("LibTorchSharp")]
11081164
extern static HType THS_Sub_(HType src, IntPtr trg, bool is_grad);
11091165

@@ -1388,6 +1444,20 @@ public ITorchTensor<float> Grad()
13881444
return new FloatTensor(THS_Grad(handle));
13891445
}
13901446

1447+
[DllImport("LibTorchSharp")]
1448+
extern static HType THS_View(HType src, IntPtr shape, int length);
1449+
1450+
public ITorchTensor<double> View(params long[] shape)
1451+
{
1452+
unsafe
1453+
{
1454+
fixed (long* pshape = shape)
1455+
{
1456+
return new DoubleTensor (THS_View (handle, (IntPtr)pshape, shape.Length));
1457+
}
1458+
}
1459+
}
1460+
13911461
[DllImport("LibTorchSharp")]
13921462
extern static HType THS_Sub_(HType src, IntPtr trg, bool is_grad);
13931463

@@ -1435,6 +1505,9 @@ public class FloatTensor : ITorchTensor<float>
14351505
[DllImport("LibTorchSharp")]
14361506
extern static AtenSharp.FloatTensor.HType THS_getTHTensorUnsafe(HType handle);
14371507

1508+
[DllImport("LibTorchSharp")]
1509+
extern static void THS_Delete(HType handle);
1510+
14381511
internal sealed class HType : SafeHandle
14391512
{
14401513
public HType(IntPtr preexistingHandle, bool ownsHandle) : base(IntPtr.Zero, ownsHandle)
@@ -1451,8 +1524,9 @@ internal HType() : base(IntPtr.Zero, true)
14511524

14521525
protected override bool ReleaseHandle()
14531526
{
1454-
var atenTensor = new AtenSharp.FloatTensor(THS_getTHTensorUnsafe(this));
1455-
atenTensor.Dispose();
1527+
//var atenTensor = new AtenSharp.FloatTensor(THS_getTHTensorUnsafe(this));
1528+
//atenTensor.Dispose();
1529+
THS_Delete(this);
14561530
return true;
14571531
}
14581532

@@ -1672,6 +1746,20 @@ public ITorchTensor<float> Grad()
16721746
return new FloatTensor(THS_Grad(handle));
16731747
}
16741748

1749+
[DllImport("LibTorchSharp")]
1750+
extern static HType THS_View(HType src, IntPtr shape, int length);
1751+
1752+
public ITorchTensor<float> View(params long[] shape)
1753+
{
1754+
unsafe
1755+
{
1756+
fixed (long* pshape = shape)
1757+
{
1758+
return new FloatTensor (THS_View (handle, (IntPtr)pshape, shape.Length));
1759+
}
1760+
}
1761+
}
1762+
16751763
[DllImport("LibTorchSharp")]
16761764
extern static HType THS_Sub_(HType src, IntPtr trg, bool is_grad);
16771765

TorchSharp/Generated/TorchTensor.tt

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,20 @@ foreach (var type in TorchTypeDef.Types) {
259259
return new FloatTensor(THS_Grad(handle));
260260
}
261261

262+
[DllImport("LibTorchSharp")]
263+
extern static HType THS_View(HType src, IntPtr shape, int length);
264+
265+
public ITorchTensor<<#=type.Storage#>> View(params long[] shape)
266+
{
267+
unsafe
268+
{
269+
fixed (long* pshape = shape)
270+
{
271+
return new <#=type.Name#>Tensor (THS_View (handle, (<#=type.Ptr#>)pshape, shape.Length));
272+
}
273+
}
274+
}
275+
262276
[DllImport("LibTorchSharp")]
263277
extern static HType THS_Sub_(HType src, IntPtr trg, bool is_grad);
264278

TorchSharp/NN/Dropout.cs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,19 @@ namespace TorchSharp.NN
1010
public class Dropout : FunctionalModule
1111
{
1212
private double _probability;
13-
private Func<bool> _isTraining;
1413

15-
internal Dropout(double probability, Func<bool> isTraining) : base()
14+
internal Dropout(double probability, bool isTraining) : base()
1615
{
1716
_probability = probability;
1817
_isTraining = isTraining;
1918
}
2019

2120
[DllImport("LibTorchSharp")]
22-
extern static FloatTensor.HType NN_LogSoftMaxModule_Forward(IntPtr tensor, double probability, bool isTraining);
21+
extern static FloatTensor.HType NN_DropoutModule_Forward(IntPtr tensor, double probability, bool isTraining);
2322

2423
public override ITorchTensor<float> Forward<T>(ITorchTensor<T> tensor)
2524
{
26-
return new FloatTensor(NN_LogSoftMaxModule_Forward(tensor.Handle, _probability, _isTraining.Invoke()));
25+
return new FloatTensor(NN_DropoutModule_Forward(tensor.Handle, _probability, _isTraining));
2726
}
2827
}
2928
}

TorchSharp/NN/FeatureDropout.cs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
using System;
2+
using System.Runtime.InteropServices;
3+
using TorchSharp.Tensor;
4+
5+
namespace TorchSharp.NN
6+
{
7+
/// <summary>
8+
/// This class is used to represent a dropout module for 2d/3d convolutational layers.
9+
/// </summary>
10+
public class FeatureDropout : FunctionalModule
11+
{
12+
internal FeatureDropout() : base()
13+
{
14+
}
15+
16+
[DllImport("LibTorchSharp")]
17+
extern static FloatTensor.HType NN_FeatureDropout_Forward(IntPtr tensor);
18+
19+
public override ITorchTensor<float> Forward<T>(ITorchTensor<T> tensor)
20+
{
21+
return new FloatTensor(NN_FeatureDropout_Forward(tensor.Handle));
22+
}
23+
}
24+
}

TorchSharp/NN/FunctionalModule.cs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,20 @@ internal FunctionalModule() : base(IntPtr.Zero)
1414
{
1515
}
1616

17-
public override void ZeroGrad()
17+
public override void RegisterModule(Module module)
1818
{
1919
}
2020

21-
public override bool IsTraining()
21+
public override void ZeroGrad()
2222
{
23-
return true;
2423
}
2524

2625
public override IEnumerable<ITorchTensor<float>> Parameters()
2726
{
2827
return new List<ITorchTensor<float>>();
2928
}
3029

31-
public override string[] GetModules()
30+
public override IEnumerable<string> GetModules()
3231
{
3332
return new string[0];
3433
}

TorchSharp/NN/LossFunction.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,14 @@ public static ITorchTensor<float> MSE<T>(ITorchTensor<T> src, ITorchTensor<T> ta
1616
{
1717
return new FloatTensor(NN_LossMSE(src.Handle, target.Handle, (long)reduction));
1818
}
19+
20+
[DllImport("LibTorchSharp")]
21+
extern static FloatTensor.HType NN_LossNLL(IntPtr srct, IntPtr trgt);
22+
23+
public static ITorchTensor<float> NLL<T>(ITorchTensor<T> src, ITorchTensor<T> target)
24+
{
25+
return new FloatTensor(NN_LossNLL(src.Handle, target.Handle));
26+
}
1927
}
2028

2129
public enum Reduction : long

0 commit comments

Comments
 (0)