|
1 | 1 | using System;
|
2 | 2 | using System.Collections.Generic;
|
| 3 | +using System.Diagnostics; |
3 | 4 | using TorchSharp.Tensor;
|
4 | 5 |
|
5 | 6 | namespace TorchSharp.Examples
|
6 | 7 | {
|
7 | 8 | public class MNIST
|
8 | 9 | {
|
9 | 10 | private readonly static int _epochs = 10;
|
10 |
| - private readonly static long _batch = 64; |
11 |
| - private readonly static string _trainDataset = @"E:/Source/Repos/LibTorchSharp/MNIST"; |
| 11 | + private readonly static long _trainBatchSize = 64; |
| 12 | + private readonly static long _testBatchSize = 1000; |
| 13 | + private readonly static string _dataLocation = @"E:/Source/Repos/LibTorchSharp/MNIST"; |
| 14 | + |
| 15 | + private readonly static int _logInterval = 10; |
12 | 16 |
|
13 | 17 | static void Main(string[] args)
|
14 | 18 | {
|
15 |
| - using (var train = Data.Loader.MNIST(_trainDataset, _batch)) |
16 |
| - using (var test = Data.Loader.MNIST(_trainDataset, _batch, false)) |
| 19 | + Torch.SetSeed(1); |
| 20 | + |
| 21 | + using (var train = Data.Loader.MNIST(_dataLocation, _trainBatchSize)) |
| 22 | + using (var test = Data.Loader.MNIST(_dataLocation, _testBatchSize, false)) |
17 | 23 | using (var model = new Model())
|
18 | 24 | using (var optimizer = NN.Optimizer.SGD(model.Parameters(), 0.01, 0.5))
|
19 | 25 | {
|
| 26 | + Stopwatch sw = new Stopwatch(); |
| 27 | + sw.Start(); |
| 28 | + |
20 | 29 | for (var epoch = 1; epoch <= _epochs; epoch++)
|
21 | 30 | {
|
22 |
| - Train(model, optimizer, train, epoch, _batch, train.Size()); |
| 31 | + Train(model, optimizer, train, epoch, _trainBatchSize, train.Size()); |
23 | 32 | Test(model, test, test.Size());
|
24 | 33 | }
|
| 34 | + |
| 35 | + sw.Stop(); |
| 36 | + Console.WriteLine($"Elapsed time {sw.ElapsedMilliseconds}."); |
| 37 | + Console.ReadLine(); |
25 | 38 | }
|
26 | 39 | }
|
27 | 40 |
|
@@ -85,7 +98,10 @@ private static void Train(
|
85 | 98 |
|
86 | 99 | optimizer.Step();
|
87 | 100 |
|
88 |
| - Console.WriteLine($"\rTrain: epoch {epoch} [{batchId * batchSize} / {size}] Loss: {loss.Item}"); |
| 101 | + if (batchId % _logInterval == 0) |
| 102 | + { |
| 103 | + Console.WriteLine($"\rTrain: epoch {epoch} [{batchId * batchSize} / {size}] Loss: {loss.Item}"); |
| 104 | + } |
89 | 105 |
|
90 | 106 | batchId++;
|
91 | 107 |
|
@@ -116,16 +132,14 @@ private static void Test(
|
116 | 132 |
|
117 | 133 | correct += pred.Eq(target).Sum().Item; // Memory leak here
|
118 | 134 |
|
119 |
| - testLoss /= size; |
120 |
| - |
121 | 135 | data.Dispose();
|
122 | 136 | target.Dispose();
|
123 | 137 | pred.Dispose();
|
124 | 138 | }
|
125 | 139 |
|
126 | 140 | }
|
127 | 141 |
|
128 |
| - Console.WriteLine($"\rTest set: Average loss {testLoss} | Accuracy {(double)correct / size}"); |
| 142 | + Console.WriteLine($"\rTest set: Average loss {testLoss / size} | Accuracy {(double)correct / size}"); |
129 | 143 | }
|
130 | 144 | }
|
131 | 145 | }
|
0 commit comments