diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index 030d10438fb..145dc16557a 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -53,16 +53,10 @@ def _validate_ref_impl_exists() -> None: # 1. be removed # 2. have a reference implementation added to ref_implementations.py _WARN_ONLY = { - "cadence::quantized_add", # We should only support per_tensor variant, should remove "cadence::_softmax_f32_f32", - "cadence::requantize", # We should only support per_tensor variant, should remove "cadence::quantized_softmax.per_tensor", - "cadence::quantized_conv2d_nchw", # We should only support per_tensor variant, should remove - "cadence::quantized_relu", # We should only support per_tensor variant, should remove - "cadence::quantized_conv2d_nhwc", # We should only support per_tensor variant, should remove "cadence::quantized_softmax", "cadence::quantized_w8a32_gru", - "cadence::quantized_layer_norm", # We should only support per_tensor variant, should remove } ref_impls = get_registered_ref_implementations() diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index b5523427a69..e284a7e639b 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -6,7 +6,7 @@ # pyre-strict -from typing import Callable +from typing import Callable, Protocol, TypeVar import torch import torch.nn as nn @@ -21,13 +21,15 @@ # Registry to track all ops with reference implementations _REGISTERED_REF_IMPLEMENTATIONS: set[str] = set() -_OUTPUTS_TYPE = torch.Tensor | tuple[torch.Tensor, ...] +T = TypeVar("T", bound=Callable[..., torch.Tensor | tuple[torch.Tensor, ...]]) + + +class MyDecorator(Protocol): + def __call__(self, __f: T) -> T: ... # Custom impl wrapper that tracks registrations -def impl_tracked( - lib: Library, op_name: str -) -> Callable[[Callable[..., _OUTPUTS_TYPE]], Callable[..., _OUTPUTS_TYPE]]: +def impl_tracked(lib: Library, op_name: str) -> MyDecorator: """Wrapper around impl that tracks registered ops.""" _REGISTERED_REF_IMPLEMENTATIONS.add(op_name) return impl(lib, op_name) @@ -314,7 +316,7 @@ def quantized_add_per_tensor( dequant_Y = Y_scale * (Y - Y_zero_point) # q_min/q_max are unused args - out = quantize_per_tensor( + return quantize_per_tensor( dequant_X + dequant_Y, out_scale, out_zero_point, @@ -323,8 +325,28 @@ def quantized_add_per_tensor( dtype, ) - assert isinstance(out, torch.Tensor) - return out + +@impl_tracked(m, "quantized_add") +def quantized_add( + X: torch.Tensor, + X_scale: torch.Tensor, + X_zero_point: torch.Tensor, + Y: torch.Tensor, + Y_scale: torch.Tensor, + Y_zero_point: torch.Tensor, + out_scale: float, + out_zero_point: int, +) -> torch.Tensor: + return quantized_add_per_tensor( + X, + float(X_scale.item()), + int(X_zero_point.item()), + Y, + float(Y_scale.item()), + int(Y_zero_point.item()), + out_scale, + out_zero_point, + ) @impl_tracked(m, "quantized_add_asym8sxasym8s_asym8s.per_tensor") @@ -343,11 +365,9 @@ def quantized_add_asym8sxasym8s_asym8s_per_tensor( if Y.dtype != torch.int8: raise ValueError("Y dtype must be torch.int8") - out = quantized_add_per_tensor( + return quantized_add_per_tensor( X, X_scale, X_zero_point, Y, Y_scale, Y_zero_point, out_scale, out_zero_point ) - assert isinstance(out, torch.Tensor) - return out @impl_tracked(m, "quantized_add_asym8uxasym8u_asym8u.per_tensor") @@ -366,11 +386,9 @@ def quantized_add_asym8uxasym8u_asym8u_per_tensor( if Y.dtype != torch.uint8: raise ValueError("Y dtype must be torch.int8") - out = quantized_add_per_tensor( + return quantized_add_per_tensor( X, X_scale, X_zero_point, Y, Y_scale, Y_zero_point, out_scale, out_zero_point ) - assert isinstance(out, torch.Tensor) - return out def quantized_linear_common( @@ -416,16 +434,14 @@ def quantized_linear_common( (weight - weight_zero_point).float(), bias.float(), ) - out = quantize_per_tensor( + return quantize_per_tensor( out, out_scale, out_zero_point, torch.iinfo(dtype).min, torch.iinfo(dtype).max, dtype, - ) - assert isinstance(out, torch.Tensor) - return out.reshape(*leading_dims, N) + ).reshape(*leading_dims, N) def quantized_linear_variant( @@ -587,7 +603,7 @@ def quantized_matmul( (X - X_zero_point).float(), (Y - Y_zero_point).float(), ) - out = quantize_per_tensor( + return quantize_per_tensor( out, out_scale, out_zero_point, @@ -595,8 +611,6 @@ def quantized_matmul( torch.iinfo(X.dtype).max, X.dtype, ) - assert isinstance(out, torch.Tensor) - return out @impl_tracked(m, "quantized_matmul_asym8sxasym8s_asym8s") @@ -616,7 +630,7 @@ def quantized_matmul_asym8sxasym8s_asym8s( if Y.dtype != torch.int8: raise ValueError("Y dtype must be torch.int8") - out = quantized_matmul( + return quantized_matmul( X, X_zero_point, Y, @@ -627,8 +641,6 @@ def quantized_matmul_asym8sxasym8s_asym8s( out_zero_point, transposed, ) - assert isinstance(out, torch.Tensor) - return out @impl_tracked(m, "quantized_matmul_asym8uxasym8u_asym8u") @@ -648,7 +660,7 @@ def quantized_matmul_asym8uxasym8u_asym8u( if Y.dtype != torch.uint8: raise ValueError("Y dtype must be torch.uint8") - out = quantized_matmul( + return quantized_matmul( X, X_zero_point, Y, @@ -659,8 +671,6 @@ def quantized_matmul_asym8uxasym8u_asym8u( out_zero_point, transposed, ) - assert isinstance(out, torch.Tensor) - return out @impl_tracked(m, "quantized_layer_norm.per_tensor") @@ -703,7 +713,7 @@ def quantized_layer_norm_per_tensor( float_input_tensor, normalized_shape, weight, bias, eps=eps ) - out = quantize_per_tensor( + return quantize_per_tensor( out, output_scale, output_zero_point, @@ -711,8 +721,31 @@ def quantized_layer_norm_per_tensor( torch.iinfo(input_tensor.dtype).max, input_tensor.dtype, ) - assert isinstance(out, torch.Tensor) - return out + + +@impl_tracked(m, "quantized_layer_norm") +def quantized_layer_norm( + input_tensor: torch.Tensor, + X_scale: torch.Tensor, + X_zero_point: torch.Tensor, + normalized_shape: list[int], + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + output_scale: float, + output_zero_point: int, +) -> torch.Tensor: + return quantized_layer_norm_per_tensor( + input_tensor, + float(X_scale.item()), + int(X_zero_point.item()), + normalized_shape, + weight, + bias, + eps, + output_scale, + output_zero_point, + ) def quantized_conv_per_tensor( @@ -774,7 +807,7 @@ def quantized_conv_per_tensor( else: raise ValueError("Input tensor must be 3D or 4D") - out = quantize_per_tensor( + return quantize_per_tensor( float_out, output_scale, output_zero_point, @@ -782,8 +815,6 @@ def quantized_conv_per_tensor( torch.iinfo(input_tensor.dtype).max, input_tensor.dtype, ) - assert isinstance(out, torch.Tensor) - return out @impl_tracked(m, "quantized_conv2d_nchw.per_tensor") @@ -842,6 +873,41 @@ def quantized_conv2d_nchw_per_tensor( ) +@impl_tracked(m, "quantized_conv2d_nchw") +def quantized_conv2d_nchw( + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: tuple[int, int], + padding: tuple[int, int], + dilation: tuple[int, int], + groups: int, + in_zero_point: int, + weight_zero_point: torch.Tensor, + bias_scale: torch.Tensor, + output_scale: float, + output_zero_point: int, + out_multiplier: torch.Tensor, + out_shift: torch.Tensor, +) -> torch.Tensor: + return quantized_conv2d_nchw_per_tensor( + input_tensor, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + int(weight_zero_point.item()), + float(bias_scale.item()), + output_scale, + output_zero_point, + int(out_multiplier.item()), + int(out_shift.item()), + ) + + @impl_tracked(m, "quantized_w8a32_conv") def quantized_w8a32_conv( src: torch.Tensor, @@ -993,6 +1059,41 @@ def quantized_conv2d_nhwc_per_tensor( return nchw_out.movedim(-3, -1).contiguous() +@impl_tracked(m, "quantized_conv2d_nhwc") +def quantized_conv2d_nhwc( + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: tuple[int, int], + padding: tuple[int, int], + dilation: tuple[int, int], + groups: int, + in_zero_point: int, + weight_zero_point: torch.Tensor, + bias_scale: torch.Tensor, + output_scale: float, + output_zero_point: int, + out_multiplier: torch.Tensor, + out_shift: torch.Tensor, +) -> torch.Tensor: + return quantized_conv2d_nhwc_per_tensor( + input_tensor, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + int(weight_zero_point.item()), + float(bias_scale.item()), + output_scale, + output_zero_point, + int(out_multiplier.item()), + int(out_shift.item()), + ) + + def quantized_conv_variant( layout: str, input_dtype: torch.dtype, @@ -1040,7 +1141,7 @@ def variant( # Call the appropriate base function match layout: case "nchw": - out = quantized_conv2d_nchw_per_tensor( + return quantized_conv2d_nchw_per_tensor( input_tensor, weight, bias, @@ -1057,7 +1158,7 @@ def variant( out_shift, ) case "nhwc": - out = quantized_conv2d_nhwc_per_tensor( + return quantized_conv2d_nhwc_per_tensor( input_tensor, weight, bias, @@ -1076,9 +1177,6 @@ def variant( case _: raise ValueError(f"Unknown layout {layout}") - assert isinstance(out, torch.Tensor) - return out - return variant return decorator @@ -1409,6 +1507,19 @@ def quantized_relu_asym8s_asym8s_per_tensor() -> torch.Tensor: ... def quantized_relu_asym8u_asym8u_per_tensor() -> torch.Tensor: ... +@impl_tracked(m, "quantized_relu") +def quantized_relu( + X: torch.Tensor, + X_zero_point: torch.Tensor, + out_zero_point: int, + out_multiplier: torch.Tensor, + out_shift: torch.Tensor, +) -> torch.Tensor: + return quantized_relu_per_tensor( + X, X_zero_point.item(), out_zero_point, out_multiplier.item(), out_shift.item() + ) + + @impl_tracked(m, "requantize.per_tensor") def requantize_per_tensor( input: torch.Tensor, @@ -1447,6 +1558,25 @@ def requantize_per_tensor( ) +@impl_tracked(m, "requantize") +def requantize( + input_tensor: torch.Tensor, + in_scale: torch.Tensor, + in_zero_point: torch.Tensor, + out_scale: torch.Tensor, + out_zero_point: torch.Tensor, + dtype: ScalarType, +) -> torch.Tensor: + return requantize_per_tensor( + input_tensor, + float(in_scale.item()), + int(in_zero_point.item()), + float(out_scale.item()), + int(out_zero_point.item()), + dtype, + ) + + @impl_tracked(m, "rms_norm") def rms_norm( X: torch.Tensor, diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index cb4b26c59e1..02485b0ae09 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -15,6 +15,8 @@ import torch from executorch.backends.cadence.aot.typing_stubs import expand +from executorch.exir.scalar_type import ScalarType + class TestRefImplementations(unittest.TestCase): @expand( @@ -103,7 +105,7 @@ def test_dequantize_per_tensor( ("uint8", 5, 0.8, 4, 5, 0.8, 4, 0.8, 4, 6, torch.uint8), ] ) - def test_quantized_add( + def test_quantized_add_per_tensor( self, name: str, X: int, @@ -121,28 +123,12 @@ def test_quantized_add( Y_tensor = torch.tensor([Y], dtype=dtype) expected_output = torch.tensor([expected_value], dtype=dtype) - quantized_add = ( + quantized_add_per_tensor = ( torch.ops.cadence.quantized_add_asym8sxasym8s_asym8s.per_tensor if dtype == torch.int8 else torch.ops.cadence.quantized_add_asym8uxasym8u_asym8u.per_tensor ) - output = quantized_add( - X_tensor, - X_scale, - X_zero_point, - Y_tensor, - Y_scale, - Y_zero_point, - out_scale, - out_zero_point, - ) - - self.assertTrue( - torch.equal(output, expected_output), - f"Values don't match in {name}: got {output}, expected {expected_output}", - ) - - output = torch.ops.cadence.quantized_add( + output = quantized_add_per_tensor( X_tensor, X_scale, X_zero_point, @@ -1412,7 +1398,7 @@ def test_quantized_w8a32_linear( ], ] ) - def test_quantized_relu( + def test_quantized_relu_per_tensor( self, name: str, X: torch.Tensor, @@ -1426,17 +1412,17 @@ def test_quantized_relu( match dtype: case torch.int8: - quantized_relu = ( + quantized_relu_per_tensor = ( torch.ops.cadence.quantized_relu_asym8s_asym8s.per_tensor ) case torch.uint8: - quantized_relu = ( + quantized_relu_per_tensor = ( torch.ops.cadence.quantized_relu_asym8u_asym8u.per_tensor ) case _: - quantized_relu = torch.ops.cadence.quantized_relu_per_tensor + quantized_relu_per_tensor = torch.ops.cadence.quantized_relu_per_tensor - output = quantized_relu( + output = quantized_relu_per_tensor( X, X_zero_point, out_zero_point, @@ -2758,3 +2744,144 @@ def test_linalg_svd_outputs_are_contiguous( self.assertTrue(U.dtype == dtype, "U dtype mismatch") self.assertTrue(S.dtype == dtype, "S dtype mismatch") self.assertTrue(Vh.dtype == dtype, "Vh dtype mismatch") + + def test_quantized_add(self) -> None: + # Test quantized_add (default variant), just to make sure it runs since wrapper around per_tensor variant + X = torch.tensor([[1, 2], [3, 4]], dtype=torch.int8) + X_scale = torch.tensor([0.1]) + X_zero_point = torch.tensor([0]) + Y = torch.tensor([[5, 6], [7, 8]], dtype=torch.int8) + Y_scale = torch.tensor([0.1]) + Y_zero_point = torch.tensor([0]) + out_scale = 0.1 + out_zero_point = 0 + torch.ops.cadence.quantized_add( + X, + X_scale, + X_zero_point, + Y, + Y_scale, + Y_zero_point, + out_scale, + out_zero_point, + ) + + def test_requantize(self) -> None: + # Test requantize (default variant), just to make sure it runs since wrapper around per_tensor variant + input_tensor = torch.tensor([[1, 2], [3, 4]], dtype=torch.int8) + in_scale = torch.tensor([0.1]) + in_zero_point = torch.tensor([0]) + out_scale_tensor = torch.tensor([0.2]) + out_zero_point_tensor = torch.tensor([0]) + torch.ops.cadence.requantize( + input_tensor, + in_scale, + in_zero_point, + out_scale_tensor, + out_zero_point_tensor, + ScalarType.CHAR, + ) + + def test_quantized_conv2d_nchw(self) -> None: + # Test quantized_conv2d_nchw (default variant), just to make sure it runs since wrapper around per_tensor variant + input_conv = torch.tensor([[[[1, 2], [3, 4]]]], dtype=torch.int8) + weight_conv = torch.tensor([[[[1, 0], [0, 1]]]], dtype=torch.int8) + bias_conv = torch.tensor([0], dtype=torch.int32) + stride = [1, 1] + padding = [0, 0] + dilation = [1, 1] + groups = 1 + input_zero_point = 0 + weight_zero_point = torch.tensor([0]) + bias_scale = torch.tensor([1.0]) + conv_out_scale = 0.1 + conv_out_zero_point = 0 + out_multiplier = torch.tensor([1073741824], dtype=torch.int32) + out_shift = torch.tensor([0], dtype=torch.int32) + torch.ops.cadence.quantized_conv2d_nchw( + input_conv, + weight_conv, + bias_conv, + stride, + padding, + dilation, + groups, + input_zero_point, + weight_zero_point, + bias_scale, + conv_out_scale, + conv_out_zero_point, + out_multiplier, + out_shift, + ) + + def test_quantized_relu(self) -> None: + # Test quantized_relu (default variant), just to make sure it runs since wrapper around per_tensor variant + X_relu = torch.tensor([[-1, 0, 1, 3]], dtype=torch.int8) + X_zero_point_relu = torch.tensor([0]) + relu_out_zero_point = 0 + out_multiplier_relu = torch.tensor([1073741824], dtype=torch.int32) + out_shift_relu = torch.tensor([0], dtype=torch.int32) + torch.ops.cadence.quantized_relu( + X_relu, + X_zero_point_relu, + relu_out_zero_point, + out_multiplier_relu, + out_shift_relu, + ) + + def test_quantized_conv2d_nhwc(self) -> None: + # Test quantized_conv2d_nhwc (default variant), just to make sure it runs since wrapper around per_tensor variant + stride = [1, 1] + padding = [0, 0] + dilation = [1, 1] + groups = 1 + input_zero_point = 0 + weight_zero_point = torch.tensor([0]) + bias_scale = torch.tensor([1.0]) + conv_out_scale = 0.1 + conv_out_zero_point = 0 + input_nhwc = torch.tensor([[[[1], [2]], [[3], [4]]]], dtype=torch.int8) + weight_nhwc = torch.tensor([[[[1], [0]], [[0], [1]]]], dtype=torch.int8) + bias_nhwc = torch.tensor([0], dtype=torch.int32) + out_multiplier = torch.tensor([1073741824], dtype=torch.int32) + out_shift = torch.tensor([0], dtype=torch.int32) + torch.ops.cadence.quantized_conv2d_nhwc( + input_nhwc, + weight_nhwc, + bias_nhwc, + stride, + padding, + dilation, + groups, + input_zero_point, + weight_zero_point, + bias_scale, + conv_out_scale, + conv_out_zero_point, + out_multiplier, + out_shift, + ) + + def test_quantized_layer_norm(self) -> None: + # Test quantized_layer_norm (default variant), just to make sure it runs since wrapper around per_tensor variant + X_ln = torch.tensor([[-1, 1]], dtype=torch.int8) + X_scale_ln = torch.tensor([0.1]) + X_zero_point_ln = torch.tensor([0]) + normalized_shape = [2] + weight_ln = torch.tensor([1.0, 1.0]) + bias_ln = torch.tensor([0.0, 0.0]) + eps = 1e-5 + output_scale = 0.1 + output_zero_point = 0 + torch.ops.cadence.quantized_layer_norm( + X_ln, + X_scale_ln, + X_zero_point_ln, + normalized_shape, + weight_ln, + bias_ln, + eps, + output_scale, + output_zero_point, + )