Skip to content

Commit 78c0cee

Browse files
AhmedZeroalinpahontu2912
authored andcommitted
Refactor Normalize method like pytorch.
1 parent 6ceda53 commit 78c0cee

File tree

2 files changed

+45
-35
lines changed

2 files changed

+45
-35
lines changed

src/TorchVision/Functional.cs

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,19 @@ public static partial class transforms
2424
{
2525
public static partial class functional
2626
{
27+
28+
private static bool IsTensorImage(Tensor img)
29+
{
30+
return img.ndim >= 2;
31+
}
32+
33+
private static bool AssertTensorImage(Tensor img)
34+
{
35+
if (!IsTensorImage(img))
36+
throw new ArgumentException("Tensor is not a torch image.");
37+
return true;
38+
}
39+
2740
/// <summary>
2841
/// Get the image dimensions
2942
/// </summary>
@@ -533,20 +546,29 @@ public static Tensor invert(Tensor input)
533546
/// <param name="input">An image tensor.</param>
534547
/// <param name="means">Sequence of means for each channel.</param>
535548
/// <param name="stdevs">Sequence of standard deviations for each channel.</param>
536-
/// <param name="dtype">Bool to make this operation inplace.</param>
549+
/// <param name="inplace">Bool to make this operation inplace.</param>
537550
/// <returns></returns>
538-
public static Tensor normalize(Tensor input, double[] means, double[] stdevs, ScalarType dtype = ScalarType.Float32)
551+
public static Tensor normalize(Tensor input, double[] means, double[] stdevs, bool inplace = false)
539552
{
540-
if (means.Length != stdevs.Length)
541-
throw new ArgumentException("means and stdevs must be the same length in call to Normalize");
542-
if (means.Length != input.shape[1])
543-
throw new ArgumentException("The number of channels is not equal to the number of means and standard deviations");
544-
545-
using var mean = means.ToTensor(new long[] { 1, means.Length, 1, 1 }).to(input.dtype, input.device); // Assumes NxCxHxW
546-
using var stdev = stdevs.ToTensor(new long[] { 1, stdevs.Length, 1, 1 }).to(input.dtype, input.device); // Assumes NxCxHxW
547-
using var t0 = input - mean;
548-
549-
return t0 / stdev;
553+
using var _ = NewDisposeScope();
554+
AssertTensorImage(input);
555+
if (!input.is_floating_point())
556+
throw new ArgumentException($"Input tensor should be a float tensor. Got {input.dtype}.");
557+
if (input.ndim < 3)
558+
throw new ArgumentException($"Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = {input.size()}");
559+
if (!inplace)
560+
input = input.clone();
561+
562+
563+
var mean = as_tensor(means, dtype: input.dtype, device: input.device);
564+
var stdev = as_tensor(stdevs, dtype: input.dtype, device: input.device);
565+
if (stdev.eq(0).any().ToBoolean())
566+
throw new ArgumentException($"std evaluated to zero after conversion to {input.dtype}, leading to division by zero.");
567+
if (mean.ndim == 1)
568+
mean = mean.view(-1, 1, 1);
569+
if (stdev.ndim == 1)
570+
stdev = stdev.view(-1, 1, 1);
571+
return input.sub_(mean).div_(stdev).MoveToOuterDisposeScope();
550572
}
551573

552574
private static Tensor _pad(Tensor input, ReadOnlySpan<long> padding, double fill = 0, PaddingModes padding_mode = PaddingModes.Constant)

src/TorchVision/Normalize.cs

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,44 +9,33 @@ public static partial class torchvision
99
{
1010
internal class Normalize : ITransform, IDisposable
1111
{
12-
internal Normalize(double[] means, double[] stdevs, ScalarType dtype = ScalarType.Float32, torch.Device? device = null)
12+
internal Normalize(double[] means, double[] stdevs,bool inplace = false)
1313
{
1414
if (means is null) throw new ArgumentNullException(nameof(means));
1515
if (stdevs is null) throw new ArgumentNullException(nameof(stdevs));
1616
if (means.Length != stdevs.Length)
1717
throw new ArgumentException($"{nameof(means)} and {nameof(stdevs)} must be the same length in call to Normalize");
1818
if (means.Length != 1 && means.Length != 3)
1919
throw new ArgumentException($"Since they correspond to the number of channels in an image, {nameof(means)} and {nameof(stdevs)} must both be either 1 or 3 long");
20+
this.means = means;
21+
this.stdevs = stdevs;
22+
this.inplace = inplace;
2023

21-
this.means = means.ToTensor(new long[] { 1, means.Length, 1, 1 }); // Assumes NxCxHxW
22-
this.stdevs = stdevs.ToTensor(new long[] { 1, stdevs.Length, 1, 1 }); // Assumes NxCxHxW
23-
24-
if (dtype != ScalarType.Float64) {
25-
this.means = this.means.to_type(dtype);
26-
this.stdevs = this.stdevs.to_type(dtype);
27-
}
28-
29-
if (device != null && device.type != DeviceType.CPU) {
30-
this.means = this.means.to(device);
31-
this.stdevs = this.stdevs.to(device);
32-
}
3324
}
3425

3526
public Tensor call(Tensor input)
3627
{
37-
if (means.size(1) != input.size(1)) throw new ArgumentException("The number of channels is not equal to the number of means and standard deviations");
38-
return (input - means) / stdevs;
28+
return transforms.functional.normalize(input, means, stdevs, inplace);
3929
}
4030

41-
private Tensor means;
42-
private Tensor stdevs;
31+
private readonly double[] means;
32+
private readonly double[] stdevs;
33+
private readonly bool inplace;
4334
bool disposedValue;
4435

4536
protected virtual void Dispose(bool disposing)
4637
{
4738
if (!disposedValue) {
48-
means?.Dispose();
49-
stdevs?.Dispose();
5039
disposedValue = true;
5140
}
5241
}
@@ -72,12 +61,11 @@ public static partial class transforms
7261
/// </summary>
7362
/// <param name="means">Sequence of means for each channel.</param>
7463
/// <param name="stdevs">Sequence of standard deviations for each channel.</param>
75-
/// <param name="dtype">Bool to make this operation inplace.</param>
76-
/// <param name="device">The device to place the output tensor on.</param>
64+
/// <param name="inplace">Bool to make this operation inplace.</param>
7765
/// <returns></returns>
78-
static public ITransform Normalize(double[] means, double[] stdevs, ScalarType dtype = ScalarType.Float32, torch.Device? device = null)
66+
static public ITransform Normalize(double[] means, double[] stdevs, bool inplace = false)
7967
{
80-
return new Normalize(means, stdevs, dtype, device);
68+
return new Normalize(means, stdevs, inplace);
8169
}
8270
}
8371
}

0 commit comments

Comments
 (0)