Skip to content

Commit 4d1da11

Browse files
pytorchbotNinja91
andauthored
Add 16A8W quantization configuration utility for ARM backend (#13728)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #13641 by @Ninja91 ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/Ninja91/1/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/Ninja91/1/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/Ninja91/1/orig @diff-train-skip-merge Co-authored-by: Nitin Jain <[email protected]>
1 parent 185c96a commit 4d1da11

File tree

1 file changed

+80
-0
lines changed

1 file changed

+80
-0
lines changed

backends/arm/quantizer/arm_quantizer.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,86 @@ def get_symmetric_quantization_config(
145145
return quantization_config
146146

147147

148+
@functools.lru_cache
149+
def get_symmetric_a16w8_quantization_config(
150+
is_per_channel: bool = True,
151+
is_qat: bool = False,
152+
is_dynamic: bool = False,
153+
weight_qmin: int = -127,
154+
weight_qmax: int = 127,
155+
):
156+
"""
157+
16A8W quantization config: 16-bit activations, 8-bit weights.
158+
159+
This configuration provides better accuracy than 8A8W while maintaining
160+
reasonable memory usage through 8-bit weights.
161+
162+
Args:
163+
is_per_channel: Whether to use per-channel quantization for weights
164+
is_qat: Whether this is for Quantization Aware Training
165+
is_dynamic: Whether to use dynamic quantization
166+
weight_qmin: Minimum quantization value for weights
167+
weight_qmax: Maximum quantization value for weights
168+
169+
Returns:
170+
QuantizationConfig with 16-bit activations and 8-bit weights
171+
"""
172+
extra_args: Dict[str, Any] = {"eps": 2**-12}
173+
174+
# Setup observer/fake-quant for 16-bit activations
175+
if is_qat:
176+
if is_dynamic:
177+
act_observer_or_fake_quant_ctr = FakeQuantize
178+
dynamic_quant_observer = MovingAverageMinMaxObserver.with_args(
179+
averaging_constant=1
180+
)
181+
extra_args["observer"] = dynamic_quant_observer
182+
else:
183+
act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize # type: ignore[assignment]
184+
else:
185+
if is_dynamic:
186+
act_observer_or_fake_quant_ctr = PlaceholderObserver # type: ignore[assignment]
187+
else:
188+
# HistogramObserver works well for 16-bit range
189+
act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment]
190+
191+
# 16-bit activation quantization spec
192+
act_quantization_spec = QuantizationSpec(
193+
dtype=torch.int16,
194+
quant_min=torch.iinfo(torch.int16).min, # -32768
195+
quant_max=torch.iinfo(torch.int16).max, # 32767
196+
qscheme=torch.per_tensor_symmetric,
197+
is_dynamic=is_dynamic,
198+
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
199+
**extra_args,
200+
),
201+
)
202+
203+
# Instead of reconstructing quantization_config, just clone and update as needed
204+
# Clone the quantization_config from get_symmetric_quantization_config and update activation spec
205+
base_config = get_symmetric_quantization_config(
206+
is_per_channel=is_per_channel,
207+
is_qat=is_qat,
208+
is_dynamic=is_dynamic,
209+
)
210+
# Replace activation quantization spec with 16-bit version
211+
if is_dynamic:
212+
quantization_config = QuantizationConfig(
213+
act_quantization_spec, # 16-bit input activations
214+
None,
215+
base_config.weight, # 8-bit weights from base config
216+
None,
217+
)
218+
else:
219+
quantization_config = QuantizationConfig(
220+
act_quantization_spec, # 16-bit input activations
221+
act_quantization_spec, # 16-bit output activations
222+
base_config.weight, # 8-bit weights from base config
223+
None,
224+
)
225+
return quantization_config
226+
227+
148228
NodeFilterType = Callable[[Node], bool]
149229
"""Type for a Node Filter used by annotators. A Node filter is a function that takes
150230
a Node and returns whether the node should be annotated or not.

0 commit comments

Comments
 (0)