Skip to content

Commit 44712c1

Browse files
authored
Merge pull request #6 from interesaaat/LibTorchSharpFirstTest
Added some tensor operations and loss functions
2 parents 0d6871a + e80b858 commit 44712c1

File tree

3 files changed

+351
-19
lines changed

3 files changed

+351
-19
lines changed

TorchSharp/NN/LossFunction.cs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,14 @@ namespace TorchSharp.NN
99
/// </summary>
1010
public class LossFunction
1111
{
12+
[DllImport("libTorchSharp")]
13+
extern static IntPtr THSNN_lossBCE(IntPtr srct, IntPtr trgt, IntPtr wgt, long reduction);
14+
15+
public static ITorchTensor<float> BCE<T, U>(ITorchTensor<T> src, ITorchTensor<U> target, ITorchTensor<U> weigths = null, Reduction reduction = Reduction.Mean)
16+
{
17+
return new FloatTensor(THSNN_lossBCE(src.Handle, target.Handle, weigths?.Handle ?? IntPtr.Zero, (long)reduction));
18+
}
19+
1220
[DllImport("libTorchSharp")]
1321
extern static IntPtr THSNN_lossMSE(IntPtr srct, IntPtr trgt, long reduction);
1422

@@ -18,11 +26,11 @@ public static ITorchTensor<float> MSE<T>(ITorchTensor<T> src, ITorchTensor<T> ta
1826
}
1927

2028
[DllImport("libTorchSharp")]
21-
extern static IntPtr THSNN_lossNLL(IntPtr srct, IntPtr trgt, long reduction);
29+
extern static IntPtr THSNN_lossNLL(IntPtr srct, IntPtr trgt, IntPtr wgt, long reduction);
2230

23-
public static ITorchTensor<float> NLL<T, U>(ITorchTensor<T> src, ITorchTensor<U> target, Reduction reduction = Reduction.Mean)
31+
public static ITorchTensor<float> NLL<T, U>(ITorchTensor<T> src, ITorchTensor<U> target, ITorchTensor<U> weigths = null, Reduction reduction = Reduction.Mean)
2432
{
25-
return new FloatTensor(THSNN_lossNLL(src.Handle, target.Handle, (long)reduction));
33+
return new FloatTensor(THSNN_lossNLL(src.Handle, target.Handle, weigths?.Handle ?? IntPtr.Zero, (long)reduction));
2634
}
2735
}
2836

0 commit comments

Comments
 (0)