From 94682d5044cdd873665ea3d5dd4685926f32d25d Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Wed, 10 Sep 2025 20:05:51 -0700 Subject: [PATCH] Add quantized fully connected ops (#14079) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/14079 Quantized fully connected are just aliases for quantized_linear, so created all aliases. Reviewed By: hsharma35 Differential Revision: D81942767 --- backends/cadence/aot/ops_registrations.py | 4 ++ backends/cadence/aot/ref_implementations.py | 33 ++++++++++-- .../aot/tests/test_ref_implementations.py | 53 +++++++++++-------- .../aot/tests/test_type_dispatch_passes.py | 6 +-- 4 files changed, 67 insertions(+), 29 deletions(-) diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index 507562526c5..35b4cbf3902 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -1771,6 +1771,7 @@ def quantized_fully_connected_meta( # src comes in shape [leading_dims, in_dim] # weight comes in shape [out_dim, in_dim] # output comes in empty with shape [leading_dims, out_dim] + assert src.shape[0] == 1 out_size = list(src.size()) weight_size = list(weight.size()) assert len(weight_size) == 2 @@ -1793,6 +1794,7 @@ def quantized_fully_connected_per_tensor_meta( # src comes in shape [leading_dims, in_dim] # weight comes in shape [out_dim, in_dim] # output comes in empty with shape [leading_dims, out_dim] + assert src.shape[0] == 1 out_size = list(src.size()) weight_size = list(weight.size()) assert len(weight_size) == 2 @@ -1815,6 +1817,7 @@ def quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor_meta( # src comes in shape [leading_dims, in_dim] # weight comes in shape [out_dim, in_dim] # output comes in empty with shape [leading_dims, out_dim] + assert src.shape[0] == 1 out_size = list(src.size()) weight_size = list(weight.size()) assert len(weight_size) == 2 @@ -1837,6 +1840,7 @@ def quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_meta( # src comes in shape [leading_dims, in_dim] # weight comes in shape [out_dim, in_dim] # output comes in empty with shape [leading_dims, out_dim] + assert src.shape[0] == 1 out_size = list(src.size()) weight_size = list(weight.size()) assert len(weight_size) == 2 diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index b1f43d12890..f496b23a82b 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -249,6 +249,7 @@ def quantized_linear_common( def quantized_linear_variant( per_tensor: bool, + fully_connected: bool, src_dtype: torch.dtype | None = None, weight_dtype: torch.dtype | None = None, ) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]: @@ -265,6 +266,10 @@ def variant( out_zero_point: int, offset: torch.Tensor | None = None, ) -> torch.Tensor: + if fully_connected and src.shape[0] != 1: + raise ValueError( + "Fully connected quantized linear only supports batch size of 1" + ) if src_dtype and src.dtype != src_dtype: raise ValueError( f"src dtype must be {src_dtype}. Got {src.dtype} instead" @@ -317,25 +322,45 @@ def variant( @impl(m, "quantized_linear") -@quantized_linear_variant(False) +@quantized_linear_variant(False, False) def quantized_linear() -> torch.Tensor: ... @impl(m, "quantized_linear.per_tensor") -@quantized_linear_variant(True) +@quantized_linear_variant(True, False) def quantized_linear_per_tensor() -> torch.Tensor: ... @impl(m, "quantized_linear_asym8sxasym8s_asym8s.per_tensor") -@quantized_linear_variant(True, torch.int8, torch.int8) +@quantized_linear_variant(True, False, torch.int8, torch.int8) def quantized_linear_asym8sxasym8s_asym8s_per_tensor() -> torch.Tensor: ... @impl(m, "quantized_linear_asym8uxasym8u_asym8u.per_tensor") -@quantized_linear_variant(True, torch.uint8, torch.uint8) +@quantized_linear_variant(True, False, torch.uint8, torch.uint8) def quantized_linear_asym8uxasym8u_asym8u_per_tensor() -> torch.Tensor: ... +@impl(m, "quantized_fully_connected") +@quantized_linear_variant(False, True) +def quantized_fully_connected() -> torch.Tensor: ... + + +@impl(m, "quantized_fully_connected.per_tensor") +@quantized_linear_variant(True, True) +def quantized_fully_connected_per_tensor() -> torch.Tensor: ... + + +@impl(m, "quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor") +@quantized_linear_variant(True, True, torch.int8, torch.int8) +def quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor() -> torch.Tensor: ... + + +@impl(m, "quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor") +@quantized_linear_variant(True, True, torch.uint8, torch.uint8) +def quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor() -> torch.Tensor: ... + + @impl(m, "quantized_layer_norm.per_tensor") def quantized_layer_norm_per_tensor( input_tensor: torch.Tensor, diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index 2cc82fa5523..a08327a7646 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -307,36 +307,45 @@ def test_quantized_linear( if per_tensor: match expected_output.dtype: case torch.int8: - linear_op = ( - torch.ops.cadence.quantized_linear_asym8sxasym8s_asym8s.per_tensor + linear_ops = ( + torch.ops.cadence.quantized_linear_asym8sxasym8s_asym8s.per_tensor, + torch.ops.cadence.quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor, ) case torch.uint8: - linear_op = ( - torch.ops.cadence.quantized_linear_asym8uxasym8u_asym8u.per_tensor + linear_ops = ( + torch.ops.cadence.quantized_linear_asym8uxasym8u_asym8u.per_tensor, + torch.ops.cadence.quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor, ) case _: - linear_op = torch.ops.cadence.quantized_linear.per_tensor + linear_ops = ( + torch.ops.cadence.quantized_linear.per_tensor, + torch.ops.cadence.quantized_fully_connected.per_tensor, + ) else: - linear_op = torch.ops.cadence.quantized_linear + linear_ops = ( + torch.ops.cadence.quantized_linear, + torch.ops.cadence.quantized_fully_connected, + ) - output = linear_op( - src, - weight, - bias, - in_zero_point, - weight_zero_point, - out_multiplier, - out_shift, - out_zero_point, - typing.cast(torch.Tensor, None), - ) + for linear_op in linear_ops: + output = linear_op( + src, + weight, + bias, + in_zero_point, + weight_zero_point, + out_multiplier, + out_shift, + out_zero_point, + typing.cast(torch.Tensor, None), + ) - self.assertTrue(output.dtype == expected_output.dtype, "Dtype mismatch") + self.assertTrue(output.dtype == expected_output.dtype, "Dtype mismatch") - self.assertTrue( - torch.equal(output, expected_output), - f"Values don't match: got {output}, expected {expected_output}", - ) + self.assertTrue( + torch.equal(output, expected_output), + f"Values don't match: got {output}, expected {expected_output}", + ) @expand( [ diff --git a/backends/cadence/aot/tests/test_type_dispatch_passes.py b/backends/cadence/aot/tests/test_type_dispatch_passes.py index f180c138ca4..52904aecb41 100644 --- a/backends/cadence/aot/tests/test_type_dispatch_passes.py +++ b/backends/cadence/aot/tests/test_type_dispatch_passes.py @@ -20,7 +20,7 @@ class TestTypeDispatchPasses(unittest.TestCase): def test_int8_dispatch_quantized_fully_connected(self) -> None: """Test int8 x int8 inputs should dispatch to asym8sxasym8s_asym8s variant""" - x = torch.randint(-128, 127, (2, 3), dtype=torch.int8) + x = torch.randint(-128, 127, (1, 3), dtype=torch.int8) w = torch.randint(-128, 127, (4, 3), dtype=torch.int8) b = torch.randint(-2147483648, 2147483647, (4,), dtype=torch.int32) gm = single_op_builder( @@ -46,7 +46,7 @@ def test_int8_dispatch_quantized_fully_connected(self) -> None: def test_uint8_dispatch_quantized_fully_connected(self) -> None: """Test uint8 x uint8 inputs should dispatch to asym8uxasym8u_asym8u variant""" - x = torch.randint(0, 255, (2, 3), dtype=torch.uint8) + x = torch.randint(0, 255, (1, 3), dtype=torch.uint8) w = torch.randint(0, 255, (4, 3), dtype=torch.uint8) b = torch.randint(-2147483648, 2147483647, (4,), dtype=torch.int32) gm = single_op_builder( @@ -124,7 +124,7 @@ def test_uint8_quantized_linear_dispatch(self) -> None: def test_mixed_types_error(self) -> None: """Test mixed int8/uint8 inputs should raise RuntimeError""" - x = torch.randint(-128, 127, (2, 3), dtype=torch.int8) + x = torch.randint(-128, 127, (1, 3), dtype=torch.int8) w = torch.randint(0, 255, (4, 3), dtype=torch.uint8) b = torch.randint(-2147483648, 2147483647, (4,), dtype=torch.int32) gm = single_op_builder(