Skip to content

Commit 24c9ce4

Browse files
Ninja91facebook-github-bot
authored andcommitted
Add 16A8W quantization configuration utility for ARM backend (#13175)
Summary: Pull Request resolved: #13175 This diff implements a 16A8W (16-bit activations, 8-bit weights) quantization configuration utility for the ExecutorTorch ARM backend, following the feedback from D79746479. ## Key Changes **1. New Quantization Configuration Function** - Add `get_symmetric_a16w8_quantization_config()` in `fbcode/executorch/backends/arm/quantizer/arm_quantizer.py` - Provides 16-bit activations with HistogramObserver (better precision than 8A8W) - Maintains 8-bit weights with MinMaxObserver/PerChannelMinMaxObserver (memory efficient) - **Technically supported by TOSA through [EXT-INT16 extension/profile](https://www.mlplatform.org/tosa/tosa_spec.html#_conv2d)** ## Benefits - **Better Precision**: 16-bit activations provide higher precision than 8-bit. Useful for carrying precision for recurring neural nets. Reviewed By: 3l1 Differential Revision: D79763381
1 parent 07b0883 commit 24c9ce4

File tree

1 file changed

+105
-0
lines changed

1 file changed

+105
-0
lines changed

backends/arm/quantizer/arm_quantizer.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,111 @@ def get_symmetric_quantization_config(
145145
return quantization_config
146146

147147

148+
@functools.lru_cache
149+
def get_symmetric_a16w8_quantization_config(
150+
is_per_channel: bool = True,
151+
is_qat: bool = False,
152+
is_dynamic: bool = False,
153+
weight_qmin: int = -127,
154+
weight_qmax: int = 127,
155+
):
156+
"""
157+
16A8W quantization config: 16-bit activations, 8-bit weights.
158+
159+
This configuration provides better accuracy than 8A8W while maintaining
160+
reasonable memory usage through 8-bit weights.
161+
162+
Args:
163+
is_per_channel: Whether to use per-channel quantization for weights
164+
is_qat: Whether this is for Quantization Aware Training
165+
is_dynamic: Whether to use dynamic quantization
166+
weight_qmin: Minimum quantization value for weights
167+
weight_qmax: Maximum quantization value for weights
168+
169+
Returns:
170+
QuantizationConfig with 16-bit activations and 8-bit weights
171+
"""
172+
extra_args: Dict[str, Any] = {"eps": 2**-12}
173+
174+
# Setup observer/fake-quant for 16-bit activations
175+
if is_qat:
176+
if is_dynamic:
177+
act_observer_or_fake_quant_ctr = FakeQuantize
178+
dynamic_quant_observer = MovingAverageMinMaxObserver.with_args(
179+
averaging_constant=1
180+
)
181+
extra_args["observer"] = dynamic_quant_observer
182+
else:
183+
act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize # type: ignore[assignment]
184+
else:
185+
if is_dynamic:
186+
act_observer_or_fake_quant_ctr = PlaceholderObserver # type: ignore[assignment]
187+
else:
188+
# HistogramObserver works well for 16-bit range
189+
act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment]
190+
191+
# 16-bit activation quantization spec
192+
act_quantization_spec = QuantizationSpec(
193+
dtype=torch.int16,
194+
quant_min=torch.iinfo(torch.int16).min, # -32768
195+
quant_max=torch.iinfo(torch.int16).max, # 32767
196+
qscheme=torch.per_tensor_symmetric,
197+
is_dynamic=is_dynamic,
198+
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
199+
**extra_args,
200+
),
201+
)
202+
203+
# Setup quantization config for weights (same as 8A8W - use 8-bit weights)
204+
weight_qscheme = (
205+
torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric
206+
)
207+
weight_observer_or_fake_quant_ctr: ObserverOrFakeQuantizeConstructor = (
208+
MinMaxObserver
209+
)
210+
# Determine the right observer/fake-quant constructor
211+
if is_qat:
212+
# Set plain fake-quant with true min/max
213+
weight_observer_or_fake_quant_ctr = FakeQuantize
214+
else:
215+
# PTQ: set min/max observer
216+
weight_observer_or_fake_quant_ctr = (
217+
PerChannelMinMaxObserver if is_per_channel else MinMaxObserver
218+
)
219+
220+
weight_extra_args = {"eps": 2**-12}
221+
222+
# 8-bit weight quantization spec (keep weights at 8-bit for memory efficiency)
223+
weight_quantization_spec = QuantizationSpec(
224+
dtype=torch.int8,
225+
quant_min=weight_qmin,
226+
quant_max=weight_qmax,
227+
qscheme=weight_qscheme,
228+
ch_axis=0,
229+
is_dynamic=False,
230+
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(
231+
**weight_extra_args
232+
),
233+
)
234+
235+
bias_quantization_spec = None
236+
if is_dynamic:
237+
quantization_config = QuantizationConfig(
238+
act_quantization_spec, # 16-bit input activations
239+
None,
240+
weight_quantization_spec, # 8-bit weights
241+
bias_quantization_spec,
242+
)
243+
else:
244+
quantization_config = QuantizationConfig(
245+
act_quantization_spec, # 16-bit input activations
246+
act_quantization_spec, # 16-bit output activations
247+
weight_quantization_spec, # 8-bit weights
248+
bias_quantization_spec,
249+
)
250+
return quantization_config
251+
252+
148253
NodeFilterType = Callable[[Node], bool]
149254
"""Type for a Node Filter used by annotators. A Node filter is a function that takes
150255
a Node and returns whether the node should be annotated or not.

0 commit comments

Comments
 (0)