Skip to content

Commit ecf6b90

Browse files
authored
Merge pull request #18 from interesaaat/LibTorchSharpFirstTest
Add the ability to specify loss functions as delegates
2 parents f8230ab + 471053f commit ecf6b90

File tree

3 files changed

+46
-35
lines changed

3 files changed

+46
-35
lines changed

Examples/MNIST.cs

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using System.Collections.Generic;
33
using System.Diagnostics;
44
using TorchSharp.Tensor;
5+
using static TorchSharp.NN.LossFunction;
56

67
namespace TorchSharp.Examples
78
{
@@ -28,8 +29,8 @@ static void Main(string[] args)
2829

2930
for (var epoch = 1; epoch <= _epochs; epoch++)
3031
{
31-
Train(model, optimizer, train, epoch, _trainBatchSize, train.Size());
32-
Test(model, test, test.Size());
32+
Train(model, optimizer, NLL(), train, epoch, _trainBatchSize, train.Size());
33+
Test(model, NLL(reduction: NN.Reduction.Sum), test, test.Size());
3334
}
3435

3536
sw.Stop();
@@ -79,6 +80,7 @@ public override TorchTensor Forward(TorchTensor input)
7980
private static void Train(
8081
NN.Module model,
8182
NN.Optimizer optimizer,
83+
Loss loss,
8284
IEnumerable<(TorchTensor, TorchTensor)> dataLoader,
8385
int epoch,
8486
long batchSize,
@@ -92,16 +94,16 @@ private static void Train(
9294
{
9395
optimizer.ZeroGrad();
9496

95-
using (var output = model.Forward(data))
96-
using (var loss = NN.LossFunction.NLL(output, target))
97+
using (var prediction = model.Forward(data))
98+
using (var output = loss(prediction, target))
9799
{
98-
loss.Backward();
100+
output.Backward();
99101

100102
optimizer.Step();
101103

102104
if (batchId % _logInterval == 0)
103105
{
104-
Console.WriteLine($"\rTrain: epoch {epoch} [{batchId * batchSize} / {size}] Loss: {loss.DataItem<float>()}");
106+
Console.WriteLine($"\rTrain: epoch {epoch} [{batchId * batchSize} / {size}] Loss: {output.DataItem<float>()}");
105107
}
106108

107109
batchId++;
@@ -114,6 +116,7 @@ private static void Train(
114116

115117
private static void Test(
116118
NN.Module model,
119+
Loss loss,
117120
IEnumerable<(TorchTensor, TorchTensor)> dataLoader,
118121
long size)
119122
{
@@ -124,10 +127,10 @@ private static void Test(
124127

125128
foreach (var (data, target) in dataLoader)
126129
{
127-
using (var output = model.Forward(data))
128-
using (var loss = NN.LossFunction.NLL(output, target, reduction: NN.Reduction.Sum))
130+
using (var prediction = model.Forward(data))
131+
using (var output = loss(prediction, target))
129132
{
130-
testLoss += loss.DataItem<float>();
133+
testLoss += output.DataItem<float>();
131134

132135
var pred = output.Argmax(1);
133136

Test/TorchSharp/TorchSharp.cs

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -431,9 +431,10 @@ public void EvalLossSequence()
431431
var y = FloatTensor.RandomN(new long[] { 64, 10 }, device: "cpu:0");
432432

433433
var eval = seq.Forward(x);
434-
var loss = NN.LossFunction.MSE(eval, y, NN.Reduction.Sum);
434+
var loss = NN.LossFunction.MSE(NN.Reduction.Sum);
435+
var output = loss(eval, y);
435436

436-
var result = loss.DataItem<float>();
437+
var result = output.DataItem<float>();
437438
Assert.IsNotNull(result);
438439
}
439440

@@ -444,9 +445,9 @@ public void TestPoissonNLLLoss()
444445
using (TorchTensor target = FloatTensor.From(new float[] { 1f, 2f, 3f }))
445446
{
446447
var componentWiseLoss = ((TorchTensor)input.Exp()) - target * input;
447-
Assert.IsTrue(componentWiseLoss.Equal(NN.LossFunction.PoissonNLL(input, target, reduction: NN.Reduction.None)));
448-
Assert.IsTrue(componentWiseLoss.Sum().Equal(NN.LossFunction.PoissonNLL(input, target, reduction: NN.Reduction.Sum)));
449-
Assert.IsTrue(componentWiseLoss.Mean().Equal(NN.LossFunction.PoissonNLL(input, target, reduction: NN.Reduction.Mean)));
448+
Assert.IsTrue(componentWiseLoss.Equal(NN.LossFunction.PoissonNLL(reduction: NN.Reduction.None)(input, target)));
449+
Assert.IsTrue(componentWiseLoss.Sum().Equal(NN.LossFunction.PoissonNLL(reduction: NN.Reduction.Sum)(input, target)));
450+
Assert.IsTrue(componentWiseLoss.Mean().Equal(NN.LossFunction.PoissonNLL(reduction: NN.Reduction.Mean)(input, target)));
450451
}
451452
}
452453

@@ -456,7 +457,7 @@ public void TestPoissonNLLLoss2()
456457
using (TorchTensor input = FloatTensor.Random(new long[] { 5, 2 }))
457458
using (TorchTensor target = FloatTensor.Random(new long[] { 5, 2 }))
458459
{
459-
Assert.IsNotNull(NN.LossFunction.PoissonNLL(input, target, true, true));
460+
Assert.IsNotNull(NN.LossFunction.PoissonNLL(true, true)(input, target));
460461
}
461462
}
462463

@@ -481,11 +482,12 @@ public void TestBackward()
481482
var y = FloatTensor.RandomN(new long[] { 64, 10 }, device: "cpu:0");
482483

483484
var eval = seq.Forward(x);
484-
var loss = NN.LossFunction.MSE(eval, y, NN.Reduction.None);
485+
var loss = NN.LossFunction.MSE(NN.Reduction.None);
486+
var output = loss(eval, y);
485487

486488
seq.ZeroGrad();
487489

488-
loss.Backward();
490+
output.Backward();
489491
}
490492

491493
[TestMethod]
@@ -499,11 +501,12 @@ public void TestGettingParameters()
499501
var y = FloatTensor.RandomN(new long[] { 64, 10 }, device: "cpu:0");
500502

501503
var eval = seq.Forward(x);
502-
var loss = NN.LossFunction.MSE(eval, y, NN.Reduction.None);
504+
var loss = NN.LossFunction.MSE(NN.Reduction.None);
505+
var output = loss(eval, y);
503506

504507
seq.ZeroGrad();
505508

506-
loss.Backward();
509+
output.Backward();
507510

508511
foreach (var parm in seq.Parameters())
509512
{
@@ -522,11 +525,12 @@ public void TestGrad()
522525
var y = FloatTensor.RandomN(new long[] { 64, 10 }, device: "cpu:0");
523526

524527
var eval = seq.Forward(x);
525-
var loss = NN.LossFunction.MSE(eval, y, NN.Reduction.None);
528+
var loss = NN.LossFunction.MSE(NN.Reduction.None);
529+
var output = loss(eval, y);
526530

527531
seq.ZeroGrad();
528532

529-
loss.Backward();
533+
output.Backward();
530534

531535
foreach (var parm in seq.Parameters())
532536
{
@@ -658,19 +662,20 @@ public void TestTraining()
658662

659663
float learning_rate = 0.00004f;
660664
float prevLoss = float.MaxValue;
665+
var loss = NN.LossFunction.MSE(NN.Reduction.Sum);
661666

662667
for (int i = 0; i < 10; i++)
663668
{
664669
var eval = seq.Forward(x);
665-
var loss = NN.LossFunction.MSE(eval, y, NN.Reduction.Sum);
666-
var lossVal = loss.DataItem<float>();
670+
var output = loss(eval, y);
671+
var lossVal = output.DataItem<float>();
667672

668673
Assert.IsTrue(lossVal < prevLoss);
669674
prevLoss = lossVal;
670675

671676
seq.ZeroGrad();
672677

673-
loss.Backward();
678+
output.Backward();
674679

675680
using (var noGrad = new AutoGradMode(false))
676681
{
@@ -715,19 +720,20 @@ public void TestTrainingAdam()
715720
double learning_rate = 0.00004f;
716721
float prevLoss = float.MaxValue;
717722
var optimizer = NN.Optimizer.Adam(seq.Parameters(), learning_rate);
723+
var loss = NN.LossFunction.MSE(NN.Reduction.Sum);
718724

719725
for (int i = 0; i < 10; i++)
720726
{
721727
var eval = seq.Forward(x);
722-
var loss = NN.LossFunction.MSE(eval, y, NN.Reduction.Sum);
723-
var lossVal = loss.DataItem<float>();
728+
var output = loss(eval, y);
729+
var lossVal = output.DataItem<float>();
724730

725731
Assert.IsTrue(lossVal < prevLoss);
726732
prevLoss = lossVal;
727733

728734
optimizer.ZeroGrad();
729735

730-
loss.Backward();
736+
output.Backward();
731737

732738
optimizer.Step();
733739
}

TorchSharp/NN/LossFunction.cs

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,36 +9,38 @@ namespace TorchSharp.NN
99
/// </summary>
1010
public class LossFunction
1111
{
12+
public delegate TorchTensor Loss(TorchTensor source, TorchTensor target);
13+
1214
[DllImport("libTorchSharp")]
1315
extern static IntPtr THSNN_lossBCE(IntPtr srct, IntPtr trgt, IntPtr wgt, long reduction);
1416

15-
public static TorchTensor BCE<T, U>(TorchTensor src, TorchTensor target, TorchTensor? weigths = null, Reduction reduction = Reduction.Mean)
17+
public static Loss BCE(TorchTensor? weigths = null, Reduction reduction = Reduction.Mean)
1618
{
17-
return new TorchTensor(THSNN_lossBCE(src.Handle, target.Handle, weigths?.Handle ?? IntPtr.Zero, (long)reduction));
19+
return (TorchTensor src, TorchTensor target) => new TorchTensor(THSNN_lossBCE(src.Handle, target.Handle, weigths?.Handle ?? IntPtr.Zero, (long)reduction));
1820
}
1921

2022
[DllImport("libTorchSharp")]
2123
extern static IntPtr THSNN_lossMSE(IntPtr srct, IntPtr trgt, long reduction);
2224

23-
public static TorchTensor MSE(TorchTensor src, TorchTensor target, Reduction reduction = Reduction.Mean)
25+
public static Loss MSE(Reduction reduction = Reduction.Mean)
2426
{
25-
return new TorchTensor(THSNN_lossMSE(src.Handle, target.Handle, (long)reduction));
27+
return (TorchTensor src, TorchTensor target) => new TorchTensor(THSNN_lossMSE(src.Handle, target.Handle, (long)reduction));
2628
}
2729

2830
[DllImport("libTorchSharp")]
2931
extern static IntPtr THSNN_lossNLL(IntPtr srct, IntPtr trgt, IntPtr wgt, long reduction);
3032

31-
public static TorchTensor NLL(TorchTensor src, TorchTensor target, TorchTensor? weigths = null, Reduction reduction = Reduction.Mean)
33+
public static Loss NLL(TorchTensor? weigths = null, Reduction reduction = Reduction.Mean)
3234
{
33-
return new TorchTensor(THSNN_lossNLL(src.Handle, target.Handle, weigths?.Handle ?? IntPtr.Zero, (long)reduction));
35+
return (TorchTensor src, TorchTensor target) => new TorchTensor(THSNN_lossNLL(src.Handle, target.Handle, weigths?.Handle ?? IntPtr.Zero, (long)reduction));
3436
}
3537

3638
[DllImport("libTorchSharp")]
3739
extern static IntPtr THSNN_lossPoissonNLL(IntPtr srct, IntPtr trgt, bool logInput, bool full, float eps, long reduction);
3840

39-
public static TorchTensor PoissonNLL(TorchTensor src, TorchTensor target, bool logInput = true, bool full = false, float eps = 1e-8f, Reduction reduction = Reduction.Mean)
41+
public static Loss PoissonNLL(bool logInput = true, bool full = false, float eps = 1e-8f, Reduction reduction = Reduction.Mean)
4042
{
41-
return new TorchTensor(THSNN_lossPoissonNLL(src.Handle, target.Handle, logInput, full, eps, (long)reduction));
43+
return (TorchTensor src, TorchTensor target) => new TorchTensor(THSNN_lossPoissonNLL(src.Handle, target.Handle, logInput, full, eps, (long)reduction));
4244
}
4345
}
4446

0 commit comments

Comments
 (0)