Skip to content

Commit 1720934

Browse files
committed
* Added missing dipose methods.
* Fixed MNIST example
1 parent 05b4384 commit 1720934

17 files changed

+277
-506
lines changed

Examples/MNIST.cs

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,40 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.Diagnostics;
34
using TorchSharp.Tensor;
45

56
namespace TorchSharp.Examples
67
{
78
public class MNIST
89
{
910
private readonly static int _epochs = 10;
10-
private readonly static long _batch = 64;
11-
private readonly static string _trainDataset = @"E:/Source/Repos/LibTorchSharp/MNIST";
11+
private readonly static long _trainBatchSize = 64;
12+
private readonly static long _testBatchSize = 1000;
13+
private readonly static string _dataLocation = @"E:/Source/Repos/LibTorchSharp/MNIST";
14+
15+
private readonly static int _logInterval = 10;
1216

1317
static void Main(string[] args)
1418
{
15-
using (var train = Data.Loader.MNIST(_trainDataset, _batch))
16-
using (var test = Data.Loader.MNIST(_trainDataset, _batch, false))
19+
Torch.SetSeed(1);
20+
21+
using (var train = Data.Loader.MNIST(_dataLocation, _trainBatchSize))
22+
using (var test = Data.Loader.MNIST(_dataLocation, _testBatchSize, false))
1723
using (var model = new Model())
1824
using (var optimizer = NN.Optimizer.SGD(model.Parameters(), 0.01, 0.5))
1925
{
26+
Stopwatch sw = new Stopwatch();
27+
sw.Start();
28+
2029
for (var epoch = 1; epoch <= _epochs; epoch++)
2130
{
22-
Train(model, optimizer, train, epoch, _batch, train.Size());
31+
Train(model, optimizer, train, epoch, _trainBatchSize, train.Size());
2332
Test(model, test, test.Size());
2433
}
34+
35+
sw.Stop();
36+
Console.WriteLine($"Elapsed time {sw.ElapsedMilliseconds}.");
37+
Console.ReadLine();
2538
}
2639
}
2740

@@ -85,7 +98,10 @@ private static void Train(
8598

8699
optimizer.Step();
87100

88-
Console.WriteLine($"\rTrain: epoch {epoch} [{batchId * batchSize} / {size}] Loss: {loss.Item}");
101+
if (batchId % _logInterval == 0)
102+
{
103+
Console.WriteLine($"\rTrain: epoch {epoch} [{batchId * batchSize} / {size}] Loss: {loss.Item}");
104+
}
89105

90106
batchId++;
91107

@@ -116,16 +132,14 @@ private static void Test(
116132

117133
correct += pred.Eq(target).Sum().Item; // Memory leak here
118134

119-
testLoss /= size;
120-
121135
data.Dispose();
122136
target.Dispose();
123137
pred.Dispose();
124138
}
125139

126140
}
127141

128-
Console.WriteLine($"\rTest set: Average loss {testLoss} | Accuracy {(double)correct / size}");
142+
Console.WriteLine($"\rTest set: Average loss {testLoss / size} | Accuracy {(double)correct / size}");
129143
}
130144
}
131145
}

Test/Test.csproj

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
<Project Sdk="Microsoft.NET.Sdk">
22

33
<PropertyGroup>
4-
<TargetFramework>netcoreapp2.0</TargetFramework>
4+
<TargetFramework>netcoreapp2.1</TargetFramework>
55

66
<IsPackable>false</IsPackable>
77
</PropertyGroup>
@@ -10,10 +10,14 @@
1010
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
1111
</PropertyGroup>
1212

13+
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'">
14+
<PlatformTarget>x64</PlatformTarget>
15+
</PropertyGroup>
16+
1317
<ItemGroup>
14-
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.8.0" />
15-
<PackageReference Include="MSTest.TestAdapter" Version="1.3.2" />
16-
<PackageReference Include="MSTest.TestFramework" Version="1.3.2" />
18+
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.0.1" />
19+
<PackageReference Include="MSTest.TestAdapter" Version="1.4.0" />
20+
<PackageReference Include="MSTest.TestFramework" Version="1.4.0" />
1721
</ItemGroup>
1822

1923
<ItemGroup>

Test/TorchSharp.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
using Microsoft.VisualStudio.TestTools.UnitTesting;
22
using System;
3-
using System.Linq;
43
using TorchSharp.Tensor;
54

65
namespace TorchSharp.Test

0 commit comments

Comments
 (0)