From fbfe2904d60a89dee1bd48d249413490925d2ae4 Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Mon, 8 Sep 2025 11:03:42 -0700 Subject: [PATCH] Backend-agnostic quantized_conv_nhwc (channels last) (#13954) Summary: Ongoing work in providing python backend-agnostic references for Cadence custom ops. Reviewed By: hsharma35 Differential Revision: D81526626 --- backends/cadence/aot/TARGETS | 1 - backends/cadence/aot/ref_implementations.py | 119 +++- .../aot/tests/test_ref_implementations.py | 606 ++++++++++-------- 3 files changed, 455 insertions(+), 271 deletions(-) diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index eb0e17f9858..1a2c5a9709f 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -129,7 +129,6 @@ python_library( ], typing = True, deps = [ - "fbcode//executorch/backends/cadence/aot:utils", "fbcode//caffe2:torch", "fbcode//executorch/exir:scalar_type", ], diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 05792a5cfa7..e0a3c8fbe9f 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -23,8 +23,6 @@ ScalarType.QINT32: torch.qint32, } -_Number = bool | int | float - @impl(m, "quantize_per_tensor") def quantize_per_tensor( @@ -298,8 +296,7 @@ def quantized_layer_norm_per_tensor( ) -@impl(m, "quantized_conv_nchw") -def quantized_conv_nchw( +def quantized_conv( input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, @@ -374,6 +371,120 @@ def quantized_conv_nchw( ) +@impl(m, "quantized_conv_nchw") +def quantized_conv_nchw( + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: tuple[int, int], + padding: tuple[int, int], + dilation: tuple[int, int], + groups: int, + in_zero_point: int, + weight_zero_point: torch.Tensor, + bias_scale: torch.Tensor, + output_scale: float, + output_zero_point: int, + out_multiplier: torch.Tensor, + out_shift: torch.Tensor, +) -> torch.Tensor: + """ + Quantized convolution operation. + + Args: + - input_tensor (Tensor): The activations tensor + - weight (Tensor): The weight tensor + - bias (Tensor): The bias tensor + - stride (Tuple[int]): The stride of the convolution + - padding (Tuple[int]): The padding of the convolution + - dilation (Tuple[int]): The dilation of the convolution + - groups (int): The number of groups + - in_zero_point (int): The quantized mapping of zero for the input + - weight_zero_point (Tensor): The quantized mapping of zero for the weight + - bias_scale (Tensor): The quantized bias scale + - output_scale (float): The scale of the output + - output_zero_point (int): The zero point of the output + - out_multiplier (Tensor): Unused + - out_shift (Tensor): Unused + """ + if not input_tensor.is_contiguous(memory_format=torch.contiguous_format): + raise ValueError("Input tensor must be in NCHW format") + return quantized_conv( + input_tensor, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out_multiplier, + out_shift, + ) + + +@impl(m, "quantized_conv_nhwc") +def quantized_conv_nhwc( + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: tuple[int, int], + padding: tuple[int, int], + dilation: tuple[int, int], + groups: int, + in_zero_point: int, + weight_zero_point: torch.Tensor, + bias_scale: torch.Tensor, + output_scale: float, + output_zero_point: int, + out_multiplier: torch.Tensor, + out_shift: torch.Tensor, +) -> torch.Tensor: + """ + Quantized convolution operation. + + Args: + - input_tensor (Tensor): The activations tensor + - weight (Tensor): The weight tensor + - bias (Tensor): The bias tensor + - stride (Tuple[int]): The stride of the convolution + - padding (Tuple[int]): The padding of the convolution + - dilation (Tuple[int]): The dilation of the convolution + - groups (int): The number of groups + - in_zero_point (int): The quantized mapping of zero for the input + - weight_zero_point (Tensor): The quantized mapping of zero for the weight + - bias_scale (Tensor): The quantized bias scale + - output_scale (float): The scale of the output + - output_zero_point (int): The zero point of the output + - out_multiplier (Tensor): Unused + - out_shift (Tensor): Unused + """ + + if not input_tensor.is_contiguous(memory_format=torch.channels_last): + raise ValueError("Input tensor must be in NHWC format") + + return quantized_conv( + input_tensor, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out_multiplier, + out_shift, + ) + + @impl(m, "requantize") def requantize( input: torch.Tensor, diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index 37a250c70f7..d8b07131d3a 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -16,6 +16,7 @@ quantize_per_tensor, quantized_add, quantized_conv_nchw, + quantized_conv_nhwc, quantized_layer_norm_per_tensor, quantized_linear, ) @@ -340,288 +341,347 @@ def test_quantized_layer_norm_per_tensor( @expand( [ # Test case 1: Basic 2D convolution with int8 - ( - torch.tensor([[[[1, 2], [3, 4]]]], dtype=torch.int8), # input: 1x1x2x2 - torch.tensor( - [[[[1, 0], [0, 1]]]], dtype=torch.int8 - ), # weight: 1x1x2x2 (identity-like) - torch.tensor([0], dtype=torch.int8), # bias - (1, 1), # stride - (0, 0), # padding - (1, 1), # dilation - 1, # groups - 0, # in_zero_point - torch.tensor([0], dtype=torch.int8), # weight_zero_point - torch.tensor([1.0], dtype=torch.float32), # bias_scale - 0.1, # output_scale - 0, # output_zero_point - torch.tensor( - [1073741824], dtype=torch.int32 - ), # out_multiplier (0.5 * 2^31) - torch.tensor([0], dtype=torch.int8), # out_shift - torch.int8, # dtype - torch.tensor( - [[[[50]]]], dtype=torch.int8 - ), # expected_output: (1*1 + 4*1) / 0.1 = 50 - ), + *[ + ( + torch.tensor( + [[[[1, 2], [3, 4]]]], dtype=torch.int8 + ), # input: 1x1x2x2 + torch.tensor( + [[[[1, 0], [0, 1]]]], dtype=torch.int8 + ), # weight: 1x1x2x2 (identity-like) + torch.tensor([0], dtype=torch.int8), # bias + (1, 1), # stride + (0, 0), # padding + (1, 1), # dilation + 1, # groups + 0, # in_zero_point + torch.tensor([0], dtype=torch.int8), # weight_zero_point + torch.tensor([1.0], dtype=torch.float32), # bias_scale + 0.1, # output_scale + 0, # output_zero_point + torch.tensor( + [1073741824], dtype=torch.int32 + ), # out_multiplier (0.5 * 2^31) + torch.tensor([0], dtype=torch.int8), # out_shift + torch.int8, # dtype + torch.tensor( + [[[[50]]]], dtype=torch.int8 + ), # expected_output: (1*1 + 4*1) / 0.1 = 50 + memory_format, + ) + for memory_format in [torch.contiguous_format, torch.channels_last] + ], # Test case 2: 2D convolution with stride and padding - ( - torch.tensor( - [[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]], dtype=torch.int8 - ), # input: 1x1x3x3 - torch.tensor( - [[[[1, 1], [1, 1]]]], dtype=torch.int8 - ), # weight: 1x1x2x2 (sum filter) - torch.tensor([0], dtype=torch.int8), # bias - (1, 1), # stride - (0, 0), # padding - (1, 1), # dilation - 1, # groups - 0, # in_zero_point - torch.tensor([0], dtype=torch.int8), # weight_zero_point - torch.tensor([1.0], dtype=torch.float32), # bias_scale - 0.25, # output_scale - 0, # output_zero_point - typing.cast(None, torch.Tensor), - typing.cast(None, torch.Tensor), - torch.int8, # dtype - torch.tensor( - [[[[48, 64], [96, 112]]]], dtype=torch.int8 - ), # expected_output: convolution results with output_scale=0.25 - ), + *[ + ( + torch.tensor( + [[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]], dtype=torch.int8 + ), # input: 1x1x3x3 + torch.tensor( + [[[[1, 1], [1, 1]]]], dtype=torch.int8 + ), # weight: 1x1x2x2 (sum filter) + torch.tensor([0], dtype=torch.int8), # bias + (1, 1), # stride + (0, 0), # padding + (1, 1), # dilation + 1, # groups + 0, # in_zero_point + torch.tensor([0], dtype=torch.int8), # weight_zero_point + torch.tensor([1.0], dtype=torch.float32), # bias_scale + 0.25, # output_scale + 0, # output_zero_point + typing.cast(None, torch.Tensor), + typing.cast(None, torch.Tensor), + torch.int8, # dtype + torch.tensor( + [[[[48, 64], [96, 112]]]], dtype=torch.int8 + ), # expected_output: convolution results with output_scale=0.25 + memory_format, + ) + for memory_format in [torch.contiguous_format, torch.channels_last] + ], # Test case 3: uint8 with non-zero zero points - ( - torch.tensor( - [[[[130, 132], [134, 136]]]], dtype=torch.uint8 - ), # input: 1x1x2x2 - torch.tensor( - [[[[129, 128], [128, 129]]]], dtype=torch.uint8 - ), # weight: 1x1x2x2 (values close to zero_point) - torch.tensor([10], dtype=torch.uint8), # bias - (1, 1), # stride - (0, 0), # padding - (1, 1), # dilation - 1, # groups - 128, # in_zero_point - torch.tensor([128], dtype=torch.uint8), # weight_zero_point - torch.tensor([0.1], dtype=torch.float32), # bias_scale - 0.1, # output_scale - 128, # output_zero_point - typing.cast(None, torch.Tensor), - typing.cast(None, torch.Tensor), - torch.uint8, # dtype - torch.tensor( - [[[[238]]]], dtype=torch.uint8 - ), # (130 - 128) + (134 - 128) = 10 - # + bias -> 10 + 1 = 11 - # round(11 / 0.1 + 128) = 238 - ), + *[ + ( + torch.tensor( + [[[[130, 132], [134, 136]]]], dtype=torch.uint8 + ), # input: 1x1x2x2 + torch.tensor( + [[[[129, 128], [128, 129]]]], dtype=torch.uint8 + ), # weight: 1x1x2x2 (values close to zero_point) + torch.tensor([10], dtype=torch.uint8), # bias + (1, 1), # stride + (0, 0), # padding + (1, 1), # dilation + 1, # groups + 128, # in_zero_point + torch.tensor([128], dtype=torch.uint8), # weight_zero_point + torch.tensor([0.1], dtype=torch.float32), # bias_scale + 0.1, # output_scale + 128, # output_zero_point + typing.cast(None, torch.Tensor), + typing.cast(None, torch.Tensor), + torch.uint8, # dtype + torch.tensor( + [[[[238]]]], dtype=torch.uint8 + ), # (130 - 128) + (134 - 128) = 10 + # + bias -> 10 + 1 = 11 + # round(11 / 0.1 + 128) = 238, + memory_format, + ) + for memory_format in [torch.contiguous_format, torch.channels_last] + ], # Test case 4: 1D convolution (3D input tensor) - ( - torch.tensor( - [[[1, 2, 3, 4]]], dtype=torch.int8 - ), # input: 1x1x4 (N, C, W) - torch.tensor( - [[[1, 1]]], dtype=torch.int8 - ), # weight: 1x1x2 (OC, IC, KW) - torch.tensor([0], dtype=torch.int8), # bias - (1, 1), # stride (padding for 2D, actual stride is stride[1]) - (0, 0), # padding (padding for 2D, actual padding is padding[1]) - (1, 1), # dilation (padding for 2D, actual dilation is dilation[1]) - 1, # groups - 0, # in_zero_point - torch.tensor([0], dtype=torch.int8), # weight_zero_point - torch.tensor([1.0], dtype=torch.float32), # bias_scale - 0.5, # output_scale - 0, # output_zero_point - typing.cast(None, torch.Tensor), - typing.cast(None, torch.Tensor), - torch.int8, # dtype - torch.tensor( - [[[6, 10, 14]]], dtype=torch.int8 - ), # expected_output: [1+2, 2+3, 3+4] / 0.5 = [6, 10, 14] - ), + *[ + ( + torch.tensor( + [[[1, 2, 3, 4]]], dtype=torch.int8 + ), # input: 1x1x4 (N, C, W) + torch.tensor( + [[[1, 1]]], dtype=torch.int8 + ), # weight: 1x1x2 (OC, IC, KW) + torch.tensor([0], dtype=torch.int8), # bias + (1, 1), # stride (padding for 2D, actual stride is stride[1]) + (0, 0), # padding (padding for 2D, actual padding is padding[1]) + (1, 1), # dilation (padding for 2D, actual dilation is dilation[1]) + 1, # groups + 0, # in_zero_point + torch.tensor([0], dtype=torch.int8), # weight_zero_point + torch.tensor([1.0], dtype=torch.float32), # bias_scale + 0.5, # output_scale + 0, # output_zero_point + typing.cast(None, torch.Tensor), + typing.cast(None, torch.Tensor), + torch.int8, # dtype + torch.tensor( + [[[6, 10, 14]]], dtype=torch.int8 + ), # expected_output: [1+2, 2+3, 3+4] / 0.5 = [6, 10, 14] + memory_format, + ) + for memory_format in [torch.contiguous_format] + ], # Test case 5: Multiple output channels - ( - torch.tensor([[[[1, 2], [3, 4]]]], dtype=torch.int8), # input: 1x1x2x2 - torch.tensor( - [ - [[[1, 0], [0, 1]]], # first output channel - [[[0, 1], [1, 0]]], # second output channel - ], - dtype=torch.int8, - ), # weight: 2x1x2x2 - torch.tensor([0, 5], dtype=torch.int8), # bias for each output channel - (1, 1), # stride - (0, 0), # padding - (1, 1), # dilation - 1, # groups - 0, # in_zero_point - torch.tensor([0], dtype=torch.int8), # weight_zero_point - torch.tensor([1.0], dtype=torch.float32), # bias_scale - 0.2, # output_scale - 0, # output_zero_point - typing.cast(None, torch.Tensor), - typing.cast(None, torch.Tensor), - torch.int8, # dtype - torch.tensor( - [[[[25]], [[50]]]], dtype=torch.int8 - ), # expected_output: [5/0.2, 10/0.2] = [25, 50] - ), + *[ + ( + torch.tensor( + [[[[1, 2], [3, 4]]]], dtype=torch.int8 + ), # input: 1x1x2x2 + torch.tensor( + [ + [[[1, 0], [0, 1]]], # first output channel + [[[0, 1], [1, 0]]], # second output channel + ], + dtype=torch.int8, + ), # weight: 2x1x2x2 + torch.tensor( + [0, 5], dtype=torch.int32 + ), # bias for each output channel + (1, 1), # stride + (0, 0), # padding + (1, 1), # dilation + 1, # groups + 0, # in_zero_point + torch.tensor([0], dtype=torch.int8), # weight_zero_point + torch.tensor([1.0], dtype=torch.float32), # bias_scale + 0.2, # output_scale + 0, # output_zero_point + typing.cast(None, torch.Tensor), + typing.cast(None, torch.Tensor), + torch.int8, # dtype + torch.tensor( + [[[[25]], [[50]]]], dtype=torch.int8 + ), # expected_output: [5/0.2, 10/0.2] = [25, 50] + memory_format, + ) + for memory_format in [torch.contiguous_format, torch.channels_last] + ], # Test case 6: Multiple input channels - ( - torch.tensor( - [ + *[ + ( + torch.tensor( [ - [[1, 2], [3, 4]], # first input channel - [[5, 6], [7, 8]], - ] # second input channel - ], - dtype=torch.int16, - ), # input: 1x2x2x2 - torch.tensor( - [ + [ + [[1, 2], [3, 4]], # first input channel + [[5, 6], [7, 8]], + ] # second input channel + ], + dtype=torch.int16, + ), # input: 1x2x2x2 + torch.tensor( [ - [[1, 0], [0, 1]], # weights for first input channel - [[0, 1], [1, 0]], - ] # weights for second input channel - ], - dtype=torch.int16, - ), # weight: 1x2x2x2 (1 output channel, 2 input channels) - torch.tensor([0], dtype=torch.int16), # bias - (1, 1), # stride - (0, 0), # padding - (1, 1), # dilation - 1, # groups - 0, # in_zero_point - torch.tensor([0], dtype=torch.int16), # weight_zero_point - torch.tensor([1.0], dtype=torch.float32), # bias_scale - 0.1, # output_scale - 0, # output_zero_point - typing.cast(None, torch.Tensor), - typing.cast(None, torch.Tensor), - torch.int16, # dtype - torch.tensor( - [[[[180]]]], dtype=torch.int16 - ), # (1 + 4 + 6 + 7) / 0.1 = 180 - ), + [ + [[1, 0], [0, 1]], # weights for first input channel + [[0, 1], [1, 0]], + ] # weights for second input channel + ], + dtype=torch.int16, + ), # weight: 1x2x2x2 (1 output channel, 2 input channels) + torch.tensor([0], dtype=torch.int16), # bias + (1, 1), # stride + (0, 0), # padding + (1, 1), # dilation + 1, # groups + 0, # in_zero_point + torch.tensor([0], dtype=torch.int16), # weight_zero_point + torch.tensor([1.0], dtype=torch.float32), # bias_scale + 0.1, # output_scale + 0, # output_zero_point + typing.cast(None, torch.Tensor), + typing.cast(None, torch.Tensor), + torch.int16, # dtype + torch.tensor( + [[[[180]]]], dtype=torch.int16 + ), # expected_output: (1 + 4 + 6 + 7) / 0.1 = 180 + memory_format, + ) + for memory_format in [torch.contiguous_format, torch.channels_last] + ], # Test case 7: Multiple input and output channels - ( - torch.tensor( - [ - [ - [[1, 2], [3, 4]], # first input channel - [[2, 1], [4, 3]], - ] # second input channel - ], - dtype=torch.int16, - ), # input: 1x2x2x2 - torch.tensor( - [ + *[ + ( + torch.tensor( [ [ - [1, 1], - [1, 1], - ], # first output channel, first input channel - [[1, 1], [1, 1]], - ], # first output channel, second input channel + [[1, 2], [3, 4]], # first input channel + [[2, 1], [4, 3]], + ] # second input channel + ], + dtype=torch.int16, + ), # input: 1x2x2x2 + torch.tensor( [ [ - [1, 0], - [0, 1], - ], # second output channel, first input channel - [[0, 1], [1, 0]], - ], # second output channel, second input channel - ], - dtype=torch.int16, - ), # weight: 2x2x2x2 (2 output channels, 2 input channels) - torch.tensor([0, 0], dtype=torch.int16), # bias for each output channel - (1, 1), # stride - (0, 0), # padding - (1, 1), # dilation - 1, # groups - 0, # in_zero_point - torch.tensor( - [0], dtype=torch.int16 - ), # weight_zero_point for each output channel - torch.tensor([1.0], dtype=torch.float32), # bias_scale for each channel - 0.05, # output_scale - 0, # output_zero_point - typing.cast(None, torch.Tensor), - typing.cast(None, torch.Tensor), - torch.int16, # dtype - torch.tensor([[[[400]], [[200]]]], dtype=torch.int16), - ), + [ + [1, 1], + [1, 1], + ], # first output channel, first input channel + [[1, 1], [1, 1]], + ], # first output channel, second input channel + [ + [ + [1, 0], + [0, 1], + ], # second output channel, first input channel + [[0, 1], [1, 0]], + ], # second output channel, second input channel + ], + dtype=torch.int16, + ), # weight: 2x2x2x2 (2 output channels, 2 input channels) + torch.tensor( + [0, 0], dtype=torch.int32 + ), # bias for each output channel + (1, 1), # stride + (0, 0), # padding + (1, 1), # dilation + 1, # groups + 0, # in_zero_point + torch.tensor( + [0], dtype=torch.int16 + ), # weight_zero_point for each output channel + torch.tensor( + [1.0], dtype=torch.float32 + ), # bias_scale for each channel + 0.05, # output_scale + 0, # output_zero_point + typing.cast(None, torch.Tensor), + typing.cast(None, torch.Tensor), + torch.int16, # dtype + torch.tensor([[[[400]], [[200]]]], dtype=torch.int16), + memory_format, + ) + for memory_format in [torch.contiguous_format, torch.channels_last] + ], # Test case 8: Grouped convolution (groups=2) - ( - torch.tensor( - [ - [ - [[1, 2], [3, 4]], # first input channel (group 1) - [[5, 6], [7, 8]], - ] # second input channel (group 2) - ], - dtype=torch.int8, - ), # input: 1x2x2x2 - torch.tensor( - [ + *[ + ( + torch.tensor( [ - [[1, 1], [1, 1]] - ], # first output channel (processes first input channel) + [ + [[1, 2], [3, 4]], # first input channel (group 1) + [[5, 6], [7, 8]], + ] # second input channel (group 2) + ], + dtype=torch.int8, + ), # input: 1x2x2x2 + torch.tensor( [ - [[1, 0], [0, 1]] - ], # second output channel (processes second input channel) - ], - dtype=torch.int8, - ), # weight: 2x1x2x2 (2 output channels, 1 input channel each due to groups=2) - torch.tensor([0, 0], dtype=torch.int8), # bias for each output channel - (1, 1), # stride - (0, 0), # padding - (1, 1), # dilation - 2, # groups (grouped convolution) - 0, # in_zero_point - torch.tensor( - [0], dtype=torch.int8 - ), # weight_zero_point for each output channel - torch.tensor([1.0], dtype=torch.float32), # bias_scale for each channel - 0.2, # output_scale - 0, # output_zero_point - typing.cast(None, torch.Tensor), - typing.cast(None, torch.Tensor), - torch.int8, # dtype - torch.tensor( - [[[[50]], [[65]]]], dtype=torch.int8 - ), # expected_output: [(1+2+3+4)/0.2, (5+8)/0.2] = [50, 65] - ), + [ + [[1, 1], [1, 1]] + ], # first output channel (processes first input channel) + [ + [[1, 0], [0, 1]] + ], # second output channel (processes second input channel) + ], + dtype=torch.int8, + ), # weight: 2x1x2x2 (2 output channels, 1 input channel each due to groups=2) + torch.tensor( + [0, 0], dtype=torch.int32 + ), # bias for each output channel + (1, 1), # stride + (0, 0), # padding + (1, 1), # dilation + 2, # groups (grouped convolution) + 0, # in_zero_point + torch.tensor( + [0], dtype=torch.int8 + ), # weight_zero_point for each output channel + torch.tensor( + [1.0], dtype=torch.float32 + ), # bias_scale for each channel + 0.2, # output_scale + 0, # output_zero_point + typing.cast(None, torch.Tensor), + typing.cast(None, torch.Tensor), + torch.int8, # dtype + torch.tensor( + [[[[50]], [[65]]]], dtype=torch.int8 + ), # expected_output: [(1+2+3+4)/0.2, (5+8)/0.2] = [50, 65] + memory_format, + ) + for memory_format in [torch.contiguous_format, torch.channels_last] + ], # Test case 9: Convolution with stride=2 and padding=1 - ( - torch.tensor( - [[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]]], - dtype=torch.int8, - ), # input: 1x1x4x4 - torch.tensor( - [[[[1, 1], [1, 1]]]], dtype=torch.int8 - ), # weight: 1x1x2x2 (sum filter) - torch.tensor([0], dtype=torch.int8), # bias - (2, 2), # stride=2 - (1, 1), # padding=1 - (1, 1), # dilation - 1, # groups - 0, # in_zero_point - torch.tensor([0], dtype=torch.int8), # weight_zero_point - torch.tensor([1.0], dtype=torch.float32), # bias_scale - 0.5, # output_scale - 0, # output_zero_point - typing.cast(None, torch.Tensor), - typing.cast(None, torch.Tensor), - torch.int8, # dtype - torch.tensor( - [[[[2, 10, 8], [28, 68, 40], [26, 58, 32]]]], dtype=torch.int8 - ), - ), + *[ + ( + torch.tensor( + [ + [ + [ + [1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12], + [13, 14, 15, 16], + ] + ] + ], + dtype=torch.int8, + ), # input: 1x1x4x4 + torch.tensor( + [[[[1, 1], [1, 1]]]], dtype=torch.int8 + ), # weight: 1x1x2x2 (sum filter) + torch.tensor([0], dtype=torch.int8), # bias + (2, 2), # stride=2 + (1, 1), # padding=1 + (1, 1), # dilation + 1, # groups + 0, # in_zero_point + torch.tensor([0], dtype=torch.int8), # weight_zero_point + torch.tensor([1.0], dtype=torch.float32), # bias_scale + 0.5, # output_scale + 0, # output_zero_point + typing.cast(None, torch.Tensor), + typing.cast(None, torch.Tensor), + torch.int8, # dtype + torch.tensor( + [[[[2, 10, 8], [28, 68, 40], [26, 58, 32]]]], dtype=torch.int8 + ), + memory_format, + ) + for memory_format in [torch.contiguous_format, torch.channels_last] + ], ] ) - def test_quantized_conv_nchw( + def test_quantized_conv( self, input_tensor: torch.Tensor, weight: torch.Tensor, @@ -639,8 +699,22 @@ def test_quantized_conv_nchw( out_shift: torch.Tensor, dtype: torch.dtype, expected_output: torch.Tensor, + memory_format: torch.memory_format, ) -> None: - output = quantized_conv_nchw( + assert memory_format in [torch.contiguous_format, torch.channels_last] + + if len(input_tensor.shape) == 3 and memory_format == torch.channels_last: + self.fail("Channels last format is not supported for 3D input tensors") + + input_tensor = input_tensor.to(memory_format=memory_format) + + conv = ( + quantized_conv_nchw + if memory_format == torch.contiguous_format + else quantized_conv_nhwc + ) + + output = conv( input_tensor, weight, bias, @@ -655,7 +729,7 @@ def test_quantized_conv_nchw( output_zero_point, out_multiplier, out_shift, - ) + ).to(memory_format=torch.contiguous_format) # Verify output properties self.assertEqual(output.dtype, dtype, f"Output dtype should be {dtype}")