Skip to content

Commit d35ff70

Browse files
committed
Fixed:
* naming with functional modules * iterator instead of memory copy Still weill need to figure out how to reset the iterator to implement epochs.
1 parent d641c36 commit d35ff70

17 files changed

+420
-97
lines changed

Examples/MNIST.cs

Lines changed: 42 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,26 +6,25 @@ namespace TorchSharp.Examples
66
{
77
public class MNIST
88
{
9+
private readonly static long _batch = 64;
10+
private readonly static string _trainDataset = @"E:/Source/Repos/LibTorchSharp/MNIST";
11+
912
static void Main(string[] args)
1013
{
11-
var train = Data.Loader.MNIST(@"E:/Source/Repos/LibTorchSharp/MNIST", 64, out int size);
12-
13-
var model = new Model();
14-
15-
var optimizer = NN.Optimizer.SGD(model.Parameters(), 0.01, 0.5);
16-
17-
for (var epoch = 1; epoch <= 10; epoch++)
14+
using (var train = Data.Loader.MNIST(_trainDataset, _batch))
15+
using (var model = new Model())
16+
using (var optimizer = NN.Optimizer.SGD(model.Parameters(), 0.01, 0.5))
1817
{
19-
Train(model, optimizer, train, epoch, size);
18+
Train(model, optimizer, train, _batch, train.Size());
2019
}
2120
}
2221

2322
private class Model : NN.Module
2423
{
25-
private NN.Module conv1 = NN.Module.Conv2D(1, 10, 5);
26-
private NN.Module conv2 = NN.Module.Conv2D(10, 20, 5);
27-
private NN.Module fc1 = NN.Module.Linear(320, 50);
28-
private NN.Module fc2 = NN.Module.Linear(50, 10);
24+
private NN.Module conv1 = Conv2D(1, 10, 5);
25+
private NN.Module conv2 = Conv2D(10, 20, 5);
26+
private NN.Module fc1 = Linear(320, 50);
27+
private NN.Module fc2 = Linear(50, 10);
2928

3029
public Model() : base(IntPtr.Zero)
3130
{
@@ -37,47 +36,57 @@ public Model() : base(IntPtr.Zero)
3736

3837
public override ITorchTensor<float> Forward<T>(ITorchTensor<T> tensor)
3938
{
40-
var x = conv1.Forward(tensor);
41-
x = NN.Module.MaxPool2D(x, 2);
42-
x = NN.Module.Relu(x);
39+
using (var l11 = conv1.Forward(tensor))
40+
using (var l12 = MaxPool2D(l11, 2))
41+
using (var l13 = Relu(l12))
4342

44-
x = conv2.Forward(x);
45-
x = NN.Module.FeatureDropout(x);
46-
x = NN.Module.MaxPool2D(x, 2);
43+
using (var l21 = conv2.Forward(l13))
44+
using (var l22 = FeatureDropout(l21))
45+
using (var l23 = MaxPool2D(l22, 2))
4746

48-
x = x.View(new long[] { -1, 320 });
47+
using (var x = l23.View(new long[] { -1, 320 }))
4948

50-
x = fc1.Forward(x);
51-
x = NN.Module.Relu(x);
52-
x = NN.Module.Dropout(x, 0.5, _isTraining);
49+
using (var l31 = fc1.Forward(x))
50+
using (var l32 = Relu(l31))
51+
using (var l33 = Dropout(l32, 0.5, _isTraining))
5352

54-
x = fc2.Forward(x);
53+
using (var l41 = fc2.Forward(l33))
5554

56-
return NN.Module.LogSoftMax(x, 1);
55+
return LogSoftMax(l41, 1);
5756
}
5857
}
5958

60-
private static void Train(NN.Module model, NN.Optimizer optimizer, IEnumerable<(ITorchTensor<float>, ITorchTensor<float>)> dataLoader, int epoch, int size)
59+
private static void Train(
60+
NN.Module model,
61+
NN.Optimizer optimizer,
62+
IEnumerable<(ITorchTensor<int>, ITorchTensor<int>)> dataLoader,
63+
long batchSize,
64+
long size)
6165
{
6266
model.Train();
6367

64-
int batchId = 0;
68+
int batchId = 1;
6569

6670
foreach (var (data, target) in dataLoader)
6771
{
6872
optimizer.ZeroGrad();
6973

70-
var output = model.Forward(data);
71-
72-
var loss = NN.LossFunction.NLL(output, target);
74+
if (batchId == 937)
75+
{
76+
Console.WriteLine();
77+
}
7378

74-
loss.Backward();
79+
using (var output = model.Forward(data))
80+
using (var loss = NN.LossFunction.NLL(output, target))
81+
{
82+
loss.Backward();
7583

76-
optimizer.Step();
84+
optimizer.Step();
7785

78-
batchId++;
86+
batchId++;
7987

80-
Console.WriteLine($"\rTrain Epoch: {epoch} [{batchId} / {size}] Loss: {loss.Item}");
88+
Console.WriteLine($"\rTrain: [{batchId * batchSize} / {size}] Loss: {loss.Item}");
89+
}
8190
}
8291
}
8392
}

Test/TorchSharp.cs

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
using Microsoft.VisualStudio.TestTools.UnitTesting;
22
using System;
3-
using System.Collections.Generic;
43
using System.Linq;
54
using TorchSharp.Tensor;
65

@@ -16,6 +15,34 @@ public void CreateFloatTensorOnes()
1615
Assert.IsNotNull(ones);
1716
}
1817

18+
[TestMethod]
19+
public void CreateFloatTensorCheckDistructor()
20+
{
21+
ITorchTensor<float> ones = null;
22+
23+
using (var tmp = FloatTensor.Ones(new long[] { 2, 2 }))
24+
{
25+
ones = tmp;
26+
Assert.IsNotNull(ones);
27+
}
28+
Assert.ThrowsException<ObjectDisposedException>(ones.Grad);
29+
}
30+
31+
[TestMethod]
32+
public void CreateFloatTensorCheckMemory()
33+
{
34+
ITorchTensor<float> ones = null;
35+
36+
for (int i = 0; i < 10; i++)
37+
{
38+
using (var tmp = FloatTensor.Ones(new long[] { 1000, 1000, 1000 }))
39+
{
40+
ones = tmp;
41+
Assert.IsNotNull(ones);
42+
}
43+
}
44+
}
45+
1946
[TestMethod]
2047
public void CreateFloatTensorOnesCheckData()
2148
{
@@ -323,14 +350,27 @@ public void TestTrainingAdam()
323350
[TestMethod]
324351
public void TestMNISTLoader()
325352
{
326-
var train = Data.Loader.MNIST(@"E:/Source/Repos/LibTorchSharp/MNIST", 32, out int size);
327-
int i = 0;
328-
329-
foreach (var (data, target) in train.SkipLast(2))
353+
using (var train = Data.Loader.MNIST(@"E:/Source/Repos/LibTorchSharp/MNIST", 32))
330354
{
331-
CollectionAssert.AreEqual(data.Shape, new long[] { 32, 1, 28, 28 });
332-
CollectionAssert.AreEqual(target.Shape, new long[] { 32 });
333-
i++;
355+
var size = train.Size();
356+
357+
Assert.IsNotNull(train);
358+
Assert.IsNotNull(size);
359+
360+
int i = 0;
361+
362+
foreach (var (data, target) in train)
363+
{
364+
i++;
365+
366+
CollectionAssert.AreEqual(data.Shape, new long[] { 32, 1, 28, 28 });
367+
CollectionAssert.AreEqual(target.Shape, new long[] { 32 });
368+
369+
data.Dispose();
370+
target.Dispose();
371+
}
372+
373+
Assert.AreEqual(size, i * 32);
334374
}
335375
}
336376
}

TorchSharp.sln

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Test", "Test\Test.csproj",
1010
EndProject
1111
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AtenSharp", "AtenSharp\AtenSharp.csproj", "{42F79D9F-7122-47CC-B1FA-FDF849940824}"
1212
EndProject
13+
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Examples", "Examples\Examples.csproj", "{FA1CC1D4-6291-4D80-85F8-0A25F4BD1AE9}"
14+
EndProject
1315
Global
1416
GlobalSection(SolutionConfigurationPlatforms) = preSolution
1517
Debug|Any CPU = Debug|Any CPU
@@ -68,6 +70,18 @@ Global
6870
{42F79D9F-7122-47CC-B1FA-FDF849940824}.Release|x64.Build.0 = Release|Any CPU
6971
{42F79D9F-7122-47CC-B1FA-FDF849940824}.Release|x86.ActiveCfg = Release|Any CPU
7072
{42F79D9F-7122-47CC-B1FA-FDF849940824}.Release|x86.Build.0 = Release|Any CPU
73+
{FA1CC1D4-6291-4D80-85F8-0A25F4BD1AE9}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
74+
{FA1CC1D4-6291-4D80-85F8-0A25F4BD1AE9}.Debug|Any CPU.Build.0 = Debug|Any CPU
75+
{FA1CC1D4-6291-4D80-85F8-0A25F4BD1AE9}.Debug|x64.ActiveCfg = Debug|Any CPU
76+
{FA1CC1D4-6291-4D80-85F8-0A25F4BD1AE9}.Debug|x64.Build.0 = Debug|Any CPU
77+
{FA1CC1D4-6291-4D80-85F8-0A25F4BD1AE9}.Debug|x86.ActiveCfg = Debug|Any CPU
78+
{FA1CC1D4-6291-4D80-85F8-0A25F4BD1AE9}.Debug|x86.Build.0 = Debug|Any CPU
79+
{FA1CC1D4-6291-4D80-85F8-0A25F4BD1AE9}.Release|Any CPU.ActiveCfg = Release|Any CPU
80+
{FA1CC1D4-6291-4D80-85F8-0A25F4BD1AE9}.Release|Any CPU.Build.0 = Release|Any CPU
81+
{FA1CC1D4-6291-4D80-85F8-0A25F4BD1AE9}.Release|x64.ActiveCfg = Release|Any CPU
82+
{FA1CC1D4-6291-4D80-85F8-0A25F4BD1AE9}.Release|x64.Build.0 = Release|Any CPU
83+
{FA1CC1D4-6291-4D80-85F8-0A25F4BD1AE9}.Release|x86.ActiveCfg = Release|Any CPU
84+
{FA1CC1D4-6291-4D80-85F8-0A25F4BD1AE9}.Release|x86.Build.0 = Release|Any CPU
7185
EndGlobalSection
7286
GlobalSection(SolutionProperties) = preSolution
7387
HideSolutionNode = FALSE

0 commit comments

Comments
 (0)