@@ -9,6 +9,14 @@ namespace TorchSharp.NN
9
9
/// </summary>
10
10
public class LossFunction
11
11
{
12
+ [ DllImport ( "libTorchSharp" ) ]
13
+ extern static IntPtr THSNN_lossBCE ( IntPtr srct , IntPtr trgt , IntPtr wgt , long reduction ) ;
14
+
15
+ public static ITorchTensor < float > BCE < T , U > ( ITorchTensor < T > src , ITorchTensor < U > target , ITorchTensor < U > weigths = null , Reduction reduction = Reduction . Mean )
16
+ {
17
+ return new FloatTensor ( THSNN_lossBCE ( src . Handle , target . Handle , weigths ? . Handle ?? IntPtr . Zero , ( long ) reduction ) ) ;
18
+ }
19
+
12
20
[ DllImport ( "libTorchSharp" ) ]
13
21
extern static IntPtr THSNN_lossMSE ( IntPtr srct , IntPtr trgt , long reduction ) ;
14
22
@@ -18,11 +26,11 @@ public static ITorchTensor<float> MSE<T>(ITorchTensor<T> src, ITorchTensor<T> ta
18
26
}
19
27
20
28
[ DllImport ( "libTorchSharp" ) ]
21
- extern static IntPtr THSNN_lossNLL ( IntPtr srct , IntPtr trgt , long reduction ) ;
29
+ extern static IntPtr THSNN_lossNLL ( IntPtr srct , IntPtr trgt , IntPtr wgt , long reduction ) ;
22
30
23
- public static ITorchTensor < float > NLL < T , U > ( ITorchTensor < T > src , ITorchTensor < U > target , Reduction reduction = Reduction . Mean )
31
+ public static ITorchTensor < float > NLL < T , U > ( ITorchTensor < T > src , ITorchTensor < U > target , ITorchTensor < U > weigths = null , Reduction reduction = Reduction . Mean )
24
32
{
25
- return new FloatTensor ( THSNN_lossNLL ( src . Handle , target . Handle , ( long ) reduction ) ) ;
33
+ return new FloatTensor ( THSNN_lossNLL ( src . Handle , target . Handle , weigths ? . Handle ?? IntPtr . Zero , ( long ) reduction ) ) ;
26
34
}
27
35
}
28
36
0 commit comments