Skip to content

Commit f94ddf5

Browse files
author
Niklas Gustafsson
committed
Adding nn.functional.normalize()
1 parent bf73b9f commit f94ddf5

File tree

6 files changed

+52
-10
lines changed

6 files changed

+52
-10
lines changed

RELEASENOTES.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,12 @@ The argument defaults for `torch.diagonal()` and `Tensor.diagonal()` arguments h
1010

1111
__Bug Fixes__:
1212

13-
#1400 There may be an error in torchvision.transforms.GaussianBlur
14-
#1402 diagonal() has incorrect default
13+
#1400 There may be an error in torchvision.transforms.GaussianBlur<br/>
14+
#1402 diagonal() has incorrect default<br/>
15+
16+
__API Changes__:
17+
18+
#1382: Add support for torch.nn.functional.normalize<br/>
1519

1620
# NuGet Version 0.103.1
1721

TorchSharp.sln

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,6 @@ Global
113113
{CAD9DB7F-3223-3324-884D-FA2381593DA7}.Release|Any CPU.ActiveCfg = Release|x64
114114
{CAD9DB7F-3223-3324-884D-FA2381593DA7}.Release|x64.ActiveCfg = Release|x64
115115
{BB811429-0DF1-3D22-B664-09C2F5A9E0AB}.Debug|Any CPU.ActiveCfg = Debug|x64
116-
{BB811429-0DF1-3D22-B664-09C2F5A9E0AB}.Debug|Any CPU.Build.0 = Debug|x64
117116
{BB811429-0DF1-3D22-B664-09C2F5A9E0AB}.Debug|x64.ActiveCfg = Debug|x64
118117
{BB811429-0DF1-3D22-B664-09C2F5A9E0AB}.Release|Any CPU.ActiveCfg = Release|x64
119118
{BB811429-0DF1-3D22-B664-09C2F5A9E0AB}.Release|x64.ActiveCfg = Release|x64
@@ -150,7 +149,6 @@ Global
150149
{95493944-D1AE-414E-964B-B58AEAE672E5}.Release|x64.ActiveCfg = Release|Any CPU
151150
{95493944-D1AE-414E-964B-B58AEAE672E5}.Release|x64.Build.0 = Release|Any CPU
152151
{6D3CE8AA-F369-4D2D-BDA7-9F89D6BE1B2E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
153-
{6D3CE8AA-F369-4D2D-BDA7-9F89D6BE1B2E}.Debug|Any CPU.Build.0 = Debug|Any CPU
154152
{6D3CE8AA-F369-4D2D-BDA7-9F89D6BE1B2E}.Debug|x64.ActiveCfg = Debug|Any CPU
155153
{6D3CE8AA-F369-4D2D-BDA7-9F89D6BE1B2E}.Debug|x64.Build.0 = Debug|Any CPU
156154
{6D3CE8AA-F369-4D2D-BDA7-9F89D6BE1B2E}.Release|Any CPU.ActiveCfg = Release|Any CPU

src/Native/LibTorchSharp/THSNN.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ EXPORT_API(void) THSNN_GroupNorm_set_weight(const NNModule module, const Ten
289289
EXPORT_API(NNModule) THSNN_LocalResponseNorm_ctor(const int64_t size, const double alpha, const double beta, const double k, NNAnyModule* outAsAnyModule);
290290
EXPORT_API(Tensor) THSNN_LocalResponseNorm_forward(const NNModule module, const Tensor tensor);
291291

292+
EXPORT_API(Tensor) THSNN_normalize(const Tensor input, const double p, const int64_t dim, const double eps);
292293
EXPORT_API(Tensor) THSNN_batch_norm(const Tensor input, const Tensor running_mean, const Tensor running_var, const Tensor weight, const Tensor bias, const bool training, const double momentum, const double eps);
293294
EXPORT_API(Tensor) THSNN_group_norm(const Tensor input, int64_t num_groups, const Tensor weight, const Tensor bias, const double eps);
294295
EXPORT_API(Tensor) THSNN_instance_norm(const Tensor input, const Tensor running_mean, const Tensor running_var, const Tensor weight, const Tensor bias, const bool use_input_stats, const double momentum, const double eps);

src/TorchSharp/NN/Normalization/Functional.cs

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using System;
33
using static TorchSharp.PInvoke.NativeMethods;
44

5+
#nullable enable
56
namespace TorchSharp
67
{
78
public static partial class torch
@@ -19,13 +20,18 @@ public static partial class functional
1920
/// <param name="eps">small value to avoid division by zero</param>
2021
public static Tensor normalize(Tensor input, double p = 2.0, long dim = 1L, double eps = 1e-12)
2122
{
22-
return null;
23+
var res = THSNN_normalize(
24+
input.Handle,
25+
p, dim, eps);
26+
if (res == IntPtr.Zero)
27+
torch.CheckForErrors();
28+
return new Tensor(res);
2329
}
2430

2531
/// <summary>
2632
/// Applies Batch Normalization for each channel across a batch of data.
2733
/// </summary>
28-
public static Tensor batch_norm(Tensor input, Tensor running_mean, Tensor running_var, Tensor weight = null, Tensor bias = null, bool training = false, double momentum = 0.1, double eps = 1e-5)
34+
public static Tensor batch_norm(Tensor input, Tensor running_mean, Tensor running_var, Tensor? weight = null, Tensor? bias = null, bool training = false, double momentum = 0.1, double eps = 1e-5)
2935
{
3036
var res = THSNN_batch_norm(
3137
input.Handle,
@@ -43,7 +49,7 @@ public static Tensor batch_norm(Tensor input, Tensor running_mean, Tensor runnin
4349
/// <summary>
4450
/// Applies Group Normalization for last certain number of dimensions.
4551
/// </summary>
46-
public static Tensor group_norm(Tensor input, long num_groups, Tensor weight = null, Tensor bias = null, double eps = 1e-5)
52+
public static Tensor group_norm(Tensor input, long num_groups, Tensor? weight = null, Tensor? bias = null, double eps = 1e-5)
4753
{
4854
var res = THSNN_group_norm(
4955
input.Handle,
@@ -59,7 +65,7 @@ public static Tensor group_norm(Tensor input, long num_groups, Tensor weight = n
5965
/// <summary>
6066
/// Applies Instance Normalization for each channel in each data sample in a batch.
6167
/// </summary>
62-
public static Tensor instance_norm(Tensor input, Tensor running_mean = null, Tensor running_var = null, Tensor weight = null, Tensor bias = null, bool use_input_stats = true, double momentum = 0.1, double eps = 1e-5)
68+
public static Tensor instance_norm(Tensor input, Tensor? running_mean = null, Tensor? running_var = null, Tensor? weight = null, Tensor? bias = null, bool use_input_stats = true, double momentum = 0.1, double eps = 1e-5)
6369
{
6470
var res = THSNN_instance_norm(
6571
input.Handle,
@@ -77,7 +83,7 @@ public static Tensor instance_norm(Tensor input, Tensor running_mean = null, Ten
7783
/// <summary>
7884
/// Applies Layer Normalization for last certain number of dimensions.
7985
/// </summary>
80-
public static Tensor layer_norm(Tensor input, long[] normalized_shape, Tensor weight = null, Tensor bias = null, double eps = 1e-5)
86+
public static Tensor layer_norm(Tensor input, long[] normalized_shape, Tensor? weight = null, Tensor? bias = null, double eps = 1e-5)
8187
{
8288
IntPtr res;
8389
unsafe {

src/TorchSharp/PInvoke/LibTorchSharp.THSNN.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -851,6 +851,9 @@ internal static extern IntPtr THSNN_custom_module(
851851
[DllImport("LibTorchSharp")]
852852
internal static extern IntPtr THSNN_Unflatten_ctor(long dim, IntPtr shape, long shape_len, out IntPtr pBoxedModule);
853853

854+
[DllImport("LibTorchSharp")]
855+
internal static extern IntPtr THSNN_normalize(IntPtr input, double p, long dim, double eps);
856+
854857
[DllImport("LibTorchSharp")]
855858
internal static extern IntPtr THSNN_batch_norm(IntPtr input, IntPtr running_mean, IntPtr running_var, IntPtr weight, IntPtr bias, [MarshalAs(UnmanagedType.U1)] bool training, double momentum, double eps);
856859

test/TorchSharpTest/NN.cs

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
1+
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
22
using System;
33
using System.Linq;
44
using System.Runtime.InteropServices;
@@ -4817,6 +4817,36 @@ private Tensor NormalizeTensor(Tensor x, Tensor x_mean, Tensor x_var, double eps
48174817
return (x - x_mean) / torch.sqrt(eps + x_var);
48184818
}
48194819

4820+
[Fact]
4821+
public void TestNormalizeFunc()
4822+
{
4823+
foreach (var device in TestUtils.AvailableDevices()) {
4824+
var x = torch.from_array(new double[]
4825+
{ -1.0786, 0.3455, 1.2929, 0.5030,
4826+
-0.2930, 1.0420, -0.1082, -0.2943,
4827+
-0.3989, -0.8311, 0.7103, -1.5878,
4828+
0.6331, 1.0106, 0.5128, -2.2565,
4829+
1.2044, -0.6916, -0.1242, 0.6808,
4830+
0.1672, 0.1105, -1.7364, 0.0669
4831+
}).reshape(3,2,4);
4832+
var y = torch.nn.functional.normalize(x);
4833+
Assert.Equal(x.shape, y.shape);
4834+
Assert.Equal(x.device_type, y.device_type);
4835+
4836+
var expected = torch.from_array(new double[]
4837+
{ -0.9650, 0.3147, 0.9965, 0.8631,
4838+
-0.2621, 0.9492, -0.0834, -0.5050,
4839+
-0.5331, -0.6352, 0.8108, -0.5755,
4840+
0.8460, 0.7724, 0.5853, -0.8178,
4841+
0.9905, -0.9875, -0.0713, 0.9952,
4842+
0.1375, 0.1577, -0.9975, 0.0978
4843+
}).reshape(3, 2, 4);
4844+
4845+
4846+
Assert.True(y.allclose(expected, rtol: 0.005, atol: 0.005));
4847+
}
4848+
}
4849+
48204850
[Fact]
48214851
public void TestBatchNormFunc()
48224852
{

0 commit comments

Comments
 (0)