Skip to content

Commit f37cc62

Browse files
committed
Added Poisson NLL loss and related methods to make it work.
1 parent 1b946ba commit f37cc62

File tree

6 files changed

+299
-11
lines changed

6 files changed

+299
-11
lines changed

Examples/MNIST.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ private static void Test(
125125
foreach (var (data, target) in dataLoader)
126126
{
127127
using (var output = model.Forward(data))
128-
using (var loss = NN.LossFunction.NLL(output, target, NN.Reduction.Sum))
128+
using (var loss = NN.LossFunction.NLL(output, target, reduction: NN.Reduction.Sum))
129129
{
130130
testLoss += loss.Item;
131131

Test/TorchSharp/TorchSharp.cs

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,19 @@ public void EvalLossSequence()
327327
Assert.IsNotNull(result);
328328
}
329329

330+
[TestMethod]
331+
public void TestPoissonNLLLoss()
332+
{
333+
using (FloatTensor input = FloatTensor.From(new float[] { 0.5f, 1.5f, 2.5f }))
334+
using (FloatTensor target = FloatTensor.From(new float[] { 1f, 2f, 3f }))
335+
{
336+
var componentWiseLoss = ((FloatTensor)input.Exp()) - target * input;
337+
Assert.IsTrue(componentWiseLoss.Equal(NN.LossFunction.PoissonNLL(input, target, reduction: NN.Reduction.None)));
338+
Assert.IsTrue(componentWiseLoss.Sum().Equal(NN.LossFunction.PoissonNLL(input, target, reduction: NN.Reduction.Sum)));
339+
Assert.IsTrue(componentWiseLoss.Mean().Equal(NN.LossFunction.PoissonNLL(input, target, reduction: NN.Reduction.Mean)));
340+
}
341+
}
342+
330343
[TestMethod]
331344
public void TestZeroGrad()
332345
{
@@ -563,13 +576,13 @@ public void TestMNISTLoader()
563576
{
564577
using (var train = Data.Loader.MNIST(@"E:/Source/Repos/LibTorchSharp/MNIST", 32))
565578
{
566-
var size = train.Size();
567-
568579
Assert.IsNotNull(train);
569-
Assert.IsNotNull(size);
570580

581+
var size = train.Size();
571582
int i = 0;
572583

584+
Assert.IsNotNull(size);
585+
573586
foreach (var (data, target) in train)
574587
{
575588
i++;

TorchSharp/NN/LossFunction.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,14 @@ public static ITorchTensor<float> NLL<T, U>(ITorchTensor<T> src, ITorchTensor<U>
3232
{
3333
return new FloatTensor(THSNN_lossNLL(src.Handle, target.Handle, weigths?.Handle ?? IntPtr.Zero, (long)reduction));
3434
}
35+
36+
[DllImport("libTorchSharp")]
37+
extern static IntPtr THSNN_lossPoissonNLL(IntPtr srct, IntPtr trgt, bool logInput, bool full, double eps, long reduction);
38+
39+
public static ITorchTensor<float> PoissonNLL<T, U>(ITorchTensor<T> src, ITorchTensor<U> target, bool logInput = true, bool full = false, double eps = 1e-8, Reduction reduction = Reduction.Mean)
40+
{
41+
return new FloatTensor(THSNN_lossPoissonNLL(src.Handle, target.Handle, logInput, full, eps, (long)reduction));
42+
}
3543
}
3644

3745
public enum Reduction : long

TorchSharp/Tensor/ITorchTensor.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ public interface ITorchTensor<T> : IDisposable
4040

4141
ITorchTensor<U> Eq<U>(ITorchTensor<U> target);
4242

43+
bool Equal<U>(ITorchTensor<U> target);
44+
4345
ITorchTensor<T> Add(ITorchTensor<T> target, int scalar);
4446

4547
void AddInPlace(ITorchTensor<T> target, int scalar);
@@ -50,10 +52,16 @@ public interface ITorchTensor<T> : IDisposable
5052

5153
ITorchTensor<T> Baddbmm(ITorchTensor<T> batch2, ITorchTensor<T> mat, float beta, float alpha);
5254

55+
ITorchTensor<T> Bmm(ITorchTensor<T> batch2);
56+
5357
ITorchTensor<T> Exp();
5458

5559
ITorchTensor<T> MatMul(ITorchTensor<T> target);
5660

61+
ITorchTensor<T> Mean();
62+
63+
ITorchTensor<T> Mm(ITorchTensor<T> target);
64+
5765
ITorchTensor<T> Mul(ITorchTensor<T> target);
5866

5967
ITorchTensor<T> Mul(T scalar);

0 commit comments

Comments
 (0)