@@ -6,6 +6,7 @@ namespace TorchSharp.Examples
6
6
{
7
7
public class MNIST
8
8
{
9
+ private readonly static int _epochs = 10 ;
9
10
private readonly static long _batch = 64 ;
10
11
private readonly static string _trainDataset = @"E:/Source/Repos/LibTorchSharp/MNIST" ;
11
12
@@ -15,7 +16,10 @@ static void Main(string[] args)
15
16
using ( var model = new Model ( ) )
16
17
using ( var optimizer = NN . Optimizer . SGD ( model . Parameters ( ) , 0.01 , 0.5 ) )
17
18
{
18
- Train ( model , optimizer , train , _batch , train . Size ( ) ) ;
19
+ for ( var epoch = 1 ; epoch <= _epochs ; epoch ++ )
20
+ {
21
+ Train ( model , optimizer , train , epoch , _batch , train . Size ( ) ) ;
22
+ }
19
23
}
20
24
}
21
25
@@ -59,7 +63,8 @@ public override ITorchTensor<float> Forward<T>(ITorchTensor<T> tensor)
59
63
private static void Train (
60
64
NN . Module model ,
61
65
NN . Optimizer optimizer ,
62
- IEnumerable < ( ITorchTensor < int > , ITorchTensor < int > ) > dataLoader ,
66
+ IEnumerable < ( ITorchTensor < int > , ITorchTensor < int > ) > dataLoader ,
67
+ int epoch ,
63
68
long batchSize ,
64
69
long size )
65
70
{
@@ -71,11 +76,6 @@ private static void Train(
71
76
{
72
77
optimizer . ZeroGrad ( ) ;
73
78
74
- if ( batchId == 937 )
75
- {
76
- Console . WriteLine ( ) ;
77
- }
78
-
79
79
using ( var output = model . Forward ( data ) )
80
80
using ( var loss = NN . LossFunction . NLL ( output , target ) )
81
81
{
@@ -85,7 +85,10 @@ private static void Train(
85
85
86
86
batchId ++ ;
87
87
88
- Console . WriteLine ( $ "\r Train: [{ batchId * batchSize } / { size } ] Loss: { loss . Item } ") ;
88
+ Console . WriteLine ( $ "\r Train: epoch { epoch } [{ batchId * batchSize } / { size } ] Loss: { loss . Item } ") ;
89
+
90
+ data . Dispose ( ) ;
91
+ target . Dispose ( ) ;
89
92
}
90
93
}
91
94
}
0 commit comments