@@ -9,14 +9,21 @@ namespace TorchSharp.NN
9
9
/// </summary>
10
10
public class LossFunction
11
11
{
12
+ public delegate TorchTensor Loss ( TorchTensor source , TorchTensor target ) ;
13
+
12
14
[ DllImport ( "libTorchSharp" ) ]
13
15
extern static IntPtr THSNN_lossBCE ( IntPtr srct , IntPtr trgt , IntPtr wgt , long reduction ) ;
14
16
15
- public static TorchTensor BCE < T , U > ( TorchTensor src , TorchTensor target , TorchTensor ? weigths = null , Reduction reduction = Reduction . Mean )
17
+ public static TorchTensor BCE ( TorchTensor src , TorchTensor target , TorchTensor ? weigths = null , Reduction reduction = Reduction . Mean )
16
18
{
17
19
return new TorchTensor ( THSNN_lossBCE ( src . Handle , target . Handle , weigths ? . Handle ?? IntPtr . Zero , ( long ) reduction ) ) ;
18
20
}
19
21
22
+ public static Loss BCE ( TorchTensor ? weigths = null , Reduction reduction = Reduction . Mean )
23
+ {
24
+ return ( TorchTensor src , TorchTensor target ) => new TorchTensor ( THSNN_lossBCE ( src . Handle , target . Handle , weigths ? . Handle ?? IntPtr . Zero , ( long ) reduction ) ) ;
25
+ }
26
+
20
27
[ DllImport ( "libTorchSharp" ) ]
21
28
extern static IntPtr THSNN_lossMSE ( IntPtr srct , IntPtr trgt , long reduction ) ;
22
29
@@ -25,6 +32,11 @@ public static TorchTensor MSE(TorchTensor src, TorchTensor target, Reduction red
25
32
return new TorchTensor ( THSNN_lossMSE ( src . Handle , target . Handle , ( long ) reduction ) ) ;
26
33
}
27
34
35
+ public static Loss MSE ( Reduction reduction = Reduction . Mean )
36
+ {
37
+ return ( TorchTensor src , TorchTensor target ) => new TorchTensor ( THSNN_lossMSE ( src . Handle , target . Handle , ( long ) reduction ) ) ;
38
+ }
39
+
28
40
[ DllImport ( "libTorchSharp" ) ]
29
41
extern static IntPtr THSNN_lossNLL ( IntPtr srct , IntPtr trgt , IntPtr wgt , long reduction ) ;
30
42
@@ -33,13 +45,23 @@ public static TorchTensor NLL(TorchTensor src, TorchTensor target, TorchTensor?
33
45
return new TorchTensor ( THSNN_lossNLL ( src . Handle , target . Handle , weigths ? . Handle ?? IntPtr . Zero , ( long ) reduction ) ) ;
34
46
}
35
47
48
+ public static Loss NLL ( TorchTensor ? weigths = null , Reduction reduction = Reduction . Mean )
49
+ {
50
+ return ( TorchTensor src , TorchTensor target ) => new TorchTensor ( THSNN_lossNLL ( src . Handle , target . Handle , weigths ? . Handle ?? IntPtr . Zero , ( long ) reduction ) ) ;
51
+ }
52
+
36
53
[ DllImport ( "libTorchSharp" ) ]
37
54
extern static IntPtr THSNN_lossPoissonNLL ( IntPtr srct , IntPtr trgt , bool logInput , bool full , float eps , long reduction ) ;
38
55
39
56
public static TorchTensor PoissonNLL ( TorchTensor src , TorchTensor target , bool logInput = true , bool full = false , float eps = 1e-8f , Reduction reduction = Reduction . Mean )
40
57
{
41
58
return new TorchTensor ( THSNN_lossPoissonNLL ( src . Handle , target . Handle , logInput , full , eps , ( long ) reduction ) ) ;
42
59
}
60
+
61
+ public static Loss PoissonNLL ( bool logInput = true , bool full = false , float eps = 1e-8f , Reduction reduction = Reduction . Mean )
62
+ {
63
+ return ( TorchTensor src , TorchTensor target ) => new TorchTensor ( THSNN_lossPoissonNLL ( src . Handle , target . Handle , logInput , full , eps , ( long ) reduction ) ) ;
64
+ }
43
65
}
44
66
45
67
public enum Reduction : long
0 commit comments