@@ -12,36 +12,36 @@ namespace Modules
12
12
/// <summary>
13
13
/// Computes the pairwise distance between vectors using the p-norm.
14
14
/// </summary>
15
- public sealed class PairwiseDistance : torch . nn . Module < Tensor , Tensor , Tensor >
15
+ public sealed class PairwiseDistance : ParamLessModule < Tensor , Tensor , Tensor >
16
16
{
17
- internal PairwiseDistance ( IntPtr handle , IntPtr boxedHandle ) : base ( handle , boxedHandle )
17
+ public double norm { get ; set ; }
18
+ public double eps { get ; set ; }
19
+ public bool keepdim { get ; set ; }
20
+
21
+ internal PairwiseDistance (
22
+ double p = 2.0 , double eps = 1e-6 , bool keepdim = false )
23
+ : base ( nameof ( PairwiseDistance ) )
18
24
{
25
+ this . norm = p ;
26
+ this . eps = eps ;
27
+ this . keepdim = keepdim ;
19
28
}
20
29
21
30
public override Tensor forward ( Tensor input1 , Tensor input2 )
22
31
{
23
- var res = THSNN_PairwiseDistance_forward ( handle , input1 . Handle , input2 . Handle ) ;
24
- if ( res == IntPtr . Zero ) { torch . CheckForErrors ( ) ; }
25
- return new Tensor ( res ) ;
32
+ return nn . functional . pairwise_distance (
33
+ input1 , input2 , norm , eps , keepdim ) ;
26
34
}
27
-
28
- // Rather than spending cycles only to discover that this module has neither
29
- // parameters nor buffers, just shortcut the move completely.
30
- protected internal override nn . Module _to ( Device device , ScalarType dtype ) => this ;
31
- protected internal override nn . Module _to ( DeviceType deviceType , int deviceIndex = - 1 ) => this ;
32
- protected internal override nn . Module _to ( ScalarType dtype ) => this ;
33
35
}
34
36
}
35
37
36
38
public static partial class torch
37
39
{
38
40
public static partial class nn
39
41
{
40
- public static PairwiseDistance PairwiseDistance ( double p = 2.0 , double eps = 1e-6 , bool keep_dim = false )
42
+ public static PairwiseDistance PairwiseDistance ( double p = 2.0 , double eps = 1e-6 , bool keepdim = false )
41
43
{
42
- var handle = THSNN_PairwiseDistance_ctor ( p , eps , keep_dim , out var boxedHandle ) ;
43
- if ( handle == IntPtr . Zero ) { torch . CheckForErrors ( ) ; }
44
- return new PairwiseDistance ( handle , boxedHandle ) ;
44
+ return new PairwiseDistance ( p , eps , keepdim ) ;
45
45
}
46
46
47
47
public static partial class functional
@@ -53,13 +53,13 @@ public static partial class functional
53
53
/// <param name="input2">(N, D) or (D), same shape as the Input1</param>
54
54
/// <param name="p">The norm degree. Default: 2</param>
55
55
/// <param name="eps">Small value to avoid division by zero.</param>
56
- /// <param name="keep_dim ">Determines whether or not to keep the vector dimension.</param>
56
+ /// <param name="keepdim ">Determines whether or not to keep the vector dimension.</param>
57
57
/// <returns></returns>
58
- public static Tensor pairwise_distance ( Tensor input1 , Tensor input2 , double p = 2.0 , double eps = 1e-6 , bool keep_dim = false )
58
+ public static Tensor pairwise_distance ( Tensor input1 , Tensor input2 , double p = 2.0 , double eps = 1e-6 , bool keepdim = false )
59
59
{
60
- using ( var f = nn . PairwiseDistance ( p , eps , keep_dim ) ) {
61
- return f . call ( input1 , input2 ) ;
62
- }
60
+ var res = THSNN_pairwise_distance ( input1 . Handle , input2 . Handle , p , eps , keepdim ) ;
61
+ if ( res == IntPtr . Zero ) { torch . CheckForErrors ( ) ; }
62
+ return new Tensor ( res ) ;
63
63
}
64
64
}
65
65
}
0 commit comments