From c4b12bfaa7b2e467e265c0101903ef2d70a2ed59 Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Wed, 10 Sep 2025 13:04:45 -0700 Subject: [PATCH] Support for all quantized linear ops (#14078) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/14078 Continued support for reference implementations of all custom Cadence ops. Reviewed By: hsharma35 Differential Revision: D81940978 --- backends/cadence/aot/ref_implementations.py | 107 ++++++++- .../aot/tests/test_ref_implementations.py | 205 +++++++++++++----- 2 files changed, 251 insertions(+), 61 deletions(-) diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 4d5e3beac26..b1f43d12890 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -7,7 +7,7 @@ # pyre-strict -from typing import Callable, Optional +from typing import Callable import torch @@ -193,17 +193,15 @@ def quantized_add( ) -@impl(m, "quantized_linear") -def quantized_linear( +def quantized_linear_common( src: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, in_zero_point: int, - weight_zero_point: torch.Tensor, - out_multiplier: torch.Tensor, - out_shift: torch.Tensor, + weight_zero_point: torch.Tensor | int, + out_multiplier: torch.Tensor | int, + out_shift: int, out_zero_point: int, - offset: Optional[torch.Tensor], ) -> torch.Tensor: """ Quantized linear (transposed matmul) operation. @@ -219,7 +217,7 @@ def quantized_linear( - out_zero_point (int): The quantized mapping of zero for the output - offset (Tensor): Unused """ - out_scale = -out_multiplier * (1 / (1 << 31)) * (2 ** out_shift[0]) + out_scale = -out_multiplier * (1 / (1 << 31)) * (2**out_shift) out_scale_inv = 1 / out_scale N, K = weight.shape @@ -235,7 +233,9 @@ def quantized_linear( ) out = torch.nn.functional.linear( - src - in_zero_point, weight - weight_zero_point, bias + (src - in_zero_point).float(), + (weight - weight_zero_point).float(), + bias.float(), ) return quantize_per_tensor( out, @@ -247,6 +247,95 @@ def quantized_linear( ).reshape(*leading_dims, N) +def quantized_linear_variant( + per_tensor: bool, + src_dtype: torch.dtype | None = None, + weight_dtype: torch.dtype | None = None, +) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]: + + def decorator(_: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: + def variant( + src: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + in_zero_point: int, + weight_zero_point: torch.Tensor | int, + out_multiplier: torch.Tensor | int, + out_shift: torch.Tensor | int, + out_zero_point: int, + offset: torch.Tensor | None = None, + ) -> torch.Tensor: + if src_dtype and src.dtype != src_dtype: + raise ValueError( + f"src dtype must be {src_dtype}. Got {src.dtype} instead" + ) + if weight_dtype and weight.dtype != weight_dtype: + raise ValueError( + f"weight dtype must be {weight_dtype}. Got {weight.dtype} instead" + ) + if bias.dtype != torch.int32: + raise ValueError( + f"bias dtype must be torch.int32. Got {bias.dtype} instead" + ) + + if per_tensor: + assert isinstance(weight_zero_point, int) + assert isinstance(out_multiplier, int) + assert isinstance(out_shift, int) + return quantized_linear_common( + src, + weight, + bias, + in_zero_point, + weight_zero_point, + out_multiplier, + out_shift, + out_zero_point, + ) + else: + assert isinstance(out_shift, torch.Tensor) + if out_shift.numel() != 1: + raise ValueError("out_shift must be a scalar") + + if out_shift.dtype != torch.int64: + raise ValueError("out_shift must be an int64") + + return quantized_linear_common( + src, + weight, + bias, + in_zero_point, + weight_zero_point, + out_multiplier, + int(out_shift.item()), + out_zero_point, + ) + + return variant + + return decorator + + +@impl(m, "quantized_linear") +@quantized_linear_variant(False) +def quantized_linear() -> torch.Tensor: ... + + +@impl(m, "quantized_linear.per_tensor") +@quantized_linear_variant(True) +def quantized_linear_per_tensor() -> torch.Tensor: ... + + +@impl(m, "quantized_linear_asym8sxasym8s_asym8s.per_tensor") +@quantized_linear_variant(True, torch.int8, torch.int8) +def quantized_linear_asym8sxasym8s_asym8s_per_tensor() -> torch.Tensor: ... + + +@impl(m, "quantized_linear_asym8uxasym8u_asym8u.per_tensor") +@quantized_linear_variant(True, torch.uint8, torch.uint8) +def quantized_linear_asym8uxasym8u_asym8u_per_tensor() -> torch.Tensor: ... + + @impl(m, "quantized_layer_norm.per_tensor") def quantized_layer_norm_per_tensor( input_tensor: torch.Tensor, diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index ebd9aac48b1..2cc82fa5523 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -141,59 +141,139 @@ def test_quantized_add( @expand( [ # Test case 1: 1x2 input, 1x2 weight (1 output feature) - ( - torch.Size([1, 2]), # src_shape: 1 sample, 2 input features - torch.Size([1, 2]), # weight_shape: 1 output feature, 2 input features - 0, # in_zero_point - torch.tensor([0, 0], dtype=torch.int8), # weight_zero_point - torch.tensor( - [1073741824], dtype=torch.int32 - ), # out_multiplier (0.5 * 2^31) - torch.tensor([0], dtype=torch.int8), # out_shift - 0, # out_zero_point - torch.tensor([[-2]], dtype=torch.int8), # expected_output - ), + *[ + ( + torch.Size([1, 2]), # src_shape: 1 sample, 2 input features + torch.Size( + [1, 2] + ), # weight_shape: 1 output feature, 2 input features + 0, # in_zero_point + torch.tensor([0, 0], dtype=dtype), # weight_zero_point + torch.tensor( + [1073741824], dtype=torch.int32 + ), # out_multiplier (0.5 * 2^31) + torch.tensor([0], dtype=torch.int64), # out_shift + 0, # out_zero_point + torch.tensor([[-2]], dtype=dtype), # expected_output + per_tensor, + ) + for (per_tensor, dtype) in ( + (False, torch.int8), + (True, torch.int8), + (True, torch.uint8), + ) + ], # Test case 2: 1x3 input, 2x3 weight (2 output features) - ( - torch.Size([1, 3]), # src_shape: 1 sample, 3 input features - torch.Size([2, 3]), # weight_shape: 2 output features, 3 input features - 0, # in_zero_point - torch.tensor([0, 0, 0], dtype=torch.int8), # weight_zero_point - torch.tensor( - [1073741824], dtype=torch.int32 - ), # out_multiplier (0.5 * 2^31) - torch.tensor([0], dtype=torch.int8), # out_shift - 0, # out_zero_point - torch.tensor([[-10, -30]], dtype=torch.int8), # expected_output - ), + *[ + ( + torch.Size([1, 3]), # src_shape: 1 sample, 3 input features + torch.Size( + [2, 3] + ), # weight_shape: 2 output features, 3 input features + 0, # in_zero_point + torch.tensor([0, 0, 0], dtype=dtype), # weight_zero_point + torch.tensor( + [1073741824], dtype=torch.int32 + ), # out_multiplier (0.5 * 2^31) + torch.tensor([0], dtype=torch.int64), # out_shift + 0, # out_zero_point + torch.tensor([[-10, -30]], dtype=dtype), # expected_output + per_tensor, + ) + for (per_tensor, dtype) in ( + (False, torch.int8), + (True, torch.int8), + (True, torch.uint8), + ) + ], # Test case 3: Batch case with different dimensions - ( - torch.Size([1, 2, 2]), # src_shape: batch=1, seq=2, features=2 - torch.Size([3, 2]), # weight_shape: 3 output features, 2 input features - 0, # in_zero_point - torch.tensor([0, 0], dtype=torch.int8), # weight_zero_point - torch.tensor( - [1073741824], dtype=torch.int32 - ), # out_multiplier (0.5 * 2^31) - torch.tensor([0], dtype=torch.int8), # out_shift - 0, # out_zero_point - torch.tensor( - [[[-2, -8, -14], [-6, -28, -50]]], dtype=torch.int8 - ), # expected_output - ), + *[ + ( + torch.Size([1, 2, 2]), # src_shape: batch=1, seq=2, features=2 + torch.Size( + [3, 2] + ), # weight_shape: 3 output features, 2 input features + 0, # in_zero_point + torch.tensor([0, 0], dtype=dtype), # weight_zero_point + torch.tensor( + [1073741824], dtype=torch.int32 + ), # out_multiplier (0.5 * 2^31) + torch.tensor([0], dtype=torch.int64), # out_shift + 0, # out_zero_point + torch.tensor( + [[[-2, -8, -14], [-6, -28, -50]]], dtype=dtype + ), # expected_output + per_tensor, + ) + for (per_tensor, dtype) in ( + (False, torch.int8), + (True, torch.int8), + (True, torch.uint8), + ) + ], # Test case 4: Non-zero zero points - ( - torch.Size([1, 2]), # src_shape: 1 sample, 2 input features - torch.Size([2, 2]), # weight_shape: 2 output feature, 1 input feature - 2, # in_zero_point - torch.tensor([1, 1], dtype=torch.int8), # weight_zero_point - torch.tensor( - [268435456], dtype=torch.int32 - ), # out_multiplier (1.0 * 2^31) - torch.tensor([0]), # out_shift - 1, # out_zero_point - torch.tensor([[-15, 25]], dtype=torch.int8), # expected_output - ), + *[ + ( + torch.Size([1, 2]), # src_shape: 1 sample, 2 input features + torch.Size( + [2, 2] + ), # weight_shape: 2 output feature, 1 input feature + 2, # in_zero_point + torch.tensor([1, 1], dtype=dtype), # weight_zero_point + torch.tensor( + [268435456], dtype=torch.int32 + ), # out_multiplier (1.0 * 2^31) + torch.tensor([0], dtype=torch.int64), # out_shift + 1, # out_zero_point + torch.tensor([[-15, 25]], dtype=dtype), # expected_output + per_tensor, + ) + for (per_tensor, dtype) in ( + (False, torch.int8), + (True, torch.int8), + (True, torch.uint8), + ) + ], + # Test case 5: Non-uniform weight zero points + *[ + ( + torch.Size([1, 2]), # src_shape: 1 sample, 2 input features + torch.Size( + [2, 2] + ), # weight_shape: 2 output feature, 1 input feature + 2, # in_zero_point + torch.tensor([1, 2], dtype=dtype), # weight_zero_point + torch.tensor( + [268435456], dtype=torch.int32 + ), # out_multiplier (1.0 * 2^31) + torch.tensor([0], dtype=torch.int64), # out_shift + 1, # out_zero_point + torch.tensor([[-23, 17]], dtype=dtype), # expected_output + False, + ) + for dtype in (torch.int8, torch.uint8) + ], + # Test case 6: Non-zero out_shift (shift=1) + *[ + ( + torch.Size([1, 2]), # src_shape: 1 sample, 2 input features + torch.Size( + [2, 2] + ), # weight_shape: 2 output features, 2 input features + 2, # in_zero_point + torch.tensor([1, 1], dtype=dtype), # weight_zero_point + torch.tensor( + [268435456], dtype=torch.int32 + ), # out_multiplier (0.125 * 2^31) + torch.tensor( + [1], dtype=torch.int64 + ), # out_shift (shift=1, doubles the scale) + 1, # out_zero_point + torch.tensor([[-7, 13]], dtype=dtype), # expected_output + per_tensor, + ) + for (per_tensor, dtype) in ((False, torch.int8), (True, torch.int8)) + ], ] ) def test_quantized_linear( @@ -206,6 +286,7 @@ def test_quantized_linear( out_shift: torch.Tensor, out_zero_point: int, expected_output: torch.Tensor, + per_tensor: bool, ) -> None: src = ( torch.arange(np.prod(src_shape)) @@ -217,8 +298,28 @@ def test_quantized_linear( .reshape(weight_shape) .to(expected_output.dtype) ) - bias = torch.arange(weight_shape[0]).to(expected_output.dtype) - output = torch.ops.cadence.quantized_linear( + bias = torch.arange(weight_shape[0]).to(torch.int32) + if per_tensor: + weight_zero_point = weight_zero_point[0] + out_multiplier = out_multiplier[0] + out_shift = out_shift[0] + + if per_tensor: + match expected_output.dtype: + case torch.int8: + linear_op = ( + torch.ops.cadence.quantized_linear_asym8sxasym8s_asym8s.per_tensor + ) + case torch.uint8: + linear_op = ( + torch.ops.cadence.quantized_linear_asym8uxasym8u_asym8u.per_tensor + ) + case _: + linear_op = torch.ops.cadence.quantized_linear.per_tensor + else: + linear_op = torch.ops.cadence.quantized_linear + + output = linear_op( src, weight, bias,