Skip to content

Commit 28ce621

Browse files
AhmedZeroalinpahontu2912
authored andcommitted
Add validation checks in Normalize constructor and call method
1 parent d22c585 commit 28ce621

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

src/TorchVision/Normalize.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ internal Normalize(double[] means, double[] stdevs,bool inplace = false)
1313
{
1414
if (means is null) throw new ArgumentNullException(nameof(means));
1515
if (stdevs is null) throw new ArgumentNullException(nameof(stdevs));
16+
if (means.Length != stdevs.Length)
17+
throw new ArgumentException($"{nameof(means)} and {nameof(stdevs)} must be the same length in call to Normalize");
1618
this.means = means;
1719
this.stdevs = stdevs;
1820
this.inplace = inplace;
@@ -21,6 +23,9 @@ internal Normalize(double[] means, double[] stdevs,bool inplace = false)
2123

2224
public Tensor call(Tensor input)
2325
{
26+
var expectedChannels = (input.shape.Length == 4) ? input.size(1) : input.size(0);
27+
if (expectedChannels != means.Length)
28+
throw new ArgumentException("The number of channels is not equal to the number of means and standard deviations");
2429
return transforms.functional.normalize(input, means, stdevs, inplace);
2530
}
2631

0 commit comments

Comments
 (0)