Skip to content

Commit 05b4384

Browse files
committed
Added per epoch test of MNIST model.
1 parent 8558317 commit 05b4384

File tree

5 files changed

+229
-20
lines changed

5 files changed

+229
-20
lines changed

Examples/MNIST.cs

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@ public class MNIST
1313
static void Main(string[] args)
1414
{
1515
using (var train = Data.Loader.MNIST(_trainDataset, _batch))
16+
using (var test = Data.Loader.MNIST(_trainDataset, _batch, false))
1617
using (var model = new Model())
1718
using (var optimizer = NN.Optimizer.SGD(model.Parameters(), 0.01, 0.5))
1819
{
1920
for (var epoch = 1; epoch <= _epochs; epoch++)
2021
{
2122
Train(model, optimizer, train, epoch, _batch, train.Size());
23+
Test(model, test, test.Size());
2224
}
2325
}
2426
}
@@ -83,14 +85,47 @@ private static void Train(
8385

8486
optimizer.Step();
8587

88+
Console.WriteLine($"\rTrain: epoch {epoch} [{batchId * batchSize} / {size}] Loss: {loss.Item}");
89+
8690
batchId++;
8791

88-
Console.WriteLine($"\rTrain: epoch {epoch} [{batchId * batchSize} / {size}] Loss: {loss.Item}");
92+
data.Dispose();
93+
target.Dispose();
94+
}
95+
}
96+
}
97+
98+
private static void Test(
99+
NN.Module model,
100+
IEnumerable<(ITorchTensor<int>, ITorchTensor<int>)> dataLoader,
101+
long size)
102+
{
103+
model.Eval();
104+
105+
double testLoss = 0;
106+
int correct = 0;
107+
108+
foreach (var (data, target) in dataLoader)
109+
{
110+
using (var output = model.Forward(data))
111+
using (var loss = NN.LossFunction.NLL(output, target, NN.Reduction.Sum))
112+
{
113+
testLoss += loss.Item;
114+
115+
var pred = output.Argmax(1);
116+
117+
correct += pred.Eq(target).Sum().Item; // Memory leak here
118+
119+
testLoss /= size;
89120

90121
data.Dispose();
91122
target.Dispose();
123+
pred.Dispose();
92124
}
125+
93126
}
127+
128+
Console.WriteLine($"\rTest set: Average loss {testLoss} | Accuracy {(double)correct / size}");
94129
}
95130
}
96131
}

TorchSharp/Generated/TorchTensor.generated.cs

Lines changed: 156 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,22 @@ public ITorchTensor<byte> View(params long[] shape)
268268
}
269269
}
270270

271+
[DllImport("LibTorchSharp")]
272+
extern static HType THS_Sum(HType src);
273+
274+
public ITorchTensor<byte> Sum()
275+
{
276+
return new ByteTensor(THS_Sum(handle));
277+
}
278+
279+
[DllImport("LibTorchSharp")]
280+
extern static IntPtr THS_Eq(HType src, IntPtr trg);
281+
282+
public ITorchTensor<U> Eq<U>(ITorchTensor<U> target)
283+
{
284+
return TensorExtensionMethods.ToTorchTensor<U>(THS_Eq(handle, target.Handle));
285+
}
286+
271287
[DllImport("LibTorchSharp")]
272288
extern static HType THS_Sub_(HType src, IntPtr trg, bool is_grad);
273289

@@ -279,9 +295,17 @@ public ITorchTensor<byte> SubInPlace(ITorchTensor<byte> target, bool no_grad = t
279295
[DllImport("LibTorchSharp")]
280296
extern static HType THS_Mul(HType src, byte scalar, bool is_grad);
281297

282-
public ITorchTensor<byte> Mul(byte scalar, bool no_grad = true)
298+
public ITorchTensor<byte> Mul(byte scalar, bool noGrad = true)
283299
{
284-
return new ByteTensor(THS_Mul(handle, scalar, !no_grad));
300+
return new ByteTensor(THS_Mul(handle, scalar, !noGrad));
301+
}
302+
303+
[DllImport("LibTorchSharp")]
304+
extern static HType THS_Argmax(HType src, long dimension, bool keep_dim);
305+
306+
public ITorchTensor<byte> Argmax(long dimension, bool keepDim = false)
307+
{
308+
return new ByteTensor(THS_Argmax(handle, dimension, keepDim));
285309
}
286310

287311
/// <summary>
@@ -568,6 +592,22 @@ public ITorchTensor<short> View(params long[] shape)
568592
}
569593
}
570594

595+
[DllImport("LibTorchSharp")]
596+
extern static HType THS_Sum(HType src);
597+
598+
public ITorchTensor<short> Sum()
599+
{
600+
return new ShortTensor(THS_Sum(handle));
601+
}
602+
603+
[DllImport("LibTorchSharp")]
604+
extern static IntPtr THS_Eq(HType src, IntPtr trg);
605+
606+
public ITorchTensor<U> Eq<U>(ITorchTensor<U> target)
607+
{
608+
return TensorExtensionMethods.ToTorchTensor<U>(THS_Eq(handle, target.Handle));
609+
}
610+
571611
[DllImport("LibTorchSharp")]
572612
extern static HType THS_Sub_(HType src, IntPtr trg, bool is_grad);
573613

@@ -579,9 +619,17 @@ public ITorchTensor<short> SubInPlace(ITorchTensor<short> target, bool no_grad =
579619
[DllImport("LibTorchSharp")]
580620
extern static HType THS_Mul(HType src, short scalar, bool is_grad);
581621

582-
public ITorchTensor<short> Mul(short scalar, bool no_grad = true)
622+
public ITorchTensor<short> Mul(short scalar, bool noGrad = true)
583623
{
584-
return new ShortTensor(THS_Mul(handle, scalar, !no_grad));
624+
return new ShortTensor(THS_Mul(handle, scalar, !noGrad));
625+
}
626+
627+
[DllImport("LibTorchSharp")]
628+
extern static HType THS_Argmax(HType src, long dimension, bool keep_dim);
629+
630+
public ITorchTensor<short> Argmax(long dimension, bool keepDim = false)
631+
{
632+
return new ShortTensor(THS_Argmax(handle, dimension, keepDim));
585633
}
586634

587635
/// <summary>
@@ -868,6 +916,22 @@ public ITorchTensor<int> View(params long[] shape)
868916
}
869917
}
870918

919+
[DllImport("LibTorchSharp")]
920+
extern static HType THS_Sum(HType src);
921+
922+
public ITorchTensor<int> Sum()
923+
{
924+
return new IntTensor(THS_Sum(handle));
925+
}
926+
927+
[DllImport("LibTorchSharp")]
928+
extern static IntPtr THS_Eq(HType src, IntPtr trg);
929+
930+
public ITorchTensor<U> Eq<U>(ITorchTensor<U> target)
931+
{
932+
return TensorExtensionMethods.ToTorchTensor<U>(THS_Eq(handle, target.Handle));
933+
}
934+
871935
[DllImport("LibTorchSharp")]
872936
extern static HType THS_Sub_(HType src, IntPtr trg, bool is_grad);
873937

@@ -879,9 +943,17 @@ public ITorchTensor<int> SubInPlace(ITorchTensor<int> target, bool no_grad = tru
879943
[DllImport("LibTorchSharp")]
880944
extern static HType THS_Mul(HType src, int scalar, bool is_grad);
881945

882-
public ITorchTensor<int> Mul(int scalar, bool no_grad = true)
946+
public ITorchTensor<int> Mul(int scalar, bool noGrad = true)
947+
{
948+
return new IntTensor(THS_Mul(handle, scalar, !noGrad));
949+
}
950+
951+
[DllImport("LibTorchSharp")]
952+
extern static HType THS_Argmax(HType src, long dimension, bool keep_dim);
953+
954+
public ITorchTensor<int> Argmax(long dimension, bool keepDim = false)
883955
{
884-
return new IntTensor(THS_Mul(handle, scalar, !no_grad));
956+
return new IntTensor(THS_Argmax(handle, dimension, keepDim));
885957
}
886958

887959
/// <summary>
@@ -1168,6 +1240,22 @@ public ITorchTensor<long> View(params long[] shape)
11681240
}
11691241
}
11701242

1243+
[DllImport("LibTorchSharp")]
1244+
extern static HType THS_Sum(HType src);
1245+
1246+
public ITorchTensor<long> Sum()
1247+
{
1248+
return new LongTensor(THS_Sum(handle));
1249+
}
1250+
1251+
[DllImport("LibTorchSharp")]
1252+
extern static IntPtr THS_Eq(HType src, IntPtr trg);
1253+
1254+
public ITorchTensor<U> Eq<U>(ITorchTensor<U> target)
1255+
{
1256+
return TensorExtensionMethods.ToTorchTensor<U>(THS_Eq(handle, target.Handle));
1257+
}
1258+
11711259
[DllImport("LibTorchSharp")]
11721260
extern static HType THS_Sub_(HType src, IntPtr trg, bool is_grad);
11731261

@@ -1179,9 +1267,17 @@ public ITorchTensor<long> SubInPlace(ITorchTensor<long> target, bool no_grad = t
11791267
[DllImport("LibTorchSharp")]
11801268
extern static HType THS_Mul(HType src, long scalar, bool is_grad);
11811269

1182-
public ITorchTensor<long> Mul(long scalar, bool no_grad = true)
1270+
public ITorchTensor<long> Mul(long scalar, bool noGrad = true)
11831271
{
1184-
return new LongTensor(THS_Mul(handle, scalar, !no_grad));
1272+
return new LongTensor(THS_Mul(handle, scalar, !noGrad));
1273+
}
1274+
1275+
[DllImport("LibTorchSharp")]
1276+
extern static HType THS_Argmax(HType src, long dimension, bool keep_dim);
1277+
1278+
public ITorchTensor<long> Argmax(long dimension, bool keepDim = false)
1279+
{
1280+
return new LongTensor(THS_Argmax(handle, dimension, keepDim));
11851281
}
11861282

11871283
/// <summary>
@@ -1468,6 +1564,22 @@ public ITorchTensor<double> View(params long[] shape)
14681564
}
14691565
}
14701566

1567+
[DllImport("LibTorchSharp")]
1568+
extern static HType THS_Sum(HType src);
1569+
1570+
public ITorchTensor<double> Sum()
1571+
{
1572+
return new DoubleTensor(THS_Sum(handle));
1573+
}
1574+
1575+
[DllImport("LibTorchSharp")]
1576+
extern static IntPtr THS_Eq(HType src, IntPtr trg);
1577+
1578+
public ITorchTensor<U> Eq<U>(ITorchTensor<U> target)
1579+
{
1580+
return TensorExtensionMethods.ToTorchTensor<U>(THS_Eq(handle, target.Handle));
1581+
}
1582+
14711583
[DllImport("LibTorchSharp")]
14721584
extern static HType THS_Sub_(HType src, IntPtr trg, bool is_grad);
14731585

@@ -1479,9 +1591,17 @@ public ITorchTensor<double> SubInPlace(ITorchTensor<double> target, bool no_grad
14791591
[DllImport("LibTorchSharp")]
14801592
extern static HType THS_Mul(HType src, double scalar, bool is_grad);
14811593

1482-
public ITorchTensor<double> Mul(double scalar, bool no_grad = true)
1594+
public ITorchTensor<double> Mul(double scalar, bool noGrad = true)
14831595
{
1484-
return new DoubleTensor(THS_Mul(handle, scalar, !no_grad));
1596+
return new DoubleTensor(THS_Mul(handle, scalar, !noGrad));
1597+
}
1598+
1599+
[DllImport("LibTorchSharp")]
1600+
extern static HType THS_Argmax(HType src, long dimension, bool keep_dim);
1601+
1602+
public ITorchTensor<double> Argmax(long dimension, bool keepDim = false)
1603+
{
1604+
return new DoubleTensor(THS_Argmax(handle, dimension, keepDim));
14851605
}
14861606

14871607
/// <summary>
@@ -1768,6 +1888,22 @@ public ITorchTensor<float> View(params long[] shape)
17681888
}
17691889
}
17701890

1891+
[DllImport("LibTorchSharp")]
1892+
extern static HType THS_Sum(HType src);
1893+
1894+
public ITorchTensor<float> Sum()
1895+
{
1896+
return new FloatTensor(THS_Sum(handle));
1897+
}
1898+
1899+
[DllImport("LibTorchSharp")]
1900+
extern static IntPtr THS_Eq(HType src, IntPtr trg);
1901+
1902+
public ITorchTensor<U> Eq<U>(ITorchTensor<U> target)
1903+
{
1904+
return TensorExtensionMethods.ToTorchTensor<U>(THS_Eq(handle, target.Handle));
1905+
}
1906+
17711907
[DllImport("LibTorchSharp")]
17721908
extern static HType THS_Sub_(HType src, IntPtr trg, bool is_grad);
17731909

@@ -1779,9 +1915,17 @@ public ITorchTensor<float> SubInPlace(ITorchTensor<float> target, bool no_grad =
17791915
[DllImport("LibTorchSharp")]
17801916
extern static HType THS_Mul(HType src, float scalar, bool is_grad);
17811917

1782-
public ITorchTensor<float> Mul(float scalar, bool no_grad = true)
1918+
public ITorchTensor<float> Mul(float scalar, bool noGrad = true)
1919+
{
1920+
return new FloatTensor(THS_Mul(handle, scalar, !noGrad));
1921+
}
1922+
1923+
[DllImport("LibTorchSharp")]
1924+
extern static HType THS_Argmax(HType src, long dimension, bool keep_dim);
1925+
1926+
public ITorchTensor<float> Argmax(long dimension, bool keepDim = false)
17831927
{
1784-
return new FloatTensor(THS_Mul(handle, scalar, !no_grad));
1928+
return new FloatTensor(THS_Argmax(handle, dimension, keepDim));
17851929
}
17861930

17871931
/// <summary>

TorchSharp/Generated/TorchTensor.tt

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,22 @@ foreach (var type in TorchTypeDef.Types) {
275275
}
276276
}
277277

278+
[DllImport("LibTorchSharp")]
279+
extern static HType THS_Sum(HType src);
280+
281+
public ITorchTensor<<#=type.Storage#>> Sum()
282+
{
283+
return new <#=type.Name#>Tensor(THS_Sum(handle));
284+
}
285+
286+
[DllImport("LibTorchSharp")]
287+
extern static IntPtr THS_Eq(HType src, IntPtr trg);
288+
289+
public ITorchTensor<U> Eq<U>(ITorchTensor<U> target)
290+
{
291+
return TensorExtensionMethods.ToTorchTensor<U>(THS_Eq(handle, target.Handle));
292+
}
293+
278294
[DllImport("LibTorchSharp")]
279295
extern static HType THS_Sub_(HType src, IntPtr trg, bool is_grad);
280296

@@ -286,9 +302,17 @@ foreach (var type in TorchTypeDef.Types) {
286302
[DllImport("LibTorchSharp")]
287303
extern static HType THS_Mul(HType src, <#=type.Storage#> scalar, bool is_grad);
288304

289-
public ITorchTensor<<#=type.Storage#>> Mul(<#=type.Storage#> scalar, bool no_grad = true)
305+
public ITorchTensor<<#=type.Storage#>> Mul(<#=type.Storage#> scalar, bool noGrad = true)
306+
{
307+
return new <#=type.Name#>Tensor(THS_Mul(handle, scalar, !noGrad));
308+
}
309+
310+
[DllImport("LibTorchSharp")]
311+
extern static HType THS_Argmax(HType src, long dimension, bool keep_dim);
312+
313+
public ITorchTensor<<#=type.Storage#>> Argmax(long dimension, bool keepDim = false)
290314
{
291-
return new <#=type.Name#>Tensor(THS_Mul(handle, scalar, !no_grad));
315+
return new <#=type.Name#>Tensor(THS_Argmax(handle, dimension, keepDim));
292316
}
293317

294318
/// <summary>

TorchSharp/NN/LossFunction.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@ public static ITorchTensor<float> MSE<T>(ITorchTensor<T> src, ITorchTensor<T> ta
1818
}
1919

2020
[DllImport("LibTorchSharp")]
21-
extern static FloatTensor.HType NN_LossNLL(IntPtr srct, IntPtr trgt);
21+
extern static FloatTensor.HType NN_LossNLL(IntPtr srct, IntPtr trgt, long reduction);
2222

23-
public static ITorchTensor<float> NLL<T, U>(ITorchTensor<T> src, ITorchTensor<U> target)
23+
public static ITorchTensor<float> NLL<T, U>(ITorchTensor<T> src, ITorchTensor<U> target, Reduction reduction = Reduction.None)
2424
{
25-
return new FloatTensor(NN_LossNLL(src.Handle, target.Handle));
25+
return new FloatTensor(NN_LossNLL(src.Handle, target.Handle, (long)reduction));
2626
}
2727
}
2828

TorchSharp/Tensor/ITorchTensor.cs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,14 @@ public interface ITorchTensor<T> : IDisposable
2828

2929
ITorchTensor<T> View(params long[] shape);
3030

31-
ITorchTensor<T> SubInPlace(ITorchTensor<T> target, bool no_grad = true);
31+
ITorchTensor<U> Eq<U>(ITorchTensor<U> target);
3232

33-
ITorchTensor<T> Mul(T scalar, bool no_grad = true);
33+
ITorchTensor<T> SubInPlace(ITorchTensor<T> target, bool noGrad = true);
34+
35+
ITorchTensor<T> Mul(T scalar, bool noGrad = true);
36+
37+
ITorchTensor<T> Sum();
38+
39+
ITorchTensor<T> Argmax(long dimension, bool keepDimension = false);
3440
}
3541
}

0 commit comments

Comments
 (0)