Skip to content

Commit ee7a0f7

Browse files
RahulC7facebook-github-bot
authored andcommitted
Enable 16-bit activations in Cadence Quantizer For fully_connected and linear
Summary: # Context We currently only support 8-bit for most operators. We would like to add generic ops for 16-bit activations, for the following ops: - quantized_fully_connected - quantized_linear - quantized_conv (all flavors) - quantized_matmul # This Diff Here, we add support for `quantized_linear` and `quantized_fully_connected`. We need to do the following: 1. Allow 16-bit activations in `quantized_fully_connected_out.cpp` and `quantized_linear_out.cpp`. 2. Allow 16-bit activations in `ref_implementations.py`, so tests can run with 16-bit activations to validate the quantization is correct. 3. Add a quantizer(`CadenceWith16BitLinearActivationsQuantizer`) for checking this works and create a unit test. Differential Revision: D84284794
1 parent e26670b commit ee7a0f7

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

backends/cadence/aot/quantizer/quantizer.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,3 +338,13 @@ def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
338338
quantizers = get_cadence_default_quantizers()
339339
quantizers.append(CadenceAtenQuantizer(SoftmaxPattern(), qconfig_A16))
340340
super().__init__(quantizers)
341+
342+
class CadenceWith16BitLinearActivationsQuantizer(CadenceQuantizer):
343+
"""
344+
Quantizer with 16-bit activations for specific operations
345+
"""
346+
def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
347+
quantizers = []
348+
# Add 16-bit quantizers for LinearPattern
349+
quantizers.append(CadenceAtenQuantizer(LinearPattern(), qconfig_A16))
350+
super().__init__(quantizers)

backends/cadence/aot/ref_implementations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def quantized_linear_common(
261261
src = src.view(-1, K)
262262

263263
dtype = src.dtype
264-
supported_dtypes = [torch.int8, torch.uint8, torch.int32]
264+
supported_dtypes = [torch.int8, torch.uint8, torch.int16, torch.int32]
265265
if dtype not in supported_dtypes:
266266
raise ValueError(
267267
f"Unsupported dtype to quantize to. Supported dtypes must be one of {supported_dtypes}"

0 commit comments

Comments
 (0)