From 8bfcd6d6a73d37cc56de0dda95fa1abfaf764af7 Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Tue, 9 Sep 2025 14:42:35 -0700 Subject: [PATCH] Add uint8/int8 specializations for conv per tensor (#14033) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/14033 Continued support of adding custom Cadence python references Reviewed By: hsharma35 Differential Revision: D81720359 --- backends/cadence/aot/ops_registrations.py | 60 ++++++++ backends/cadence/aot/ref_implementations.py | 141 +++++++++++++++++- .../aot/tests/test_ref_implementations.py | 124 ++++++++++----- 3 files changed, 285 insertions(+), 40 deletions(-) diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index 68091e2d521..ce0fba47610 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -873,6 +873,11 @@ def quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor_meta( out_multiplier: int, out_shift: int, ) -> torch.Tensor: + assert ( + input.dtype == torch.int8 + and weight.dtype == torch.int8 + and bias.dtype == torch.int32 + ) out_channels, _, *kernel_size = weight.shape in_size = input.shape @@ -917,6 +922,11 @@ def quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor_meta( out_multiplier: int, out_shift: int, ) -> torch.Tensor: + assert ( + input.dtype == torch.uint8 + and weight.dtype == torch.uint8 + and bias.dtype == torch.int32 + ) out_channels, _, *kernel_size = weight.shape in_size = input.shape @@ -961,6 +971,11 @@ def quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor_meta( out_multiplier: int, out_shift: int, ) -> torch.Tensor: + assert ( + input.dtype == torch.int8 + and weight.dtype == torch.int8 + and bias.dtype == torch.int32 + ) out_channels, *kernel_size, _ = weight.shape in_size = input.shape @@ -1005,6 +1020,11 @@ def quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor_meta( out_multiplier: int, out_shift: int, ) -> torch.Tensor: + assert ( + input.dtype == torch.uint8 + and weight.dtype == torch.uint8 + and bias.dtype == torch.int32 + ) out_channels, *kernel_size, _ = weight.shape in_size = input.shape @@ -1049,6 +1069,11 @@ def quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor_meta( out_multiplier: int, out_shift: int, ) -> torch.Tensor: + assert ( + input.dtype == torch.int8 + and weight.dtype == torch.int8 + and bias.dtype == torch.int32 + ) out_channels, _, *kernel_size = weight.shape in_size = input.shape @@ -1093,6 +1118,11 @@ def quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor_meta( out_multiplier: int, out_shift: int, ) -> torch.Tensor: + assert ( + input.dtype == torch.uint8 + and weight.dtype == torch.uint8 + and bias.dtype == torch.int32 + ) out_channels, _, *kernel_size = weight.shape in_size = input.shape @@ -1137,6 +1167,11 @@ def quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor_meta( out_multiplier: int, out_shift: int, ) -> torch.Tensor: + assert ( + input.dtype == torch.int8 + and weight.dtype == torch.int8 + and bias.dtype == torch.int32 + ) out_channels, *kernel_size, _ = weight.shape in_size = input.shape @@ -1181,6 +1216,11 @@ def quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor_meta( out_multiplier: int, out_shift: int, ) -> torch.Tensor: + assert ( + input.dtype == torch.uint8 + and weight.dtype == torch.uint8 + and bias.dtype == torch.int32 + ) out_channels, *kernel_size, _ = weight.shape in_size = input.shape @@ -1225,6 +1265,11 @@ def quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor_meta( out_multiplier: int, out_shift: int, ) -> torch.Tensor: + assert ( + input.dtype == torch.int8 + and weight.dtype == torch.int8 + and bias.dtype == torch.int32 + ) out_channels, _, *kernel_size = weight.shape in_size = input.shape @@ -1269,6 +1314,11 @@ def quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_meta( out_multiplier: int, out_shift: int, ) -> torch.Tensor: + assert ( + input.dtype == torch.uint8 + and weight.dtype == torch.uint8 + and bias.dtype == torch.int32 + ) out_channels, _, *kernel_size = weight.shape in_size = input.shape @@ -1313,6 +1363,11 @@ def quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_meta( out_multiplier: int, out_shift: int, ) -> torch.Tensor: + assert ( + input.dtype == torch.int8 + and weight.dtype == torch.int8 + and bias.dtype == torch.int32 + ) out_channels, *kernel_size, _ = weight.shape in_size = input.shape @@ -1357,6 +1412,11 @@ def quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_meta( out_multiplier: int, out_shift: int, ) -> torch.Tensor: + assert ( + input.dtype == torch.uint8 + and weight.dtype == torch.uint8 + and bias.dtype == torch.int32 + ) out_channels, *kernel_size, _ = weight.shape in_size = input.shape diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 790341f8f5a..0cd55326b86 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -7,7 +7,7 @@ # pyre-strict -from typing import Optional +from typing import Callable, Optional import torch @@ -479,6 +479,145 @@ def quantized_conv_nhwc_per_tensor( ) +def quantized_conv_variant( + layout: str, + input_dtype: torch.dtype, + weight_dtype: torch.dtype, +) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]: + """Create a quantized conv variant with type checking.""" + + def decorator(_: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: + def variant( + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: tuple[int, int], + padding: tuple[int, int], + dilation: tuple[int, int], + groups: int, + in_zero_point: int, + weight_zero_point: int, + bias_scale: float, + output_scale: float, + output_zero_point: int, + out_multiplier: int, + out_shift: int, + ) -> torch.Tensor: + assert ( + input_tensor.dtype == input_dtype + ), f"Expected input dtype {input_dtype}, got {input_tensor.dtype}" + assert ( + weight.dtype == weight_dtype + ), f"Expected weight dtype {weight_dtype}, got {weight.dtype}" + + assert ( + bias.dtype == torch.int32 + ), f"Expected bias dtype int32, got {bias.dtype}" + + # Call the appropriate base function + match layout: + case "nchw": + return quantized_conv_nchw_per_tensor( + input_tensor, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out_multiplier, + out_shift, + ) + case "nhwc": + return quantized_conv_nhwc_per_tensor( + input_tensor, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out_multiplier, + out_shift, + ) + case _: + raise ValueError(f"Unknown layout {layout}") + + return variant + + return decorator + + +@impl(m, "quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor") +@quantized_conv_variant("nchw", torch.int8, torch.int8) +def quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... + + +@impl(m, "quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor") +@quantized_conv_variant("nchw", torch.uint8, torch.uint8) +def quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... + + +@impl(m, "quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor") +@quantized_conv_variant("nhwc", torch.int8, torch.int8) +def quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... + + +@impl(m, "quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor") +@quantized_conv_variant("nhwc", torch.uint8, torch.uint8) +def quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... + + +@impl(m, "quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor") +@quantized_conv_variant("nchw", torch.int8, torch.int8) +def quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... + + +@impl(m, "quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor") +@quantized_conv_variant("nchw", torch.uint8, torch.uint8) +def quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... + + +@impl(m, "quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor") +@quantized_conv_variant("nhwc", torch.int8, torch.int8) +def quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... + + +@impl(m, "quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor") +@quantized_conv_variant("nhwc", torch.uint8, torch.uint8) +def quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... + + +@impl(m, "quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor") +@quantized_conv_variant("nchw", torch.int8, torch.int8) +def quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... + + +@impl(m, "quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor") +@quantized_conv_variant("nchw", torch.uint8, torch.uint8) +def quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... + + +@impl(m, "quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor") +@quantized_conv_variant("nhwc", torch.int8, torch.int8) +def quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... + + +@impl(m, "quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor") +@quantized_conv_variant("nhwc", torch.uint8, torch.uint8) +def quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... + + @impl(m, "quantized_relu") def quantized_relu( X: torch.Tensor, diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index 54247e0b53b..4e2829a8460 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -15,7 +15,19 @@ dequantize_per_tensor, quantize_per_tensor, quantized_add, + quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor, + quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor, + quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor, + quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor, + quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor, + quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor, quantized_conv_nchw_per_tensor, + quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor, + quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor, + quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor, + quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor, + quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor, + quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor, quantized_conv_nhwc_per_tensor, quantized_layer_norm_per_tensor, quantized_linear, @@ -350,7 +362,7 @@ def test_quantized_layer_norm_per_tensor( torch.tensor( [[[[1, 0], [0, 1]]]], dtype=torch.int8 ), # weight: 1x1x2x2 (identity-like) - torch.tensor([0], dtype=torch.int8), # bias + torch.tensor([0], dtype=torch.int32), # bias (1, 1), # stride (0, 0), # padding (1, 1), # dilation @@ -381,7 +393,7 @@ def test_quantized_layer_norm_per_tensor( torch.tensor( [[[[1, 1], [1, 1]]]], dtype=torch.int8 ), # weight: 1x1x2x2 (sum filter) - torch.tensor([0], dtype=torch.int8), # bias + torch.tensor([0], dtype=torch.int32), # bias (1, 1), # stride (0, 0), # padding (1, 1), # dilation @@ -410,7 +422,7 @@ def test_quantized_layer_norm_per_tensor( torch.tensor( [[[[129, 128], [128, 129]]]], dtype=torch.uint8 ), # weight: 1x1x2x2 (values close to zero_point) - torch.tensor([10], dtype=torch.uint8), # bias + torch.tensor([10], dtype=torch.int32), # bias (1, 1), # stride (0, 0), # padding (1, 1), # dilation @@ -441,7 +453,7 @@ def test_quantized_layer_norm_per_tensor( torch.tensor( [[[1, 1]]], dtype=torch.int8 ), # weight: 1x1x2 (OC, IC, KW) - torch.tensor([0], dtype=torch.int8), # bias + torch.tensor([0], dtype=torch.int32), # bias (1, 1), # stride (padding for 2D, actual stride is stride[1]) (0, 0), # padding (padding for 2D, actual padding is padding[1]) (1, 1), # dilation (padding for 2D, actual dilation is dilation[1]) @@ -517,7 +529,7 @@ def test_quantized_layer_norm_per_tensor( ], dtype=torch.int16, ), # weight: 1x2x2x2 (1 output channel, 2 input channels) - torch.tensor([0], dtype=torch.int16), # bias + torch.tensor([0], dtype=torch.int32), # bias (1, 1), # stride (0, 0), # padding (1, 1), # dilation @@ -652,7 +664,7 @@ def test_quantized_layer_norm_per_tensor( torch.tensor( [[[[1, 1], [1, 1]]]], dtype=torch.int8 ), # weight: 1x1x2x2 (sum filter) - torch.tensor([0], dtype=torch.int8), # bias + torch.tensor([0], dtype=torch.int32), # bias (2, 2), # stride=2 (1, 1), # padding=1 (1, 1), # dilation @@ -701,42 +713,76 @@ def test_quantized_conv_per_tensor( input_tensor = input_tensor.to(memory_format=memory_format) - conv = ( - quantized_conv_nchw_per_tensor - if memory_format == torch.contiguous_format - else quantized_conv_nhwc_per_tensor - ) + convs = [ + ( + quantized_conv_nchw_per_tensor + if memory_format == torch.contiguous_format + else quantized_conv_nhwc_per_tensor + ) + ] - output = conv( - input_tensor, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out_multiplier, - out_shift, - ).to(memory_format=torch.contiguous_format) + optimized_convs = [] + if input_tensor.dtype == torch.int8 and weight.dtype == torch.int8: + if input_tensor.is_contiguous(memory_format=torch.contiguous_format): + optimized_convs = [ + quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor, + quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor, + quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor, + ] - # Verify output properties - self.assertEqual(output.dtype, dtype, f"Output dtype should be {dtype}") - self.assertEqual( - output.shape, - expected_output.shape, - "Output shape should match expected shape", - ) + else: + optimized_convs = [ + quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor, + quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor, + quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor, + ] + elif input_tensor.dtype == torch.uint8 and weight.dtype == torch.uint8: + if input_tensor.is_contiguous(memory_format=torch.contiguous_format): + optimized_convs = [ + quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor, + quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor, + quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor, + ] - # Verify output matches expected values - self.assertTrue( - torch.equal(output, expected_output), - f"Output values don't match expected. Got {output}, expected {expected_output}", - ) + else: + optimized_convs = [ + quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor, + quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor, + quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor, + ] + + convs.extend(optimized_convs) + for conv in convs: + output = conv( + input_tensor, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out_multiplier, + out_shift, + ).to(memory_format=torch.contiguous_format) + + # Verify output properties + self.assertEqual(output.dtype, dtype, f"Output dtype should be {dtype}") + self.assertEqual( + output.shape, + expected_output.shape, + "Output shape should match expected shape", + ) + + # Verify output matches expected values + self.assertTrue( + torch.equal(output, expected_output), + f"Output values don't match expected. Got {output}, expected {expected_output}", + ) @expand( [