Skip to content

Commit 408cbcd

Browse files
RahulC7facebook-github-bot
authored andcommitted
Enable 16-bit activations in Cadence Quantizer For fully_connected and linear (#15010)
Summary: # Context We currently only support 8-bit for most operators. We would like to add generic ops for 16-bit activations, for the following ops: - quantized_fully_connected - quantized_linear - quantized_conv (all flavors) - quantized_matmul # This Diff Here, we add support for `quantized_linear` and `quantized_fully_connected`. We need to do the following: 1. Allow 16-bit activations in `quantized_fully_connected_out.cpp` and `quantized_linear_out.cpp`. 2. Allow 16-bit activations in `ref_implementations.py`, so tests can run with 16-bit activations to validate the quantization is correct. 3. Add a quantizer(`CadenceWith16BitLinearActivationsQuantizer`) for checking this works and create a unit test. Reviewed By: hsharma35 Differential Revision: D84284794
1 parent b56f4e5 commit 408cbcd

File tree

3 files changed

+23
-4
lines changed

3 files changed

+23
-4
lines changed

backends/cadence/aot/quantizer/quantizer.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,3 +342,14 @@ def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
342342
quantizers = get_cadence_default_quantizers()
343343
quantizers.append(CadenceAtenQuantizer(SoftmaxPattern(), qconfig_A16))
344344
super().__init__(quantizers)
345+
346+
class CadenceWith16BitLinearActivationsQuantizer(CadenceQuantizer):
347+
"""
348+
Quantizer including A16 fully_connected
349+
"""
350+
def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
351+
if quantizers is None:
352+
quantizers = []
353+
# Add 16-bit quantizers for LinearPattern
354+
quantizers.append(CadenceAtenQuantizer(LinearPattern(), qconfig_A16))
355+
super().__init__(quantizers)

backends/cadence/aot/ref_implementations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def quantized_linear_common(
261261
src = src.view(-1, K)
262262

263263
dtype = src.dtype
264-
supported_dtypes = [torch.int8, torch.uint8, torch.int32]
264+
supported_dtypes = [torch.int8, torch.uint8, torch.int16, torch.int32]
265265
if dtype not in supported_dtypes:
266266
raise ValueError(
267267
f"Unsupported dtype to quantize to {dtype}. Supported dtypes must be one of {supported_dtypes}"

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,8 @@ def test_quantized_add(
183183
(False, torch.int8),
184184
(True, torch.int8),
185185
(True, torch.uint8),
186+
(True, torch.int16),
187+
(False, torch.int16),
186188
)
187189
],
188190
# Test case 2: 1x3 input, 2x3 weight (2 output features)
@@ -207,6 +209,8 @@ def test_quantized_add(
207209
for (per_tensor, dtype) in (
208210
(False, torch.int8),
209211
(True, torch.int8),
212+
(False, torch.int16),
213+
(True, torch.int16),
210214
)
211215
],
212216
*[
@@ -256,6 +260,8 @@ def test_quantized_add(
256260
for (per_tensor, dtype) in (
257261
(False, torch.int8),
258262
(True, torch.int8),
263+
(False, torch.int16),
264+
(True, torch.int16),
259265
)
260266
],
261267
# Test case 4: Non-zero zero points
@@ -280,6 +286,8 @@ def test_quantized_add(
280286
for (per_tensor, dtype) in (
281287
(False, torch.int8),
282288
(True, torch.int8),
289+
(False, torch.int16),
290+
(True, torch.int16),
283291
# (True, torch.uint8),
284292
)
285293
],
@@ -302,7 +310,7 @@ def test_quantized_add(
302310
False,
303311
False,
304312
)
305-
for dtype in (torch.int8,)
313+
for dtype in (torch.int8, torch.int16,)
306314
],
307315
# Test case 6: Non-zero out_shift (shift=1)
308316
*[
@@ -325,7 +333,7 @@ def test_quantized_add(
325333
False,
326334
False,
327335
)
328-
for (per_tensor, dtype) in ((False, torch.int8), (True, torch.int8))
336+
for (per_tensor, dtype) in ((False, torch.int8), (True, torch.int8), (False, torch.int16), (True, torch.int16))
329337
],
330338
*[
331339
(
@@ -348,7 +356,7 @@ def test_quantized_add(
348356
transposed_matmul,
349357
)
350358
for (matmul, transposed_matmul) in ((True, False), (True, True))
351-
for (per_tensor, dtype) in ((True, torch.int8),)
359+
for (per_tensor, dtype) in ((True, torch.int8), (True, torch.int16))
352360
],
353361
*[
354362
(

0 commit comments

Comments
 (0)