From 94ea57b3e8fff5051f8bf929ae983ded39ee1f3a Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Mon, 6 Oct 2025 13:48:01 -0700 Subject: [PATCH] Fixed assumption on out_shift for quantized linear (#14789) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/14789 out shift should be int32 Reviewed By: hsharma35 Differential Revision: D83875670 --- backends/cadence/aot/ref_implementations.py | 4 ++-- .../aot/tests/test_ref_implementations.py | 16 ++++++++-------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 2642340679e..ad1abb3ce4b 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -330,8 +330,8 @@ def variant( if out_shift.numel() != 1: raise ValueError("out_shift must be a scalar") - if out_shift.dtype != torch.int64: - raise ValueError("out_shift must be an int64") + if out_shift.dtype != torch.int32: + raise ValueError("out_shift must be an int32") _out_shift = int(out_shift.item()) _out_multiplier = int(out_multiplier[0].item()) diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index f78d2292e7b..d8a79454097 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -172,7 +172,7 @@ def test_quantized_add( torch.tensor( [1073741824], dtype=torch.int32 ), # out_multiplier (0.5 * 2^31) - torch.tensor([0], dtype=torch.int64), # out_shift + torch.tensor([0], dtype=torch.int32), # out_shift 0, # out_zero_point torch.tensor([[0]], dtype=dtype), # expected_output per_tensor, @@ -197,7 +197,7 @@ def test_quantized_add( torch.tensor( [1073741824], dtype=torch.int32 ), # out_multiplier (0.5 * 2^31) - torch.tensor([0], dtype=torch.int64), # out_shift + torch.tensor([0], dtype=torch.int32), # out_shift 0, # out_zero_point torch.tensor([[-2, -8]], dtype=dtype), # expected_output per_tensor, @@ -220,7 +220,7 @@ def test_quantized_add( torch.tensor( [1073741824], dtype=torch.int32 ), # out_multiplier (0.5 * 2^31) - torch.tensor([0], dtype=torch.int64), # out_shift + torch.tensor([0], dtype=torch.int32), # out_shift 0, # out_zero_point torch.tensor([[0, 0]], dtype=dtype), # expected_output per_tensor, @@ -244,7 +244,7 @@ def test_quantized_add( torch.tensor( [1073741824], dtype=torch.int32 ), # out_multiplier (0.5 * 2^31) - torch.tensor([0], dtype=torch.int64), # out_shift + torch.tensor([0], dtype=torch.int32), # out_shift 0, # out_zero_point torch.tensor( [[[0, -2, -4], [-2, -7, -12]]], dtype=dtype @@ -270,7 +270,7 @@ def test_quantized_add( torch.tensor( [268435456], dtype=torch.int32 ), # out_multiplier (1.0 * 2^31) - torch.tensor([0], dtype=torch.int64), # out_shift + torch.tensor([0], dtype=torch.int32), # out_shift 1, # out_zero_point torch.tensor([[1, 1]], dtype=dtype), # expected_output per_tensor, @@ -295,7 +295,7 @@ def test_quantized_add( torch.tensor( [268435456], dtype=torch.int32 ), # out_multiplier (1.0 * 2^31) - torch.tensor([0], dtype=torch.int64), # out_shift + torch.tensor([0], dtype=torch.int32), # out_shift 1, # out_zero_point torch.tensor([[1, 1]], dtype=dtype), # expected_output False, @@ -317,7 +317,7 @@ def test_quantized_add( [268435456], dtype=torch.int32 ), # out_multiplier (0.125 * 2^31) torch.tensor( - [1], dtype=torch.int64 + [1], dtype=torch.int32 ), # out_shift (shift=1, doubles the scale) 1, # out_zero_point torch.tensor([[1, 2]], dtype=dtype), # expected_output @@ -339,7 +339,7 @@ def test_quantized_add( [268435456], dtype=torch.int32 ), # out_multiplier (0.125 * 2^31) torch.tensor( - [1], dtype=torch.int64 + [1], dtype=torch.int32 ), # out_shift (shift=1, doubles the scale) 1, # out_zero_point torch.tensor([[1, 2]], dtype=dtype), # expected_output