Skip to content

Commit 926312e

Browse files
authored
Enable 16-bit activations in Cadence Quantizer For fully_connected and linear
Differential Revision: D84284794 Pull Request resolved: #15010
1 parent c9339e2 commit 926312e

File tree

3 files changed

+33
-4
lines changed

3 files changed

+33
-4
lines changed

backends/cadence/aot/quantizer/quantizer.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,3 +342,16 @@ def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
342342
quantizers = get_cadence_default_quantizers()
343343
quantizers.append(CadenceAtenQuantizer(SoftmaxPattern(), qconfig_A16))
344344
super().__init__(quantizers)
345+
346+
347+
class CadenceWith16BitLinearActivationsQuantizer(CadenceQuantizer):
348+
"""
349+
Quantizer including A16 fully_connected
350+
"""
351+
352+
def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
353+
if quantizers is None:
354+
quantizers = []
355+
# Add 16-bit quantizers for LinearPattern
356+
quantizers.append(CadenceAtenQuantizer(LinearPattern(), qconfig_A16))
357+
super().__init__(quantizers)

backends/cadence/aot/ref_implementations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def quantized_linear_common(
261261
src = src.view(-1, K)
262262

263263
dtype = src.dtype
264-
supported_dtypes = [torch.int8, torch.uint8, torch.int32]
264+
supported_dtypes = [torch.int8, torch.uint8, torch.int16, torch.int32]
265265
if dtype not in supported_dtypes:
266266
raise ValueError(
267267
f"Unsupported dtype to quantize to {dtype}. Supported dtypes must be one of {supported_dtypes}"

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,8 @@ def test_quantized_add(
183183
(False, torch.int8),
184184
(True, torch.int8),
185185
(True, torch.uint8),
186+
(True, torch.int16),
187+
(False, torch.int16),
186188
)
187189
],
188190
# Test case 2: 1x3 input, 2x3 weight (2 output features)
@@ -207,6 +209,8 @@ def test_quantized_add(
207209
for (per_tensor, dtype) in (
208210
(False, torch.int8),
209211
(True, torch.int8),
212+
(False, torch.int16),
213+
(True, torch.int16),
210214
)
211215
],
212216
*[
@@ -256,6 +260,8 @@ def test_quantized_add(
256260
for (per_tensor, dtype) in (
257261
(False, torch.int8),
258262
(True, torch.int8),
263+
(False, torch.int16),
264+
(True, torch.int16),
259265
)
260266
],
261267
# Test case 4: Non-zero zero points
@@ -280,6 +286,8 @@ def test_quantized_add(
280286
for (per_tensor, dtype) in (
281287
(False, torch.int8),
282288
(True, torch.int8),
289+
(False, torch.int16),
290+
(True, torch.int16),
283291
# (True, torch.uint8),
284292
)
285293
],
@@ -302,7 +310,10 @@ def test_quantized_add(
302310
False,
303311
False,
304312
)
305-
for dtype in (torch.int8,)
313+
for dtype in (
314+
torch.int8,
315+
torch.int16,
316+
)
306317
],
307318
# Test case 6: Non-zero out_shift (shift=1)
308319
*[
@@ -325,7 +336,12 @@ def test_quantized_add(
325336
False,
326337
False,
327338
)
328-
for (per_tensor, dtype) in ((False, torch.int8), (True, torch.int8))
339+
for (per_tensor, dtype) in (
340+
(False, torch.int8),
341+
(True, torch.int8),
342+
(False, torch.int16),
343+
(True, torch.int16),
344+
)
329345
],
330346
*[
331347
(
@@ -348,7 +364,7 @@ def test_quantized_add(
348364
transposed_matmul,
349365
)
350366
for (matmul, transposed_matmul) in ((True, False), (True, True))
351-
for (per_tensor, dtype) in ((True, torch.int8),)
367+
for (per_tensor, dtype) in ((True, torch.int8), (True, torch.int16))
352368
],
353369
*[
354370
(

0 commit comments

Comments
 (0)