Skip to content

Commit 97fa580

Browse files
authored
Merge pull request #23 from interesaaat/LibTorchSharpFirstTest
Add basic error handling logic
2 parents 6fafef7 + d677818 commit 97fa580

File tree

4 files changed

+44
-6
lines changed

4 files changed

+44
-6
lines changed

Test/TorchSharp/TorchSharp.cs

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
using Microsoft.VisualStudio.TestTools.UnitTesting;
22
using System;
3+
using System.Diagnostics;
34
using System.Linq;
5+
using System.Runtime.InteropServices;
46
using TorchSharp.JIT;
57
using TorchSharp.NN;
68
using TorchSharp.Tensor;
@@ -164,7 +166,7 @@ public void TestSparse()
164166
Assert.IsTrue(sparse.IsSparse);
165167
Assert.IsFalse(i.IsSparse);
166168
Assert.IsFalse(v.IsSparse);
167-
CollectionAssert.AreEqual(sparse.Indeces.Data<long>().ToArray(), new long[] { 0, 1, 1, 2, 0, 2 });
169+
CollectionAssert.AreEqual(sparse.Indices.Data<long>().ToArray(), new long[] { 0, 1, 1, 2, 0, 2 });
168170
CollectionAssert.AreEqual(sparse.Values.Data<float>().ToArray(), new float[] { 3, 4, 5 });
169171
}
170172
}
@@ -461,13 +463,24 @@ public void TestPoissonNLLLoss2()
461463
}
462464
}
463465

466+
# if DEBUG
467+
[TestMethod]
468+
public void TestErrorHandling()
469+
{
470+
using (TorchTensor input = FloatTensor.From(new float[] { 0.5f, 1.5f}))
471+
using (TorchTensor target = FloatTensor.From(new float[] { 1f, 2f, 3f }))
472+
{
473+
Assert.ThrowsException<ExternalException>(() => NN.LossFunction.PoissonNLL()(input, target));
474+
}
475+
}
476+
#endif
477+
464478
[TestMethod]
465479
public void TestZeroGrad()
466480
{
467481
var lin1 = NN.Module.Linear(1000, 100);
468482
var lin2 = NN.Module.Linear(100, 10);
469483
var seq = NN.Module.Sequential(lin1, NN.Module.Relu(), lin2);
470-
471484
seq.ZeroGrad();
472485
}
473486

TorchSharp/NN/LossFunction.cs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,16 @@ public static Loss NLL(TorchTensor? weigths = null, Reduction reduction = Reduct
3636
}
3737

3838
[DllImport("libTorchSharp")]
39-
extern static IntPtr THSNN_lossPoissonNLL(IntPtr srct, IntPtr trgt, bool logInput, bool full, float eps, long reduction);
39+
extern static IntPtr THSNN_loss_poisson_nll(IntPtr srct, IntPtr trgt, bool logInput, bool full, float eps, long reduction);
4040

4141
public static Loss PoissonNLL(bool logInput = true, bool full = false, float eps = 1e-8f, Reduction reduction = Reduction.Mean)
4242
{
43-
return (TorchTensor src, TorchTensor target) => new TorchTensor(THSNN_lossPoissonNLL(src.Handle, target.Handle, logInput, full, eps, (long)reduction));
43+
return (TorchTensor src, TorchTensor target) =>
44+
{
45+
var tptr = THSNN_loss_poisson_nll(src.Handle, target.Handle, logInput, full, eps, (long)reduction);
46+
Torch.AssertNoErrors();
47+
return new TorchTensor(tptr);
48+
};
4449
}
4550
}
4651

TorchSharp/Tensor/TorchTensorTyped.generated.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -876,7 +876,9 @@ static public TorchTensor Random(long[] size, string device = "cpu", bool requir
876876
{
877877
fixed (long* psizes = size)
878878
{
879-
return new TorchTensor (THSTensor_rand ((IntPtr)psizes, size.Length, (sbyte)ATenScalarMapping.Float, device, requiresGrad));
879+
var tptr = THSTensor_rand((IntPtr)psizes, size.Length, (sbyte)ATenScalarMapping.Float, device, requiresGrad);
880+
Torch.AssertNoErrors();
881+
return new TorchTensor (tptr);
880882
}
881883
}
882884
}

TorchSharp/Torch.cs

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1-
using System.Runtime.InteropServices;
1+
using System;
2+
using System.Diagnostics;
3+
using System.Runtime.InteropServices;
24

35
namespace TorchSharp
46
{
7+
using Debug = System.Diagnostics.Debug;
8+
59
public static class Torch
610
{
711
[DllImport("libTorchSharp")]
@@ -19,5 +23,19 @@ public static bool IsCudaAvailable()
1923
{
2024
return THSTorch_isCudaAvailable();
2125
}
26+
27+
[DllImport("libTorchSharp")]
28+
extern static IntPtr THSTorch_get_and_reset_last_err();
29+
30+
[Conditional("DEBUG")]
31+
internal static void AssertNoErrors()
32+
{
33+
var error = THSTorch_get_and_reset_last_err();
34+
35+
if (error != IntPtr.Zero)
36+
{
37+
throw new ExternalException(Marshal.PtrToStringAnsi(error));
38+
}
39+
}
2240
}
2341
}

0 commit comments

Comments
 (0)