Skip to content

Commit e80b858

Browse files
committed
Added BCE loss.
1 parent f17b1c5 commit e80b858

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

TorchSharp/NN/LossFunction.cs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,14 @@ namespace TorchSharp.NN
99
/// </summary>
1010
public class LossFunction
1111
{
12+
[DllImport("libTorchSharp")]
13+
extern static IntPtr THSNN_lossBCE(IntPtr srct, IntPtr trgt, IntPtr wgt, long reduction);
14+
15+
public static ITorchTensor<float> BCE<T, U>(ITorchTensor<T> src, ITorchTensor<U> target, ITorchTensor<U> weigths = null, Reduction reduction = Reduction.Mean)
16+
{
17+
return new FloatTensor(THSNN_lossBCE(src.Handle, target.Handle, weigths?.Handle ?? IntPtr.Zero, (long)reduction));
18+
}
19+
1220
[DllImport("libTorchSharp")]
1321
extern static IntPtr THSNN_lossMSE(IntPtr srct, IntPtr trgt, long reduction);
1422

@@ -18,11 +26,11 @@ public static ITorchTensor<float> MSE<T>(ITorchTensor<T> src, ITorchTensor<T> ta
1826
}
1927

2028
[DllImport("libTorchSharp")]
21-
extern static IntPtr THSNN_lossNLL(IntPtr srct, IntPtr trgt, long reduction);
29+
extern static IntPtr THSNN_lossNLL(IntPtr srct, IntPtr trgt, IntPtr wgt, long reduction);
2230

23-
public static ITorchTensor<float> NLL<T, U>(ITorchTensor<T> src, ITorchTensor<U> target, Reduction reduction = Reduction.Mean)
31+
public static ITorchTensor<float> NLL<T, U>(ITorchTensor<T> src, ITorchTensor<U> target, ITorchTensor<U> weigths = null, Reduction reduction = Reduction.Mean)
2432
{
25-
return new FloatTensor(THSNN_lossNLL(src.Handle, target.Handle, (long)reduction));
33+
return new FloatTensor(THSNN_lossNLL(src.Handle, target.Handle, weigths?.Handle ?? IntPtr.Zero, (long)reduction));
2634
}
2735
}
2836

0 commit comments

Comments
 (0)