Skip to content

Commit 4b4064f

Browse files
Merge branch 'munit' of https://github.com/yueyinqiu2/TorchSharp into yueyinqiu2-munit
2 parents b0819d6 + 95d4347 commit 4b4064f

File tree

5 files changed

+24
-36
lines changed

5 files changed

+24
-36
lines changed

src/Native/LibTorchSharp/THSNN.cpp

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -538,22 +538,9 @@ Tensor THSNN_cosine_similarity(const Tensor input1, const Tensor input2, int64_t
538538
CATCH_TENSOR(torch::nn::functional::cosine_similarity(*input1, *input2, torch::nn::functional::CosineSimilarityFuncOptions().dim(dim).eps(eps)));
539539
}
540540

541-
NNModule THSNN_PairwiseDistance_ctor(double p, double eps, bool keep_dim, NNAnyModule* outAsAnyModule)
541+
Tensor THSNN_pairwise_distance(const Tensor input1, const Tensor input2, double p, double eps, bool keepdim)
542542
{
543-
CATCH_RETURN_NNModule(
544-
auto opts = torch::nn::PairwiseDistanceOptions()
545-
.p(p)
546-
.eps(eps)
547-
.keepdim(keep_dim);
548-
549-
res = create_module<torch::nn::PairwiseDistanceImpl>(opts, outAsAnyModule);
550-
);
551-
552-
}
553-
554-
Tensor THSNN_PairwiseDistance_forward(const NNModule module, const Tensor input1, const Tensor input2)
555-
{
556-
CATCH_TENSOR((*module)->as<torch::nn::PairwiseDistance>()->forward(*input1, *input2));
543+
CATCH_TENSOR(torch::nn::functional::pairwise_distance(*input1, *input2, torch::nn::functional::PairwiseDistanceFuncOptions().p(p).eps(eps).keepdim(keepdim)));
557544
}
558545

559546
NNModule THSNN_RNN_ctor(const int64_t input_size, const int64_t hidden_size, const int64_t num_layers, const int64_t nonlinearity, const bool bias, const bool batchFirst, const double dropout, const bool bidirectional, NNAnyModule* outAsAnyModule)

src/Native/LibTorchSharp/THSNN.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,8 +269,7 @@ EXPORT_API(Tensor) THSNN_one_hot(const Tensor self, const int64_t num_classes);
269269

270270
EXPORT_API(Tensor) THSNN_cosine_similarity(const Tensor input1, const Tensor input2, int64_t dim, double eps);
271271

272-
EXPORT_API(NNModule) THSNN_PairwiseDistance_ctor(double p, double eps, bool keep_dim, NNAnyModule* outAsAnyModule);
273-
EXPORT_API(Tensor) THSNN_PairwiseDistance_forward(const NNModule module, const Tensor input1, const Tensor input2);
272+
EXPORT_API(Tensor) THSNN_pairwise_distance(const Tensor input1, const Tensor input2, double p, double eps, bool keepdim);
274273

275274
EXPORT_API(Tensor) THSNN_scaled_dot_product_attention(const Tensor query, const Tensor key, const Tensor value, const Tensor attention_mask, double p, bool casual);
276275

src/TorchSharp/NN/PairwiseDistance.cs

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,24 @@ namespace Modules
1212
/// <summary>
1313
/// Computes the pairwise distance between vectors using the p-norm.
1414
/// </summary>
15-
public sealed class PairwiseDistance : torch.nn.Module<Tensor, Tensor, Tensor>
15+
public sealed class PairwiseDistance : ParamLessModule<Tensor, Tensor, Tensor>
1616
{
17-
internal PairwiseDistance(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle)
17+
public double norm { get; set; }
18+
public double eps { get; set; }
19+
public bool keepdim { get; set; }
20+
21+
internal PairwiseDistance(
22+
double p = 2.0, double eps = 1e-6, bool keepdim = false)
23+
: base(nameof(PairwiseDistance))
1824
{
25+
this.norm = p;
26+
this.eps = eps;
27+
this.keepdim = keepdim;
1928
}
2029

2130
public override Tensor forward(Tensor input1, Tensor input2)
2231
{
23-
var res = THSNN_PairwiseDistance_forward(handle, input1.Handle, input2.Handle);
24-
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
25-
return new Tensor(res);
32+
return nn.functional.pairwise_distance(input1, input2, norm, eps, keepdim);
2633
}
2734

2835
// Rather than spending cycles only to discover that this module has neither
@@ -37,11 +44,9 @@ public static partial class torch
3744
{
3845
public static partial class nn
3946
{
40-
public static PairwiseDistance PairwiseDistance(double p = 2.0, double eps = 1e-6, bool keep_dim = false)
47+
public static PairwiseDistance PairwiseDistance(double p = 2.0, double eps = 1e-6, bool keepdim = false)
4148
{
42-
var handle = THSNN_PairwiseDistance_ctor(p, eps, keep_dim, out var boxedHandle);
43-
if (handle == IntPtr.Zero) { torch.CheckForErrors(); }
44-
return new PairwiseDistance(handle, boxedHandle);
49+
return new PairwiseDistance(p, eps, keepdim);
4550
}
4651

4752
public static partial class functional
@@ -53,13 +58,13 @@ public static partial class functional
5358
/// <param name="input2">(N, D) or (D), same shape as the Input1</param>
5459
/// <param name="p">The norm degree. Default: 2</param>
5560
/// <param name="eps">Small value to avoid division by zero.</param>
56-
/// <param name="keep_dim">Determines whether or not to keep the vector dimension.</param>
61+
/// <param name="keepdim">Determines whether or not to keep the vector dimension.</param>
5762
/// <returns></returns>
58-
public static Tensor pairwise_distance(Tensor input1, Tensor input2, double p = 2.0, double eps = 1e-6, bool keep_dim = false)
63+
public static Tensor pairwise_distance(Tensor input1, Tensor input2, double p = 2.0, double eps = 1e-6, bool keepdim = false)
5964
{
60-
using (var f = nn.PairwiseDistance(p, eps, keep_dim)) {
61-
return f.call(input1, input2);
62-
}
65+
var res = THSNN_pairwise_distance(input1.Handle, input2.Handle, p, eps, keepdim);
66+
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
67+
return new Tensor(res);
6368
}
6469
}
6570
}

src/TorchSharp/PInvoke/LibTorchSharp.THSNN.cs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -402,10 +402,7 @@ internal static extern IntPtr THSNN_custom_module(
402402
internal static extern IntPtr THSNN_one_hot(IntPtr self, long num_classes);
403403

404404
[DllImport("LibTorchSharp")]
405-
internal static extern IntPtr THSNN_PairwiseDistance_forward(torch.nn.Module.HType module, IntPtr input1, IntPtr input2);
406-
407-
[DllImport("LibTorchSharp")]
408-
internal static extern IntPtr THSNN_PairwiseDistance_ctor(double p, double eps, [MarshalAs(UnmanagedType.U1)] bool keep_dim, out IntPtr pBoxedModule);
405+
internal static extern IntPtr THSNN_pairwise_distance(IntPtr input1, IntPtr input2, double p, double eps, [MarshalAs(UnmanagedType.U1)] bool keepdim);
409406

410407
[DllImport("LibTorchSharp")]
411408
internal static extern IntPtr THSNN_pixel_unshuffle(IntPtr tensor, long downscale_factor);

test/TorchSharpTest/NN.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5620,7 +5620,7 @@ public void TestPairwiseDistance()
56205620
using (Tensor input1 = torch.rand(new long[] { 5, 12 }, device: device))
56215621
using (Tensor input2 = torch.randint(12, new long[] { 5, 12 }, torch.int64, device: device))
56225622

5623-
using (var module = PairwiseDistance(keep_dim: true)) {
5623+
using (var module = PairwiseDistance(keepdim: true)) {
56245624
var output = module.call(input1, input2);
56255625
Assert.Equal(device.type, output.device_type);
56265626
Assert.Equal(input1.shape[0], output.shape[0]);

0 commit comments

Comments
 (0)