Skip to content

Commit 471053f

Browse files
committed
Edits following Artidoro review
1 parent 1c4a162 commit 471053f

File tree

3 files changed

+32
-75
lines changed

3 files changed

+32
-75
lines changed

Examples/MNIST.cs

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using System.Collections.Generic;
33
using System.Diagnostics;
44
using TorchSharp.Tensor;
5+
using static TorchSharp.NN.LossFunction;
56

67
namespace TorchSharp.Examples
78
{
@@ -28,8 +29,8 @@ static void Main(string[] args)
2829

2930
for (var epoch = 1; epoch <= _epochs; epoch++)
3031
{
31-
Train(model, optimizer, train, epoch, _trainBatchSize, train.Size());
32-
Test(model, test, test.Size());
32+
Train(model, optimizer, NLL(), train, epoch, _trainBatchSize, train.Size());
33+
Test(model, NLL(reduction: NN.Reduction.Sum), test, test.Size());
3334
}
3435

3536
sw.Stop();
@@ -79,6 +80,7 @@ public override TorchTensor Forward(TorchTensor input)
7980
private static void Train(
8081
NN.Module model,
8182
NN.Optimizer optimizer,
83+
Loss loss,
8284
IEnumerable<(TorchTensor, TorchTensor)> dataLoader,
8385
int epoch,
8486
long batchSize,
@@ -92,16 +94,16 @@ private static void Train(
9294
{
9395
optimizer.ZeroGrad();
9496

95-
using (var output = model.Forward(data))
96-
using (var loss = NN.LossFunction.NLL(output, target))
97+
using (var prediction = model.Forward(data))
98+
using (var output = loss(prediction, target))
9799
{
98-
loss.Backward();
100+
output.Backward();
99101

100102
optimizer.Step();
101103

102104
if (batchId % _logInterval == 0)
103105
{
104-
Console.WriteLine($"\rTrain: epoch {epoch} [{batchId * batchSize} / {size}] Loss: {loss.DataItem<float>()}");
106+
Console.WriteLine($"\rTrain: epoch {epoch} [{batchId * batchSize} / {size}] Loss: {output.DataItem<float>()}");
105107
}
106108

107109
batchId++;
@@ -114,6 +116,7 @@ private static void Train(
114116

115117
private static void Test(
116118
NN.Module model,
119+
Loss loss,
117120
IEnumerable<(TorchTensor, TorchTensor)> dataLoader,
118121
long size)
119122
{
@@ -124,10 +127,10 @@ private static void Test(
124127

125128
foreach (var (data, target) in dataLoader)
126129
{
127-
using (var output = model.Forward(data))
128-
using (var loss = NN.LossFunction.NLL(output, target, reduction: NN.Reduction.Sum))
130+
using (var prediction = model.Forward(data))
131+
using (var output = loss(prediction, target))
129132
{
130-
testLoss += loss.DataItem<float>();
133+
testLoss += output.DataItem<float>();
131134

132135
var pred = output.Argmax(1);
133136

Test/TorchSharp/TorchSharp.cs

Lines changed: 20 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -431,9 +431,10 @@ public void EvalLossSequence()
431431
var y = FloatTensor.RandomN(new long[] { 64, 10 }, device: "cpu:0");
432432

433433
var eval = seq.Forward(x);
434-
var loss = NN.LossFunction.MSE(eval, y, NN.Reduction.Sum);
434+
var loss = NN.LossFunction.MSE(NN.Reduction.Sum);
435+
var output = loss(eval, y);
435436

436-
var result = loss.DataItem<float>();
437+
var result = output.DataItem<float>();
437438
Assert.IsNotNull(result);
438439
}
439440

@@ -444,9 +445,9 @@ public void TestPoissonNLLLoss()
444445
using (TorchTensor target = FloatTensor.From(new float[] { 1f, 2f, 3f }))
445446
{
446447
var componentWiseLoss = ((TorchTensor)input.Exp()) - target * input;
447-
Assert.IsTrue(componentWiseLoss.Equal(NN.LossFunction.PoissonNLL(input, target, reduction: NN.Reduction.None)));
448-
Assert.IsTrue(componentWiseLoss.Sum().Equal(NN.LossFunction.PoissonNLL(input, target, reduction: NN.Reduction.Sum)));
449-
Assert.IsTrue(componentWiseLoss.Mean().Equal(NN.LossFunction.PoissonNLL(input, target, reduction: NN.Reduction.Mean)));
448+
Assert.IsTrue(componentWiseLoss.Equal(NN.LossFunction.PoissonNLL(reduction: NN.Reduction.None)(input, target)));
449+
Assert.IsTrue(componentWiseLoss.Sum().Equal(NN.LossFunction.PoissonNLL(reduction: NN.Reduction.Sum)(input, target)));
450+
Assert.IsTrue(componentWiseLoss.Mean().Equal(NN.LossFunction.PoissonNLL(reduction: NN.Reduction.Mean)(input, target)));
450451
}
451452
}
452453

@@ -456,7 +457,7 @@ public void TestPoissonNLLLoss2()
456457
using (TorchTensor input = FloatTensor.Random(new long[] { 5, 2 }))
457458
using (TorchTensor target = FloatTensor.Random(new long[] { 5, 2 }))
458459
{
459-
Assert.IsNotNull(NN.LossFunction.PoissonNLL(input, target, true, true));
460+
Assert.IsNotNull(NN.LossFunction.PoissonNLL(true, true)(input, target));
460461
}
461462
}
462463

@@ -481,11 +482,12 @@ public void TestBackward()
481482
var y = FloatTensor.RandomN(new long[] { 64, 10 }, device: "cpu:0");
482483

483484
var eval = seq.Forward(x);
484-
var loss = NN.LossFunction.MSE(eval, y, NN.Reduction.None);
485+
var loss = NN.LossFunction.MSE(NN.Reduction.None);
486+
var output = loss(eval, y);
485487

486488
seq.ZeroGrad();
487489

488-
loss.Backward();
490+
output.Backward();
489491
}
490492

491493
[TestMethod]
@@ -499,11 +501,12 @@ public void TestGettingParameters()
499501
var y = FloatTensor.RandomN(new long[] { 64, 10 }, device: "cpu:0");
500502

501503
var eval = seq.Forward(x);
502-
var loss = NN.LossFunction.MSE(eval, y, NN.Reduction.None);
504+
var loss = NN.LossFunction.MSE(NN.Reduction.None);
505+
var output = loss(eval, y);
503506

504507
seq.ZeroGrad();
505508

506-
loss.Backward();
509+
output.Backward();
507510

508511
foreach (var parm in seq.Parameters())
509512
{
@@ -522,11 +525,12 @@ public void TestGrad()
522525
var y = FloatTensor.RandomN(new long[] { 64, 10 }, device: "cpu:0");
523526

524527
var eval = seq.Forward(x);
525-
var loss = NN.LossFunction.MSE(eval, y, NN.Reduction.None);
528+
var loss = NN.LossFunction.MSE(NN.Reduction.None);
529+
var output = loss(eval, y);
526530

527531
seq.ZeroGrad();
528532

529-
loss.Backward();
533+
output.Backward();
530534

531535
foreach (var parm in seq.Parameters())
532536
{
@@ -658,19 +662,20 @@ public void TestTraining()
658662

659663
float learning_rate = 0.00004f;
660664
float prevLoss = float.MaxValue;
665+
var loss = NN.LossFunction.MSE(NN.Reduction.Sum);
661666

662667
for (int i = 0; i < 10; i++)
663668
{
664669
var eval = seq.Forward(x);
665-
var loss = NN.LossFunction.MSE(eval, y, NN.Reduction.Sum);
666-
var lossVal = loss.DataItem<float>();
670+
var output = loss(eval, y);
671+
var lossVal = output.DataItem<float>();
667672

668673
Assert.IsTrue(lossVal < prevLoss);
669674
prevLoss = lossVal;
670675

671676
seq.ZeroGrad();
672677

673-
loss.Backward();
678+
output.Backward();
674679

675680
using (var noGrad = new AutoGradMode(false))
676681
{
@@ -712,37 +717,6 @@ public void TestTrainingAdam()
712717
var x = FloatTensor.RandomN(new long[] { 64, 1000 }, device: "cpu:0");
713718
var y = FloatTensor.RandomN(new long[] { 64, 10 }, device: "cpu:0");
714719

715-
double learning_rate = 0.00004f;
716-
float prevLoss = float.MaxValue;
717-
var optimizer = NN.Optimizer.Adam(seq.Parameters(), learning_rate);
718-
719-
for (int i = 0; i < 10; i++)
720-
{
721-
var eval = seq.Forward(x);
722-
var loss = NN.LossFunction.MSE(eval, y, NN.Reduction.Sum);
723-
var lossVal = loss.DataItem<float>();
724-
725-
Assert.IsTrue(lossVal < prevLoss);
726-
prevLoss = lossVal;
727-
728-
optimizer.ZeroGrad();
729-
730-
loss.Backward();
731-
732-
optimizer.Step();
733-
}
734-
}
735-
736-
[TestMethod]
737-
public void TestTrainingAdam2()
738-
{
739-
var lin1 = NN.Module.Linear(1000, 100);
740-
var lin2 = NN.Module.Linear(100, 10);
741-
var seq = NN.Module.Sequential(lin1, NN.Module.Relu(), lin2);
742-
743-
var x = FloatTensor.RandomN(new long[] { 64, 1000 }, device: "cpu:0");
744-
var y = FloatTensor.RandomN(new long[] { 64, 10 }, device: "cpu:0");
745-
746720
double learning_rate = 0.00004f;
747721
float prevLoss = float.MaxValue;
748722
var optimizer = NN.Optimizer.Adam(seq.Parameters(), learning_rate);

TorchSharp/NN/LossFunction.cs

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,6 @@ public class LossFunction
1414
[DllImport("libTorchSharp")]
1515
extern static IntPtr THSNN_lossBCE(IntPtr srct, IntPtr trgt, IntPtr wgt, long reduction);
1616

17-
public static TorchTensor BCE(TorchTensor src, TorchTensor target, TorchTensor? weigths = null, Reduction reduction = Reduction.Mean)
18-
{
19-
return new TorchTensor(THSNN_lossBCE(src.Handle, target.Handle, weigths?.Handle ?? IntPtr.Zero, (long)reduction));
20-
}
21-
2217
public static Loss BCE(TorchTensor? weigths = null, Reduction reduction = Reduction.Mean)
2318
{
2419
return (TorchTensor src, TorchTensor target) => new TorchTensor(THSNN_lossBCE(src.Handle, target.Handle, weigths?.Handle ?? IntPtr.Zero, (long)reduction));
@@ -27,11 +22,6 @@ public static Loss BCE(TorchTensor? weigths = null, Reduction reduction = Reduct
2722
[DllImport("libTorchSharp")]
2823
extern static IntPtr THSNN_lossMSE(IntPtr srct, IntPtr trgt, long reduction);
2924

30-
public static TorchTensor MSE(TorchTensor src, TorchTensor target, Reduction reduction = Reduction.Mean)
31-
{
32-
return new TorchTensor(THSNN_lossMSE(src.Handle, target.Handle, (long)reduction));
33-
}
34-
3525
public static Loss MSE(Reduction reduction = Reduction.Mean)
3626
{
3727
return (TorchTensor src, TorchTensor target) => new TorchTensor(THSNN_lossMSE(src.Handle, target.Handle, (long)reduction));
@@ -40,11 +30,6 @@ public static Loss MSE(Reduction reduction = Reduction.Mean)
4030
[DllImport("libTorchSharp")]
4131
extern static IntPtr THSNN_lossNLL(IntPtr srct, IntPtr trgt, IntPtr wgt, long reduction);
4232

43-
public static TorchTensor NLL(TorchTensor src, TorchTensor target, TorchTensor? weigths = null, Reduction reduction = Reduction.Mean)
44-
{
45-
return new TorchTensor(THSNN_lossNLL(src.Handle, target.Handle, weigths?.Handle ?? IntPtr.Zero, (long)reduction));
46-
}
47-
4833
public static Loss NLL(TorchTensor? weigths = null, Reduction reduction = Reduction.Mean)
4934
{
5035
return (TorchTensor src, TorchTensor target) => new TorchTensor(THSNN_lossNLL(src.Handle, target.Handle, weigths?.Handle ?? IntPtr.Zero, (long)reduction));
@@ -53,11 +38,6 @@ public static Loss NLL(TorchTensor? weigths = null, Reduction reduction = Reduct
5338
[DllImport("libTorchSharp")]
5439
extern static IntPtr THSNN_lossPoissonNLL(IntPtr srct, IntPtr trgt, bool logInput, bool full, float eps, long reduction);
5540

56-
public static TorchTensor PoissonNLL(TorchTensor src, TorchTensor target, bool logInput = true, bool full = false, float eps = 1e-8f, Reduction reduction = Reduction.Mean)
57-
{
58-
return new TorchTensor(THSNN_lossPoissonNLL(src.Handle, target.Handle, logInput, full, eps, (long)reduction));
59-
}
60-
6141
public static Loss PoissonNLL(bool logInput = true, bool full = false, float eps = 1e-8f, Reduction reduction = Reduction.Mean)
6242
{
6343
return (TorchTensor src, TorchTensor target) => new TorchTensor(THSNN_lossPoissonNLL(src.Handle, target.Handle, logInput, full, eps, (long)reduction));

0 commit comments

Comments
 (0)