Skip to content

Commit b66289b

Browse files
committed
pairwise distance
1 parent 254470e commit b66289b

File tree

5 files changed

+25
-42
lines changed

5 files changed

+25
-42
lines changed

src/Native/LibTorchSharp/THSNN.cpp

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -538,22 +538,9 @@ Tensor THSNN_cosine_similarity(const Tensor input1, const Tensor input2, int64_t
538538
CATCH_TENSOR(torch::nn::functional::cosine_similarity(*input1, *input2, torch::nn::functional::CosineSimilarityFuncOptions().dim(dim).eps(eps)));
539539
}
540540

541-
NNModule THSNN_PairwiseDistance_ctor(double p, double eps, bool keep_dim, NNAnyModule* outAsAnyModule)
541+
Tensor THSNN_pairwise_distance(const Tensor input1, const Tensor input2, double p, double eps, bool keepdim)
542542
{
543-
CATCH_RETURN_NNModule(
544-
auto opts = torch::nn::PairwiseDistanceOptions()
545-
.p(p)
546-
.eps(eps)
547-
.keepdim(keep_dim);
548-
549-
res = create_module<torch::nn::PairwiseDistanceImpl>(opts, outAsAnyModule);
550-
);
551-
552-
}
553-
554-
Tensor THSNN_PairwiseDistance_forward(const NNModule module, const Tensor input1, const Tensor input2)
555-
{
556-
CATCH_TENSOR((*module)->as<torch::nn::PairwiseDistance>()->forward(*input1, *input2));
543+
CATCH_TENSOR(torch::nn::functional::pairwise_distance(*input1, *input2, torch::nn::functional::PairwiseDistanceFuncOptions().p(p).eps(eps).keepdim(keepdim)));
557544
}
558545

559546
NNModule THSNN_RNN_ctor(const int64_t input_size, const int64_t hidden_size, const int64_t num_layers, const int64_t nonlinearity, const bool bias, const bool batchFirst, const double dropout, const bool bidirectional, NNAnyModule* outAsAnyModule)

src/Native/LibTorchSharp/THSNN.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,8 +267,7 @@ EXPORT_API(Tensor) THSNN_one_hot(const Tensor self, const int64_t num_classes);
267267

268268
EXPORT_API(Tensor) THSNN_cosine_similarity(const Tensor input1, const Tensor input2, int64_t dim, double eps);
269269

270-
EXPORT_API(NNModule) THSNN_PairwiseDistance_ctor(double p, double eps, bool keep_dim, NNAnyModule* outAsAnyModule);
271-
EXPORT_API(Tensor) THSNN_PairwiseDistance_forward(const NNModule module, const Tensor input1, const Tensor input2);
270+
EXPORT_API(Tensor) THSNN_pairwise_distance(const Tensor input1, const Tensor input2, double p, double eps, bool keepdim);
272271

273272
EXPORT_API(Tensor) THSNN_scaled_dot_product_attention(const Tensor query, const Tensor key, const Tensor value, const Tensor attention_mask, double p, bool casual);
274273

src/TorchSharp/NN/PairwiseDistance.cs

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,36 +12,36 @@ namespace Modules
1212
/// <summary>
1313
/// Computes the pairwise distance between vectors using the p-norm.
1414
/// </summary>
15-
public sealed class PairwiseDistance : torch.nn.Module<Tensor, Tensor, Tensor>
15+
public sealed class PairwiseDistance : ParamLessModule<Tensor, Tensor, Tensor>
1616
{
17-
internal PairwiseDistance(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle)
17+
public double norm { get; set; }
18+
public double eps { get; set; }
19+
public bool keepdim { get; set; }
20+
21+
internal PairwiseDistance(
22+
double p = 2.0, double eps = 1e-6, bool keepdim = false)
23+
: base(nameof(PairwiseDistance))
1824
{
25+
this.norm = p;
26+
this.eps = eps;
27+
this.keepdim = keepdim;
1928
}
2029

2130
public override Tensor forward(Tensor input1, Tensor input2)
2231
{
23-
var res = THSNN_PairwiseDistance_forward(handle, input1.Handle, input2.Handle);
24-
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
25-
return new Tensor(res);
32+
return nn.functional.pairwise_distance(
33+
input1, input2, norm, eps, keepdim);
2634
}
27-
28-
// Rather than spending cycles only to discover that this module has neither
29-
// parameters nor buffers, just shortcut the move completely.
30-
protected internal override nn.Module _to(Device device, ScalarType dtype) => this;
31-
protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this;
32-
protected internal override nn.Module _to(ScalarType dtype) => this;
3335
}
3436
}
3537

3638
public static partial class torch
3739
{
3840
public static partial class nn
3941
{
40-
public static PairwiseDistance PairwiseDistance(double p = 2.0, double eps = 1e-6, bool keep_dim = false)
42+
public static PairwiseDistance PairwiseDistance(double p = 2.0, double eps = 1e-6, bool keepdim = false)
4143
{
42-
var handle = THSNN_PairwiseDistance_ctor(p, eps, keep_dim, out var boxedHandle);
43-
if (handle == IntPtr.Zero) { torch.CheckForErrors(); }
44-
return new PairwiseDistance(handle, boxedHandle);
44+
return new PairwiseDistance(p, eps, keepdim);
4545
}
4646

4747
public static partial class functional
@@ -53,13 +53,13 @@ public static partial class functional
5353
/// <param name="input2">(N, D) or (D), same shape as the Input1</param>
5454
/// <param name="p">The norm degree. Default: 2</param>
5555
/// <param name="eps">Small value to avoid division by zero.</param>
56-
/// <param name="keep_dim">Determines whether or not to keep the vector dimension.</param>
56+
/// <param name="keepdim">Determines whether or not to keep the vector dimension.</param>
5757
/// <returns></returns>
58-
public static Tensor pairwise_distance(Tensor input1, Tensor input2, double p = 2.0, double eps = 1e-6, bool keep_dim = false)
58+
public static Tensor pairwise_distance(Tensor input1, Tensor input2, double p = 2.0, double eps = 1e-6, bool keepdim = false)
5959
{
60-
using (var f = nn.PairwiseDistance(p, eps, keep_dim)) {
61-
return f.call(input1, input2);
62-
}
60+
var res = THSNN_pairwise_distance(input1.Handle, input2.Handle, p, eps, keepdim);
61+
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
62+
return new Tensor(res);
6363
}
6464
}
6565
}

src/TorchSharp/PInvoke/LibTorchSharp.THSNN.cs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -402,10 +402,7 @@ internal static extern IntPtr THSNN_custom_module(
402402
internal static extern IntPtr THSNN_one_hot(IntPtr self, long num_classes);
403403

404404
[DllImport("LibTorchSharp")]
405-
internal static extern IntPtr THSNN_PairwiseDistance_forward(torch.nn.Module.HType module, IntPtr input1, IntPtr input2);
406-
407-
[DllImport("LibTorchSharp")]
408-
internal static extern IntPtr THSNN_PairwiseDistance_ctor(double p, double eps, [MarshalAs(UnmanagedType.U1)] bool keep_dim, out IntPtr pBoxedModule);
405+
internal static extern IntPtr THSNN_pairwise_distance(IntPtr input1, IntPtr input2, double p, double eps, [MarshalAs(UnmanagedType.U1)] bool keepdim);
409406

410407
[DllImport("LibTorchSharp")]
411408
internal static extern IntPtr THSNN_pixel_unshuffle(IntPtr tensor, long downscale_factor);

test/TorchSharpTest/NN.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5567,7 +5567,7 @@ public void TestPairwiseDistance()
55675567
using (Tensor input1 = torch.rand(new long[] { 5, 12 }, device: device))
55685568
using (Tensor input2 = torch.randint(12, new long[] { 5, 12 }, torch.int64, device: device))
55695569

5570-
using (var module = PairwiseDistance(keep_dim: true)) {
5570+
using (var module = PairwiseDistance(keepdim: true)) {
55715571
var output = module.call(input1, input2);
55725572
Assert.Equal(device.type, output.device_type);
55735573
Assert.Equal(input1.shape[0], output.shape[0]);

0 commit comments

Comments
 (0)