Skip to content

Commit acd368f

Browse files
committed
Update on "Arm backend: Add 16A8W support and test for add operation"
Add 16A8W quantization support and comprehensive tests for the add operation in ExecutorTorch ARM backend targeting Ethos U55 and U85 NPUs. This follows the pattern established for linear operations, extending int16 support to add operations with hardware-specific testing. Changes: - Add INT16 dtype validation support in op_add.py - Add test_add_tensor_16a8w_tosa_INT test function with U55/U85 pipeline support - Add U55 and U85 specific 16A8W tests with proper xfail decorators - Fix U55/U85 test parameter usage (remove unsupported tosa_extensions, clean quantizer function calls) - Update xfail reasons to consistent 'Vela compilation fails with Invalid arguments' pattern Differential Revision: [D80510463](https://our.internmc.facebook.com/intern/diff/D80510463) cc digantdesai freddan80 per zingo oscarandersson8218 [ghstack-poisoned]
2 parents 70eb26f + 04c6534 commit acd368f

File tree

2 files changed

+83
-3
lines changed

2 files changed

+83
-3
lines changed

backends/arm/operators/op_add.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,15 +65,15 @@ def define_node(
6565
tosa_graph, inputs, node, self.tosa_spec
6666
)
6767
else:
68-
# input[0].dtype == ts.DType.INT32
68+
# input[0].dtype == ts.DType.INT16 or ts.DType.INT32
6969
# Non quantized input, natively support by TOSA.ADD
7070
rescaled_inputs = inputs
7171

7272
if output.dtype == ts.DType.INT8:
7373
broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
7474
add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
7575
else:
76-
# output.dtype == ts.DType.INT32
76+
# output.dtype == ts.DType.INT16 or ts.DType.INT32
7777
add_output = output
7878

7979
input1, input2 = rescaled_inputs
@@ -123,7 +123,7 @@ def define_node(
123123
validate_num_inputs(self.target, inputs, 2)
124124
validate_same_dtype(self.target, [*inputs, output], ts)
125125

126-
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
126+
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32]:
127127
# Call the inherited define_node for handling integers
128128
super().define_node(node, tosa_graph, inputs, output)
129129
else:

backends/arm/quantizer/arm_quantizer.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,86 @@ def get_symmetric_a16w8_quantization_config(
225225
return quantization_config
226226

227227

228+
@functools.lru_cache
229+
def get_symmetric_a16w8_quantization_config(
230+
is_per_channel: bool = True,
231+
is_qat: bool = False,
232+
is_dynamic: bool = False,
233+
weight_qmin: int = -127,
234+
weight_qmax: int = 127,
235+
):
236+
"""
237+
16A8W quantization config: 16-bit activations, 8-bit weights.
238+
239+
This configuration provides better accuracy than 8A8W while maintaining
240+
reasonable memory usage through 8-bit weights.
241+
242+
Args:
243+
is_per_channel: Whether to use per-channel quantization for weights
244+
is_qat: Whether this is for Quantization Aware Training
245+
is_dynamic: Whether to use dynamic quantization
246+
weight_qmin: Minimum quantization value for weights
247+
weight_qmax: Maximum quantization value for weights
248+
249+
Returns:
250+
QuantizationConfig with 16-bit activations and 8-bit weights
251+
"""
252+
extra_args: Dict[str, Any] = {"eps": 2**-12}
253+
254+
# Setup observer/fake-quant for 16-bit activations
255+
if is_qat:
256+
if is_dynamic:
257+
act_observer_or_fake_quant_ctr = FakeQuantize
258+
dynamic_quant_observer = MovingAverageMinMaxObserver.with_args(
259+
averaging_constant=1
260+
)
261+
extra_args["observer"] = dynamic_quant_observer
262+
else:
263+
act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize # type: ignore[assignment]
264+
else:
265+
if is_dynamic:
266+
act_observer_or_fake_quant_ctr = PlaceholderObserver # type: ignore[assignment]
267+
else:
268+
# HistogramObserver works well for 16-bit range
269+
act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment]
270+
271+
# 16-bit activation quantization spec
272+
act_quantization_spec = QuantizationSpec(
273+
dtype=torch.int16,
274+
quant_min=torch.iinfo(torch.int16).min, # -32768
275+
quant_max=torch.iinfo(torch.int16).max, # 32767
276+
qscheme=torch.per_tensor_symmetric,
277+
is_dynamic=is_dynamic,
278+
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
279+
**extra_args,
280+
),
281+
)
282+
283+
# Instead of reconstructing quantization_config, just clone and update as needed
284+
# Clone the quantization_config from get_symmetric_quantization_config and update activation spec
285+
base_config = get_symmetric_quantization_config(
286+
is_per_channel=is_per_channel,
287+
is_qat=is_qat,
288+
is_dynamic=is_dynamic,
289+
)
290+
# Replace activation quantization spec with 16-bit version
291+
if is_dynamic:
292+
quantization_config = QuantizationConfig(
293+
act_quantization_spec, # 16-bit input activations
294+
None,
295+
base_config.weight, # 8-bit weights from base config
296+
None,
297+
)
298+
else:
299+
quantization_config = QuantizationConfig(
300+
act_quantization_spec, # 16-bit input activations
301+
act_quantization_spec, # 16-bit output activations
302+
base_config.weight, # 8-bit weights from base config
303+
None,
304+
)
305+
return quantization_config
306+
307+
228308
NodeFilterType = Callable[[Node], bool]
229309
"""Type for a Node Filter used by annotators. A Node filter is a function that takes
230310
a Node and returns whether the node should be annotated or not.

0 commit comments

Comments
 (0)