Skip to content

Commit 1c4a162

Browse files
committed
Add the ability to create a loss function as a delegate. This allows to match the python api.
1 parent 48fea15 commit 1c4a162

File tree

2 files changed

+55
-1
lines changed

2 files changed

+55
-1
lines changed

Test/TorchSharp/TorchSharp.cs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,38 @@ public void TestTrainingAdam()
733733
}
734734
}
735735

736+
[TestMethod]
737+
public void TestTrainingAdam2()
738+
{
739+
var lin1 = NN.Module.Linear(1000, 100);
740+
var lin2 = NN.Module.Linear(100, 10);
741+
var seq = NN.Module.Sequential(lin1, NN.Module.Relu(), lin2);
742+
743+
var x = FloatTensor.RandomN(new long[] { 64, 1000 }, device: "cpu:0");
744+
var y = FloatTensor.RandomN(new long[] { 64, 10 }, device: "cpu:0");
745+
746+
double learning_rate = 0.00004f;
747+
float prevLoss = float.MaxValue;
748+
var optimizer = NN.Optimizer.Adam(seq.Parameters(), learning_rate);
749+
var loss = NN.LossFunction.MSE(NN.Reduction.Sum);
750+
751+
for (int i = 0; i < 10; i++)
752+
{
753+
var eval = seq.Forward(x);
754+
var output = loss(eval, y);
755+
var lossVal = output.DataItem<float>();
756+
757+
Assert.IsTrue(lossVal < prevLoss);
758+
prevLoss = lossVal;
759+
760+
optimizer.ZeroGrad();
761+
762+
output.Backward();
763+
764+
optimizer.Step();
765+
}
766+
}
767+
736768
[TestMethod]
737769
public void TestMNISTLoader()
738770
{

TorchSharp/NN/LossFunction.cs

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,21 @@ namespace TorchSharp.NN
99
/// </summary>
1010
public class LossFunction
1111
{
12+
public delegate TorchTensor Loss(TorchTensor source, TorchTensor target);
13+
1214
[DllImport("libTorchSharp")]
1315
extern static IntPtr THSNN_lossBCE(IntPtr srct, IntPtr trgt, IntPtr wgt, long reduction);
1416

15-
public static TorchTensor BCE<T, U>(TorchTensor src, TorchTensor target, TorchTensor? weigths = null, Reduction reduction = Reduction.Mean)
17+
public static TorchTensor BCE(TorchTensor src, TorchTensor target, TorchTensor? weigths = null, Reduction reduction = Reduction.Mean)
1618
{
1719
return new TorchTensor(THSNN_lossBCE(src.Handle, target.Handle, weigths?.Handle ?? IntPtr.Zero, (long)reduction));
1820
}
1921

22+
public static Loss BCE(TorchTensor? weigths = null, Reduction reduction = Reduction.Mean)
23+
{
24+
return (TorchTensor src, TorchTensor target) => new TorchTensor(THSNN_lossBCE(src.Handle, target.Handle, weigths?.Handle ?? IntPtr.Zero, (long)reduction));
25+
}
26+
2027
[DllImport("libTorchSharp")]
2128
extern static IntPtr THSNN_lossMSE(IntPtr srct, IntPtr trgt, long reduction);
2229

@@ -25,6 +32,11 @@ public static TorchTensor MSE(TorchTensor src, TorchTensor target, Reduction red
2532
return new TorchTensor(THSNN_lossMSE(src.Handle, target.Handle, (long)reduction));
2633
}
2734

35+
public static Loss MSE(Reduction reduction = Reduction.Mean)
36+
{
37+
return (TorchTensor src, TorchTensor target) => new TorchTensor(THSNN_lossMSE(src.Handle, target.Handle, (long)reduction));
38+
}
39+
2840
[DllImport("libTorchSharp")]
2941
extern static IntPtr THSNN_lossNLL(IntPtr srct, IntPtr trgt, IntPtr wgt, long reduction);
3042

@@ -33,13 +45,23 @@ public static TorchTensor NLL(TorchTensor src, TorchTensor target, TorchTensor?
3345
return new TorchTensor(THSNN_lossNLL(src.Handle, target.Handle, weigths?.Handle ?? IntPtr.Zero, (long)reduction));
3446
}
3547

48+
public static Loss NLL(TorchTensor? weigths = null, Reduction reduction = Reduction.Mean)
49+
{
50+
return (TorchTensor src, TorchTensor target) => new TorchTensor(THSNN_lossNLL(src.Handle, target.Handle, weigths?.Handle ?? IntPtr.Zero, (long)reduction));
51+
}
52+
3653
[DllImport("libTorchSharp")]
3754
extern static IntPtr THSNN_lossPoissonNLL(IntPtr srct, IntPtr trgt, bool logInput, bool full, float eps, long reduction);
3855

3956
public static TorchTensor PoissonNLL(TorchTensor src, TorchTensor target, bool logInput = true, bool full = false, float eps = 1e-8f, Reduction reduction = Reduction.Mean)
4057
{
4158
return new TorchTensor(THSNN_lossPoissonNLL(src.Handle, target.Handle, logInput, full, eps, (long)reduction));
4259
}
60+
61+
public static Loss PoissonNLL(bool logInput = true, bool full = false, float eps = 1e-8f, Reduction reduction = Reduction.Mean)
62+
{
63+
return (TorchTensor src, TorchTensor target) => new TorchTensor(THSNN_lossPoissonNLL(src.Handle, target.Handle, logInput, full, eps, (long)reduction));
64+
}
4365
}
4466

4567
public enum Reduction : long

0 commit comments

Comments
 (0)