Skip to content

Commit e3d39c2

Browse files
committed
Added epochs and dispose to mnist example
1 parent e0466b4 commit e3d39c2

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

Examples/MNIST.cs

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ namespace TorchSharp.Examples
66
{
77
public class MNIST
88
{
9+
private readonly static int _epochs = 10;
910
private readonly static long _batch = 64;
1011
private readonly static string _trainDataset = @"E:/Source/Repos/LibTorchSharp/MNIST";
1112

@@ -15,7 +16,10 @@ static void Main(string[] args)
1516
using (var model = new Model())
1617
using (var optimizer = NN.Optimizer.SGD(model.Parameters(), 0.01, 0.5))
1718
{
18-
Train(model, optimizer, train, _batch, train.Size());
19+
for (var epoch = 1; epoch <= _epochs; epoch++)
20+
{
21+
Train(model, optimizer, train, epoch, _batch, train.Size());
22+
}
1923
}
2024
}
2125

@@ -59,7 +63,8 @@ public override ITorchTensor<float> Forward<T>(ITorchTensor<T> tensor)
5963
private static void Train(
6064
NN.Module model,
6165
NN.Optimizer optimizer,
62-
IEnumerable<(ITorchTensor<int>, ITorchTensor<int>)> dataLoader,
66+
IEnumerable<(ITorchTensor<int>, ITorchTensor<int>)> dataLoader,
67+
int epoch,
6368
long batchSize,
6469
long size)
6570
{
@@ -71,11 +76,6 @@ private static void Train(
7176
{
7277
optimizer.ZeroGrad();
7378

74-
if (batchId == 937)
75-
{
76-
Console.WriteLine();
77-
}
78-
7979
using (var output = model.Forward(data))
8080
using (var loss = NN.LossFunction.NLL(output, target))
8181
{
@@ -85,7 +85,10 @@ private static void Train(
8585

8686
batchId++;
8787

88-
Console.WriteLine($"\rTrain: [{batchId * batchSize} / {size}] Loss: {loss.Item}");
88+
Console.WriteLine($"\rTrain: epoch {epoch} [{batchId * batchSize} / {size}] Loss: {loss.Item}");
89+
90+
data.Dispose();
91+
target.Dispose();
8992
}
9093
}
9194
}

0 commit comments

Comments
 (0)