Skip to content

Commit 887e784

Browse files
committed
Removed ITorchTensor interface and types. Added few additional features.
1 parent 437a95c commit 887e784

23 files changed

+350
-811
lines changed

Examples/MNIST.cs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,17 +45,17 @@ private class Model : NN.Module
4545
private NN.Module fc1 = Linear(320, 50);
4646
private NN.Module fc2 = Linear(50, 10);
4747

48-
public Model() : base()
48+
public Model()
4949
{
5050
RegisterModule(conv1);
5151
RegisterModule(conv2);
5252
RegisterModule(fc1);
5353
RegisterModule(fc2);
5454
}
5555

56-
public override ITorchTensor<float> Forward<T>(params ITorchTensor<T>[] tensors)
56+
public override TorchTensor Forward(TorchTensor input)
5757
{
58-
using (var l11 = conv1.Forward(tensors))
58+
using (var l11 = conv1.Forward(input))
5959
using (var l12 = MaxPool2D(l11, 2))
6060
using (var l13 = Relu(l12))
6161

@@ -79,7 +79,7 @@ public override ITorchTensor<float> Forward<T>(params ITorchTensor<T>[] tensors)
7979
private static void Train(
8080
NN.Module model,
8181
NN.Optimizer optimizer,
82-
IEnumerable<(ITorchTensor<int>, ITorchTensor<int>)> dataLoader,
82+
IEnumerable<(TorchTensor, TorchTensor)> dataLoader,
8383
int epoch,
8484
long batchSize,
8585
long size)
@@ -101,7 +101,7 @@ private static void Train(
101101

102102
if (batchId % _logInterval == 0)
103103
{
104-
Console.WriteLine($"\rTrain: epoch {epoch} [{batchId * batchSize} / {size}] Loss: {loss.DataItem}");
104+
Console.WriteLine($"\rTrain: epoch {epoch} [{batchId * batchSize} / {size}] Loss: {loss.DataItem<float>()}");
105105
}
106106

107107
batchId++;
@@ -114,7 +114,7 @@ private static void Train(
114114

115115
private static void Test(
116116
NN.Module model,
117-
IEnumerable<(ITorchTensor<int>, ITorchTensor<int>)> dataLoader,
117+
IEnumerable<(TorchTensor, TorchTensor)> dataLoader,
118118
long size)
119119
{
120120
model.Eval();
@@ -127,11 +127,11 @@ private static void Test(
127127
using (var output = model.Forward(data))
128128
using (var loss = NN.LossFunction.NLL(output, target, reduction: NN.Reduction.Sum))
129129
{
130-
testLoss += loss.DataItem;
130+
testLoss += loss.DataItem<float>();
131131

132132
var pred = output.Argmax(1);
133133

134-
correct += pred.Eq(target).Sum().DataItem; // Memory leak here
134+
correct += pred.Eq(target).Sum().DataItem<int>(); // Memory leak here
135135

136136
data.Dispose();
137137
target.Dispose();

0 commit comments

Comments
 (0)