Skip to content

Commit 4632410

Browse files
committed
Add 16A8W quantization configuration utility for ARM backend
This diff implements a 16A8W (16-bit activations, 8-bit weights) quantization configuration utility for the ExecutorTorch ARM backend, following the feedback from D79746479. ## Key Changes **1. New Quantization Configuration Function** - Add `get_16a8w_quantization_config()` in `fbcode/executorch/backends/arm/quantizer/arm_quantizer.py` - Provides 16-bit activations with HistogramObserver (better precision than 8A8W) - Maintains 8-bit weights with MinMaxObserver/PerChannelMinMaxObserver (memory efficient) - **Technically supported by TOSA through [EXT-INT16 extension/profile](https://www.mlplatform.org/tosa/tosa_spec.html#_conv2d)** ## Benefits - **Better Precision**: 16-bit activations provide higher precision than 8-bit. Useful for carrying precision for recurring neural nets. ghstack-source-id: 305991462 @exported-using-ghexport @bypass-github-export-checks @bypass-github-pytorch-ci-checks @bypass-github-executorch-ci-checks Differential Revision: [D81550512](https://our.internmc.facebook.com/intern/diff/D81550512/) [ghstack-poisoned]
1 parent 02da205 commit 4632410

File tree

1 file changed

+80
-0
lines changed

1 file changed

+80
-0
lines changed

backends/arm/quantizer/arm_quantizer.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,86 @@ def get_symmetric_a16w8_quantization_config(
225225
return quantization_config
226226

227227

228+
@functools.lru_cache
229+
def get_symmetric_a16w8_quantization_config(
230+
is_per_channel: bool = True,
231+
is_qat: bool = False,
232+
is_dynamic: bool = False,
233+
weight_qmin: int = -127,
234+
weight_qmax: int = 127,
235+
):
236+
"""
237+
16A8W quantization config: 16-bit activations, 8-bit weights.
238+
239+
This configuration provides better accuracy than 8A8W while maintaining
240+
reasonable memory usage through 8-bit weights.
241+
242+
Args:
243+
is_per_channel: Whether to use per-channel quantization for weights
244+
is_qat: Whether this is for Quantization Aware Training
245+
is_dynamic: Whether to use dynamic quantization
246+
weight_qmin: Minimum quantization value for weights
247+
weight_qmax: Maximum quantization value for weights
248+
249+
Returns:
250+
QuantizationConfig with 16-bit activations and 8-bit weights
251+
"""
252+
extra_args: Dict[str, Any] = {"eps": 2**-12}
253+
254+
# Setup observer/fake-quant for 16-bit activations
255+
if is_qat:
256+
if is_dynamic:
257+
act_observer_or_fake_quant_ctr = FakeQuantize
258+
dynamic_quant_observer = MovingAverageMinMaxObserver.with_args(
259+
averaging_constant=1
260+
)
261+
extra_args["observer"] = dynamic_quant_observer
262+
else:
263+
act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize # type: ignore[assignment]
264+
else:
265+
if is_dynamic:
266+
act_observer_or_fake_quant_ctr = PlaceholderObserver # type: ignore[assignment]
267+
else:
268+
# HistogramObserver works well for 16-bit range
269+
act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment]
270+
271+
# 16-bit activation quantization spec
272+
act_quantization_spec = QuantizationSpec(
273+
dtype=torch.int16,
274+
quant_min=torch.iinfo(torch.int16).min, # -32768
275+
quant_max=torch.iinfo(torch.int16).max, # 32767
276+
qscheme=torch.per_tensor_symmetric,
277+
is_dynamic=is_dynamic,
278+
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
279+
**extra_args,
280+
),
281+
)
282+
283+
# Instead of reconstructing quantization_config, just clone and update as needed
284+
# Clone the quantization_config from get_symmetric_quantization_config and update activation spec
285+
base_config = get_symmetric_quantization_config(
286+
is_per_channel=is_per_channel,
287+
is_qat=is_qat,
288+
is_dynamic=is_dynamic,
289+
)
290+
# Replace activation quantization spec with 16-bit version
291+
if is_dynamic:
292+
quantization_config = QuantizationConfig(
293+
act_quantization_spec, # 16-bit input activations
294+
None,
295+
base_config.weight, # 8-bit weights from base config
296+
None,
297+
)
298+
else:
299+
quantization_config = QuantizationConfig(
300+
act_quantization_spec, # 16-bit input activations
301+
act_quantization_spec, # 16-bit output activations
302+
base_config.weight, # 8-bit weights from base config
303+
None,
304+
)
305+
return quantization_config
306+
307+
228308
NodeFilterType = Callable[[Node], bool]
229309
"""Type for a Node Filter used by annotators. A Node filter is a function that takes
230310
a Node and returns whether the node should be annotated or not.

0 commit comments

Comments
 (0)