@@ -12,17 +12,24 @@ namespace Modules
12
12
/// <summary>
13
13
/// Computes the pairwise distance between vectors using the p-norm.
14
14
/// </summary>
15
- public sealed class PairwiseDistance : torch . nn . Module < Tensor , Tensor , Tensor >
15
+ public sealed class PairwiseDistance : ParamLessModule < Tensor , Tensor , Tensor >
16
16
{
17
- internal PairwiseDistance ( IntPtr handle , IntPtr boxedHandle ) : base ( handle , boxedHandle )
17
+ public double norm { get ; set ; }
18
+ public double eps { get ; set ; }
19
+ public bool keepdim { get ; set ; }
20
+
21
+ internal PairwiseDistance (
22
+ double p = 2.0 , double eps = 1e-6 , bool keepdim = false )
23
+ : base ( nameof ( PairwiseDistance ) )
18
24
{
25
+ this . norm = p ;
26
+ this . eps = eps ;
27
+ this . keepdim = keepdim ;
19
28
}
20
29
21
30
public override Tensor forward ( Tensor input1 , Tensor input2 )
22
31
{
23
- var res = THSNN_PairwiseDistance_forward ( handle , input1 . Handle , input2 . Handle ) ;
24
- if ( res == IntPtr . Zero ) { torch . CheckForErrors ( ) ; }
25
- return new Tensor ( res ) ;
32
+ return nn . functional . pairwise_distance ( input1 , input2 , norm , eps , keepdim ) ;
26
33
}
27
34
28
35
// Rather than spending cycles only to discover that this module has neither
@@ -37,11 +44,9 @@ public static partial class torch
37
44
{
38
45
public static partial class nn
39
46
{
40
- public static PairwiseDistance PairwiseDistance ( double p = 2.0 , double eps = 1e-6 , bool keep_dim = false )
47
+ public static PairwiseDistance PairwiseDistance ( double p = 2.0 , double eps = 1e-6 , bool keepdim = false )
41
48
{
42
- var handle = THSNN_PairwiseDistance_ctor ( p , eps , keep_dim , out var boxedHandle ) ;
43
- if ( handle == IntPtr . Zero ) { torch . CheckForErrors ( ) ; }
44
- return new PairwiseDistance ( handle , boxedHandle ) ;
49
+ return new PairwiseDistance ( p , eps , keepdim ) ;
45
50
}
46
51
47
52
public static partial class functional
@@ -53,13 +58,13 @@ public static partial class functional
53
58
/// <param name="input2">(N, D) or (D), same shape as the Input1</param>
54
59
/// <param name="p">The norm degree. Default: 2</param>
55
60
/// <param name="eps">Small value to avoid division by zero.</param>
56
- /// <param name="keep_dim ">Determines whether or not to keep the vector dimension.</param>
61
+ /// <param name="keepdim ">Determines whether or not to keep the vector dimension.</param>
57
62
/// <returns></returns>
58
- public static Tensor pairwise_distance ( Tensor input1 , Tensor input2 , double p = 2.0 , double eps = 1e-6 , bool keep_dim = false )
63
+ public static Tensor pairwise_distance ( Tensor input1 , Tensor input2 , double p = 2.0 , double eps = 1e-6 , bool keepdim = false )
59
64
{
60
- using ( var f = nn . PairwiseDistance ( p , eps , keep_dim ) ) {
61
- return f . call ( input1 , input2 ) ;
62
- }
65
+ var res = THSNN_pairwise_distance ( input1 . Handle , input2 . Handle , p , eps , keepdim ) ;
66
+ if ( res == IntPtr . Zero ) { torch . CheckForErrors ( ) ; }
67
+ return new Tensor ( res ) ;
63
68
}
64
69
}
65
70
}
0 commit comments