diff --git a/backends/cadence/aot/quantizer/quantizer.py b/backends/cadence/aot/quantizer/quantizer.py index d4af074c475..786b7d6cdf2 100644 --- a/backends/cadence/aot/quantizer/quantizer.py +++ b/backends/cadence/aot/quantizer/quantizer.py @@ -342,3 +342,16 @@ def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None: quantizers = get_cadence_default_quantizers() quantizers.append(CadenceAtenQuantizer(SoftmaxPattern(), qconfig_A16)) super().__init__(quantizers) + + +class CadenceWith16BitLinearActivationsQuantizer(CadenceQuantizer): + """ + Quantizer including A16 fully_connected + """ + + def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None: + if quantizers is None: + quantizers = [] + # Add 16-bit quantizers for LinearPattern + quantizers.append(CadenceAtenQuantizer(LinearPattern(), qconfig_A16)) + super().__init__(quantizers) diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index ed9bb438a9e..b91f585fb16 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -261,7 +261,7 @@ def quantized_linear_common( src = src.view(-1, K) dtype = src.dtype - supported_dtypes = [torch.int8, torch.uint8, torch.int32] + supported_dtypes = [torch.int8, torch.uint8, torch.int16, torch.int32] if dtype not in supported_dtypes: raise ValueError( f"Unsupported dtype to quantize to {dtype}. Supported dtypes must be one of {supported_dtypes}" diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index 259752f3893..6aa091147c7 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -183,6 +183,8 @@ def test_quantized_add( (False, torch.int8), (True, torch.int8), (True, torch.uint8), + (True, torch.int16), + (False, torch.int16), ) ], # Test case 2: 1x3 input, 2x3 weight (2 output features) @@ -207,6 +209,8 @@ def test_quantized_add( for (per_tensor, dtype) in ( (False, torch.int8), (True, torch.int8), + (False, torch.int16), + (True, torch.int16), ) ], *[ @@ -256,6 +260,8 @@ def test_quantized_add( for (per_tensor, dtype) in ( (False, torch.int8), (True, torch.int8), + (False, torch.int16), + (True, torch.int16), ) ], # Test case 4: Non-zero zero points @@ -280,6 +286,8 @@ def test_quantized_add( for (per_tensor, dtype) in ( (False, torch.int8), (True, torch.int8), + (False, torch.int16), + (True, torch.int16), # (True, torch.uint8), ) ], @@ -302,7 +310,10 @@ def test_quantized_add( False, False, ) - for dtype in (torch.int8,) + for dtype in ( + torch.int8, + torch.int16, + ) ], # Test case 6: Non-zero out_shift (shift=1) *[ @@ -325,7 +336,12 @@ def test_quantized_add( False, False, ) - for (per_tensor, dtype) in ((False, torch.int8), (True, torch.int8)) + for (per_tensor, dtype) in ( + (False, torch.int8), + (True, torch.int8), + (False, torch.int16), + (True, torch.int16), + ) ], *[ ( @@ -348,7 +364,7 @@ def test_quantized_add( transposed_matmul, ) for (matmul, transposed_matmul) in ((True, False), (True, True)) - for (per_tensor, dtype) in ((True, torch.int8),) + for (per_tensor, dtype) in ((True, torch.int8), (True, torch.int16)) ], *[ (