Skip to content

Commit b2bb7e8

Browse files
Merge pull request #1405 from NiklasGustafsson/missing
Adding `torch.nn.functional.normalize`
2 parents 6c93cc6 + 9b782a3 commit b2bb7e8

File tree

7 files changed

+85
-21
lines changed

7 files changed

+85
-21
lines changed

RELEASENOTES.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,12 @@ The argument defaults for `torch.diagonal()` and `Tensor.diagonal()` arguments h
1010

1111
__Bug Fixes__:
1212

13-
#1400 There may be an error in torchvision.transforms.GaussianBlur
14-
#1402 diagonal() has incorrect default
13+
#1400 There may be an error in torchvision.transforms.GaussianBlur<br/>
14+
#1402 diagonal() has incorrect default<br/>
15+
16+
__API Changes__:
17+
18+
#1382: Add support for torch.nn.functional.normalize<br/>
1519

1620
# NuGet Version 0.103.1
1721

TorchSharp.sln

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "TorchSharp", "TorchSharp",
3434
pkg\TorchSharp\TorchSharp.symbols.nupkgproj = pkg\TorchSharp\TorchSharp.symbols.nupkgproj
3535
EndProjectSection
3636
EndProject
37-
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "LibTorchSharp", "bin\obj\x64.Debug\Native\LibTorchSharp\LibTorchSharp.vcxproj", "{2B359162-062E-3C52-91D3-027A8542A58C}"
37+
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "LibTorchSharp", "bin\obj\x64.Debug\Native\LibTorchSharp\LibTorchSharp.vcxproj", "{CAD9DB7F-3223-3324-884D-FA2381593DA7}"
3838
EndProject
39-
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "LibTorchSharp", "bin\obj\x64.Release\Native\LibTorchSharp\LibTorchSharp.vcxproj", "{E4C0DBEE-0815-311B-9065-137BB50BD793}"
39+
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "LibTorchSharp", "bin\obj\x64.Release\Native\LibTorchSharp\LibTorchSharp.vcxproj", "{BB811429-0DF1-3D22-B664-09C2F5A9E0AB}"
4040
EndProject
4141
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Native-Debug", "Native-Debug", "{CF2C1A9E-3A8A-4329-8A6E-7880C15AAC3D}"
4242
ProjectSection(SolutionItems) = preProject
@@ -66,9 +66,9 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution
6666
azure-pipelines.yml = azure-pipelines.yml
6767
build\BranchInfo.props = build\BranchInfo.props
6868
DEVGUIDE.md = DEVGUIDE.md
69+
global.json = global.json
6970
README.md = README.md
7071
RELEASENOTES.md = RELEASENOTES.md
71-
global.json = global.json
7272
EndProjectSection
7373
EndProject
7474
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TorchVision", "src\TorchVision\TorchVision.csproj", "{DCF01EE5-6431-4115-85E0-1FC4C3DE86A2}"
@@ -107,14 +107,14 @@ Global
107107
{42B45168-476D-4BFA-87B8-81A34E6295CD}.Release|Any CPU.Build.0 = Release|Any CPU
108108
{42B45168-476D-4BFA-87B8-81A34E6295CD}.Release|x64.ActiveCfg = Release|Any CPU
109109
{42B45168-476D-4BFA-87B8-81A34E6295CD}.Release|x64.Build.0 = Release|Any CPU
110-
{2B359162-062E-3C52-91D3-027A8542A58C}.Debug|Any CPU.ActiveCfg = Debug|x64
111-
{2B359162-062E-3C52-91D3-027A8542A58C}.Debug|x64.ActiveCfg = Debug|x64
112-
{2B359162-062E-3C52-91D3-027A8542A58C}.Release|Any CPU.ActiveCfg = Release|x64
113-
{2B359162-062E-3C52-91D3-027A8542A58C}.Release|x64.ActiveCfg = Release|x64
114-
{E4C0DBEE-0815-311B-9065-137BB50BD793}.Debug|Any CPU.ActiveCfg = Debug|x64
115-
{E4C0DBEE-0815-311B-9065-137BB50BD793}.Debug|x64.ActiveCfg = Debug|x64
116-
{E4C0DBEE-0815-311B-9065-137BB50BD793}.Release|Any CPU.ActiveCfg = Release|x64
117-
{E4C0DBEE-0815-311B-9065-137BB50BD793}.Release|x64.ActiveCfg = Release|x64
110+
{CAD9DB7F-3223-3324-884D-FA2381593DA7}.Debug|Any CPU.ActiveCfg = Debug|x64
111+
{CAD9DB7F-3223-3324-884D-FA2381593DA7}.Debug|x64.ActiveCfg = Debug|x64
112+
{CAD9DB7F-3223-3324-884D-FA2381593DA7}.Release|Any CPU.ActiveCfg = Release|x64
113+
{CAD9DB7F-3223-3324-884D-FA2381593DA7}.Release|x64.ActiveCfg = Release|x64
114+
{BB811429-0DF1-3D22-B664-09C2F5A9E0AB}.Debug|Any CPU.ActiveCfg = Debug|x64
115+
{BB811429-0DF1-3D22-B664-09C2F5A9E0AB}.Debug|x64.ActiveCfg = Debug|x64
116+
{BB811429-0DF1-3D22-B664-09C2F5A9E0AB}.Release|Any CPU.ActiveCfg = Release|x64
117+
{BB811429-0DF1-3D22-B664-09C2F5A9E0AB}.Release|x64.ActiveCfg = Release|x64
118118
{DD652544-711E-4029-83FF-DA4A9600E6E7}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
119119
{DD652544-711E-4029-83FF-DA4A9600E6E7}.Debug|Any CPU.Build.0 = Debug|Any CPU
120120
{DD652544-711E-4029-83FF-DA4A9600E6E7}.Debug|x64.ActiveCfg = Debug|Any CPU
@@ -148,7 +148,6 @@ Global
148148
{95493944-D1AE-414E-964B-B58AEAE672E5}.Release|x64.ActiveCfg = Release|Any CPU
149149
{95493944-D1AE-414E-964B-B58AEAE672E5}.Release|x64.Build.0 = Release|Any CPU
150150
{6D3CE8AA-F369-4D2D-BDA7-9F89D6BE1B2E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
151-
{6D3CE8AA-F369-4D2D-BDA7-9F89D6BE1B2E}.Debug|Any CPU.Build.0 = Debug|Any CPU
152151
{6D3CE8AA-F369-4D2D-BDA7-9F89D6BE1B2E}.Debug|x64.ActiveCfg = Debug|Any CPU
153152
{6D3CE8AA-F369-4D2D-BDA7-9F89D6BE1B2E}.Debug|x64.Build.0 = Debug|Any CPU
154153
{6D3CE8AA-F369-4D2D-BDA7-9F89D6BE1B2E}.Release|Any CPU.ActiveCfg = Release|Any CPU
@@ -181,8 +180,8 @@ Global
181180
{6C323B05-9028-4B09-911C-3C03AE058BEE} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
182181
{42B45168-476D-4BFA-87B8-81A34E6295CD} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
183182
{567456AD-B026-4CB6-B98D-4FC930C90223} = {D3D38B03-B557-484D-8348-8BADEE4DF592}
184-
{2B359162-062E-3C52-91D3-027A8542A58C} = {CF2C1A9E-3A8A-4329-8A6E-7880C15AAC3D}
185-
{E4C0DBEE-0815-311B-9065-137BB50BD793} = {4DB9E84D-324C-408F-87A6-246E86205540}
183+
{CAD9DB7F-3223-3324-884D-FA2381593DA7} = {CF2C1A9E-3A8A-4329-8A6E-7880C15AAC3D}
184+
{BB811429-0DF1-3D22-B664-09C2F5A9E0AB} = {4DB9E84D-324C-408F-87A6-246E86205540}
186185
{CF2C1A9E-3A8A-4329-8A6E-7880C15AAC3D} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
187186
{D8C60CD8-8429-45F2-A755-47B6CD10FDF8} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
188187
{4DB9E84D-324C-408F-87A6-246E86205540} = {CF2C1A9E-3A8A-4329-8A6E-7880C15AAC3D}

src/Native/LibTorchSharp/THSNN.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ EXPORT_API(void) THSNN_GroupNorm_set_weight(const NNModule module, const Ten
289289
EXPORT_API(NNModule) THSNN_LocalResponseNorm_ctor(const int64_t size, const double alpha, const double beta, const double k, NNAnyModule* outAsAnyModule);
290290
EXPORT_API(Tensor) THSNN_LocalResponseNorm_forward(const NNModule module, const Tensor tensor);
291291

292+
EXPORT_API(Tensor) THSNN_normalize(const Tensor input, const double p, const int64_t dim, const double eps);
292293
EXPORT_API(Tensor) THSNN_batch_norm(const Tensor input, const Tensor running_mean, const Tensor running_var, const Tensor weight, const Tensor bias, const bool training, const double momentum, const double eps);
293294
EXPORT_API(Tensor) THSNN_group_norm(const Tensor input, int64_t num_groups, const Tensor weight, const Tensor bias, const double eps);
294295
EXPORT_API(Tensor) THSNN_instance_norm(const Tensor input, const Tensor running_mean, const Tensor running_var, const Tensor weight, const Tensor bias, const bool use_input_stats, const double momentum, const double eps);

src/Native/LibTorchSharp/THSNormalization.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,15 @@ Tensor THSNN_batch_norm(const Tensor input, Tensor running_mean, const Tensor ru
607607
CATCH_TENSOR(torch::nn::functional::batch_norm(*input, *running_mean, *running_var, opts));
608608
}
609609

610+
Tensor THSNN_normalize(const Tensor input, const double p, const int64_t dim, const double eps)
611+
{
612+
auto opts = torch::nn::functional::NormalizeFuncOptions()
613+
.p(p)
614+
.dim(dim)
615+
.eps(eps);
616+
CATCH_TENSOR(torch::nn::functional::normalize(*input, opts));
617+
}
618+
610619
Tensor THSNN_group_norm(const Tensor input, const int64_t num_groups, const Tensor weight, const Tensor bias, const double eps)
611620
{
612621
auto opts = torch::nn::functional::GroupNormFuncOptions(num_groups)

src/TorchSharp/NN/Normalization/Functional.cs

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using System;
33
using static TorchSharp.PInvoke.NativeMethods;
44

5+
#nullable enable
56
namespace TorchSharp
67
{
78
public static partial class torch
@@ -10,10 +11,27 @@ public static partial class nn
1011
{
1112
public static partial class functional
1213
{
14+
/// <summary>
15+
/// Perform normalization of inputs over specified dimension.
16+
/// </summary>
17+
/// <param name="input">Input tensor of any shape.</param>
18+
/// <param name="p">the exponent value in the norm formulation</param>
19+
/// <param name="dim">the dimension to reduce</param>
20+
/// <param name="eps">small value to avoid division by zero</param>
21+
public static Tensor normalize(Tensor input, double p = 2.0, long dim = 1L, double eps = 1e-12)
22+
{
23+
var res = THSNN_normalize(
24+
input.Handle,
25+
p, dim, eps);
26+
if (res == IntPtr.Zero)
27+
torch.CheckForErrors();
28+
return new Tensor(res);
29+
}
30+
1331
/// <summary>
1432
/// Applies Batch Normalization for each channel across a batch of data.
1533
/// </summary>
16-
public static Tensor batch_norm(Tensor input, Tensor running_mean, Tensor running_var, Tensor weight = null, Tensor bias = null, bool training = false, double momentum = 0.1, double eps = 1e-5)
34+
public static Tensor batch_norm(Tensor input, Tensor running_mean, Tensor running_var, Tensor? weight = null, Tensor? bias = null, bool training = false, double momentum = 0.1, double eps = 1e-5)
1735
{
1836
var res = THSNN_batch_norm(
1937
input.Handle,
@@ -31,7 +49,7 @@ public static Tensor batch_norm(Tensor input, Tensor running_mean, Tensor runnin
3149
/// <summary>
3250
/// Applies Group Normalization for last certain number of dimensions.
3351
/// </summary>
34-
public static Tensor group_norm(Tensor input, long num_groups, Tensor weight = null, Tensor bias = null, double eps = 1e-5)
52+
public static Tensor group_norm(Tensor input, long num_groups, Tensor? weight = null, Tensor? bias = null, double eps = 1e-5)
3553
{
3654
var res = THSNN_group_norm(
3755
input.Handle,
@@ -47,7 +65,7 @@ public static Tensor group_norm(Tensor input, long num_groups, Tensor weight = n
4765
/// <summary>
4866
/// Applies Instance Normalization for each channel in each data sample in a batch.
4967
/// </summary>
50-
public static Tensor instance_norm(Tensor input, Tensor running_mean = null, Tensor running_var = null, Tensor weight = null, Tensor bias = null, bool use_input_stats = true, double momentum = 0.1, double eps = 1e-5)
68+
public static Tensor instance_norm(Tensor input, Tensor? running_mean = null, Tensor? running_var = null, Tensor? weight = null, Tensor? bias = null, bool use_input_stats = true, double momentum = 0.1, double eps = 1e-5)
5169
{
5270
var res = THSNN_instance_norm(
5371
input.Handle,
@@ -65,7 +83,7 @@ public static Tensor instance_norm(Tensor input, Tensor running_mean = null, Ten
6583
/// <summary>
6684
/// Applies Layer Normalization for last certain number of dimensions.
6785
/// </summary>
68-
public static Tensor layer_norm(Tensor input, long[] normalized_shape, Tensor weight = null, Tensor bias = null, double eps = 1e-5)
86+
public static Tensor layer_norm(Tensor input, long[] normalized_shape, Tensor? weight = null, Tensor? bias = null, double eps = 1e-5)
6987
{
7088
IntPtr res;
7189
unsafe {

src/TorchSharp/PInvoke/LibTorchSharp.THSNN.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -851,6 +851,9 @@ internal static extern IntPtr THSNN_custom_module(
851851
[DllImport("LibTorchSharp")]
852852
internal static extern IntPtr THSNN_Unflatten_ctor(long dim, IntPtr shape, long shape_len, out IntPtr pBoxedModule);
853853

854+
[DllImport("LibTorchSharp")]
855+
internal static extern IntPtr THSNN_normalize(IntPtr input, double p, long dim, double eps);
856+
854857
[DllImport("LibTorchSharp")]
855858
internal static extern IntPtr THSNN_batch_norm(IntPtr input, IntPtr running_mean, IntPtr running_var, IntPtr weight, IntPtr bias, [MarshalAs(UnmanagedType.U1)] bool training, double momentum, double eps);
856859

test/TorchSharpTest/NN.cs

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
1+
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
22
using System;
33
using System.Linq;
44
using System.Runtime.InteropServices;
@@ -4817,6 +4817,36 @@ private Tensor NormalizeTensor(Tensor x, Tensor x_mean, Tensor x_var, double eps
48174817
return (x - x_mean) / torch.sqrt(eps + x_var);
48184818
}
48194819

4820+
[Fact]
4821+
public void TestNormalizeFunc()
4822+
{
4823+
foreach (var device in TestUtils.AvailableDevices()) {
4824+
var x = torch.from_array(new double[]
4825+
{ -1.0786, 0.3455, 1.2929, 0.5030,
4826+
-0.2930, 1.0420, -0.1082, -0.2943,
4827+
-0.3989, -0.8311, 0.7103, -1.5878,
4828+
0.6331, 1.0106, 0.5128, -2.2565,
4829+
1.2044, -0.6916, -0.1242, 0.6808,
4830+
0.1672, 0.1105, -1.7364, 0.0669
4831+
}).reshape(3,2,4);
4832+
var y = torch.nn.functional.normalize(x);
4833+
Assert.Equal(x.shape, y.shape);
4834+
Assert.Equal(x.device_type, y.device_type);
4835+
4836+
var expected = torch.from_array(new double[]
4837+
{ -0.9650, 0.3147, 0.9965, 0.8631,
4838+
-0.2621, 0.9492, -0.0834, -0.5050,
4839+
-0.5331, -0.6352, 0.8108, -0.5755,
4840+
0.8460, 0.7724, 0.5853, -0.8178,
4841+
0.9905, -0.9875, -0.0713, 0.9952,
4842+
0.1375, 0.1577, -0.9975, 0.0978
4843+
}).reshape(3, 2, 4);
4844+
4845+
4846+
Assert.True(y.allclose(expected, rtol: 0.005, atol: 0.005));
4847+
}
4848+
}
4849+
48204850
[Fact]
48214851
public void TestBatchNormFunc()
48224852
{

0 commit comments

Comments
 (0)