@@ -6,26 +6,25 @@ namespace TorchSharp.Examples
6
6
{
7
7
public class MNIST
8
8
{
9
+ private readonly static long _batch = 64 ;
10
+ private readonly static string _trainDataset = @"E:/Source/Repos/LibTorchSharp/MNIST" ;
11
+
9
12
static void Main ( string [ ] args )
10
13
{
11
- var train = Data . Loader . MNIST ( @"E:/Source/Repos/LibTorchSharp/MNIST" , 64 , out int size ) ;
12
-
13
- var model = new Model ( ) ;
14
-
15
- var optimizer = NN . Optimizer . SGD ( model . Parameters ( ) , 0.01 , 0.5 ) ;
16
-
17
- for ( var epoch = 1 ; epoch <= 10 ; epoch ++ )
14
+ using ( var train = Data . Loader . MNIST ( _trainDataset , _batch ) )
15
+ using ( var model = new Model ( ) )
16
+ using ( var optimizer = NN . Optimizer . SGD ( model . Parameters ( ) , 0.01 , 0.5 ) )
18
17
{
19
- Train ( model , optimizer , train , epoch , size ) ;
18
+ Train ( model , optimizer , train , _batch , train . Size ( ) ) ;
20
19
}
21
20
}
22
21
23
22
private class Model : NN . Module
24
23
{
25
- private NN . Module conv1 = NN . Module . Conv2D ( 1 , 10 , 5 ) ;
26
- private NN . Module conv2 = NN . Module . Conv2D ( 10 , 20 , 5 ) ;
27
- private NN . Module fc1 = NN . Module . Linear ( 320 , 50 ) ;
28
- private NN . Module fc2 = NN . Module . Linear ( 50 , 10 ) ;
24
+ private NN . Module conv1 = Conv2D ( 1 , 10 , 5 ) ;
25
+ private NN . Module conv2 = Conv2D ( 10 , 20 , 5 ) ;
26
+ private NN . Module fc1 = Linear ( 320 , 50 ) ;
27
+ private NN . Module fc2 = Linear ( 50 , 10 ) ;
29
28
30
29
public Model ( ) : base ( IntPtr . Zero )
31
30
{
@@ -37,47 +36,57 @@ public Model() : base(IntPtr.Zero)
37
36
38
37
public override ITorchTensor < float > Forward < T > ( ITorchTensor < T > tensor )
39
38
{
40
- var x = conv1 . Forward ( tensor ) ;
41
- x = NN . Module . MaxPool2D ( x , 2 ) ;
42
- x = NN . Module . Relu ( x ) ;
39
+ using ( var l11 = conv1 . Forward ( tensor ) )
40
+ using ( var l12 = MaxPool2D ( l11 , 2 ) )
41
+ using ( var l13 = Relu ( l12 ) )
43
42
44
- x = conv2 . Forward ( x ) ;
45
- x = NN . Module . FeatureDropout ( x ) ;
46
- x = NN . Module . MaxPool2D ( x , 2 ) ;
43
+ using ( var l21 = conv2 . Forward ( l13 ) )
44
+ using ( var l22 = FeatureDropout ( l21 ) )
45
+ using ( var l23 = MaxPool2D ( l22 , 2 ) )
47
46
48
- x = x . View ( new long [ ] { - 1 , 320 } ) ;
47
+ using ( var x = l23 . View ( new long [ ] { - 1 , 320 } ) )
49
48
50
- x = fc1 . Forward ( x ) ;
51
- x = NN . Module . Relu ( x ) ;
52
- x = NN . Module . Dropout ( x , 0.5 , _isTraining ) ;
49
+ using ( var l31 = fc1 . Forward ( x ) )
50
+ using ( var l32 = Relu ( l31 ) )
51
+ using ( var l33 = Dropout ( l32 , 0.5 , _isTraining ) )
53
52
54
- x = fc2 . Forward ( x ) ;
53
+ using ( var l41 = fc2 . Forward ( l33 ) )
55
54
56
- return NN . Module . LogSoftMax ( x , 1 ) ;
55
+ return LogSoftMax ( l41 , 1 ) ;
57
56
}
58
57
}
59
58
60
- private static void Train ( NN . Module model , NN . Optimizer optimizer , IEnumerable < ( ITorchTensor < float > , ITorchTensor < float > ) > dataLoader , int epoch , int size )
59
+ private static void Train (
60
+ NN . Module model ,
61
+ NN . Optimizer optimizer ,
62
+ IEnumerable < ( ITorchTensor < int > , ITorchTensor < int > ) > dataLoader ,
63
+ long batchSize ,
64
+ long size )
61
65
{
62
66
model . Train ( ) ;
63
67
64
- int batchId = 0 ;
68
+ int batchId = 1 ;
65
69
66
70
foreach ( var ( data , target ) in dataLoader )
67
71
{
68
72
optimizer . ZeroGrad ( ) ;
69
73
70
- var output = model . Forward ( data ) ;
71
-
72
- var loss = NN . LossFunction . NLL ( output , target ) ;
74
+ if ( batchId == 937 )
75
+ {
76
+ Console . WriteLine ( ) ;
77
+ }
73
78
74
- loss . Backward ( ) ;
79
+ using ( var output = model . Forward ( data ) )
80
+ using ( var loss = NN . LossFunction . NLL ( output , target ) )
81
+ {
82
+ loss . Backward ( ) ;
75
83
76
- optimizer . Step ( ) ;
84
+ optimizer . Step ( ) ;
77
85
78
- batchId ++ ;
86
+ batchId ++ ;
79
87
80
- Console . WriteLine ( $ "\r Train Epoch: { epoch } [{ batchId } / { size } ] Loss: { loss . Item } ") ;
88
+ Console . WriteLine ( $ "\r Train: [{ batchId * batchSize } / { size } ] Loss: { loss . Item } ") ;
89
+ }
81
90
}
82
91
}
83
92
}
0 commit comments