Skip to content

Commit b9b9bd0

Browse files
AhmedZeroalinpahontu2912
authored andcommitted
Refactor channel determination logic in TorchSharp.
1 parent dd475df commit b9b9bd0

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

src/TorchVision/Functional.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,17 @@ private static bool AssertTensorImage(Tensor img)
3737
return true;
3838
}
3939

40+
public static long GetImageNumChannels(Tensor img)
41+
{
42+
AssertTensorImage(img);
43+
var ndim_ = img.ndim;
44+
return ndim_ switch {
45+
2 => 1,
46+
> 2 => img.shape[ndim_ - 3],
47+
_ => throw new ArgumentException($"Input ndim should be 2 or more. Got {ndim_}"),
48+
};
49+
}
50+
4051
/// <summary>
4152
/// Get the image dimensions
4253
/// </summary>

src/TorchVision/Normalize.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ internal Normalize(double[] means, double[] stdevs,bool inplace = false)
2323

2424
public Tensor call(Tensor input)
2525
{
26-
var expectedChannels = (input.ndim == 4) ? input.size(1) : input.size(0);
26+
var expectedChannels = transforms.functional.GetImageNumChannels(input);
2727
if (expectedChannels != means.Length)
2828
throw new ArgumentException("The number of channels is not equal to the number of means and standard deviations");
2929
return transforms.functional.normalize(input, means, stdevs, inplace);

0 commit comments

Comments
 (0)