@@ -45,17 +45,17 @@ private class Model : NN.Module
45
45
private NN . Module fc1 = Linear ( 320 , 50 ) ;
46
46
private NN . Module fc2 = Linear ( 50 , 10 ) ;
47
47
48
- public Model ( ) : base ( )
48
+ public Model ( )
49
49
{
50
50
RegisterModule ( conv1 ) ;
51
51
RegisterModule ( conv2 ) ;
52
52
RegisterModule ( fc1 ) ;
53
53
RegisterModule ( fc2 ) ;
54
54
}
55
55
56
- public override ITorchTensor < float > Forward < T > ( params ITorchTensor < T > [ ] tensors )
56
+ public override TorchTensor Forward ( TorchTensor input )
57
57
{
58
- using ( var l11 = conv1 . Forward ( tensors ) )
58
+ using ( var l11 = conv1 . Forward ( input ) )
59
59
using ( var l12 = MaxPool2D ( l11 , 2 ) )
60
60
using ( var l13 = Relu ( l12 ) )
61
61
@@ -79,7 +79,7 @@ public override ITorchTensor<float> Forward<T>(params ITorchTensor<T>[] tensors)
79
79
private static void Train (
80
80
NN . Module model ,
81
81
NN . Optimizer optimizer ,
82
- IEnumerable < ( ITorchTensor < int > , ITorchTensor < int > ) > dataLoader ,
82
+ IEnumerable < ( TorchTensor , TorchTensor ) > dataLoader ,
83
83
int epoch ,
84
84
long batchSize ,
85
85
long size )
@@ -101,7 +101,7 @@ private static void Train(
101
101
102
102
if ( batchId % _logInterval == 0 )
103
103
{
104
- Console . WriteLine ( $ "\r Train: epoch { epoch } [{ batchId * batchSize } / { size } ] Loss: { loss . DataItem } ") ;
104
+ Console . WriteLine ( $ "\r Train: epoch { epoch } [{ batchId * batchSize } / { size } ] Loss: { loss . DataItem < float > ( ) } ") ;
105
105
}
106
106
107
107
batchId ++ ;
@@ -114,7 +114,7 @@ private static void Train(
114
114
115
115
private static void Test (
116
116
NN . Module model ,
117
- IEnumerable < ( ITorchTensor < int > , ITorchTensor < int > ) > dataLoader ,
117
+ IEnumerable < ( TorchTensor , TorchTensor ) > dataLoader ,
118
118
long size )
119
119
{
120
120
model . Eval ( ) ;
@@ -127,11 +127,11 @@ private static void Test(
127
127
using ( var output = model . Forward ( data ) )
128
128
using ( var loss = NN . LossFunction . NLL ( output , target , reduction : NN . Reduction . Sum ) )
129
129
{
130
- testLoss += loss . DataItem ;
130
+ testLoss += loss . DataItem < float > ( ) ;
131
131
132
132
var pred = output . Argmax ( 1 ) ;
133
133
134
- correct += pred . Eq ( target ) . Sum ( ) . DataItem ; // Memory leak here
134
+ correct += pred . Eq ( target ) . Sum ( ) . DataItem < int > ( ) ; // Memory leak here
135
135
136
136
data . Dispose ( ) ;
137
137
target . Dispose ( ) ;
0 commit comments