Skip to content

Commit 95856ad

Browse files
committed
Update
[ghstack-poisoned]
2 parents c285ecf + edfd321 commit 95856ad

File tree

35 files changed

+1041
-689
lines changed

35 files changed

+1041
-689
lines changed

backends/apple/coreml/test/__init__.py

Lines changed: 0 additions & 17 deletions
This file was deleted.

backends/arm/operators/op_transpose.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,17 @@ def define_node(
4444

4545
validate_num_inputs(self.target, inputs, 2)
4646
validate_same_dtype(self.target, [inputs[0], output], ts)
47+
48+
valid_dtypes = [ts.DType.BOOL]
49+
if self.tosa_spec.support_integer():
50+
valid_dtypes.extend([ts.DType.INT8, ts.DType.INT16])
51+
if self.tosa_spec.support_float():
52+
valid_dtypes.extend([ts.DType.FP16, ts.DType.FP32])
53+
4754
validate_valid_dtype(
4855
self.target,
4956
[inputs[0], output],
50-
[
51-
ts.DType.INT8,
52-
ts.DType.INT16,
53-
ts.DType.INT32,
54-
ts.DType.FP32,
55-
ts.DType.BOOL,
56-
ts.DType.FP16,
57-
],
57+
valid_dtypes,
5858
output.tosa_spec,
5959
)
6060

backends/arm/operators/op_view.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,16 @@ def define_node(
4141

4242
validate_num_inputs(self.target, inputs, 2)
4343
validate_same_dtype(self.target, [inputs[0], output], ts)
44+
valid_dtypes = [ts.DType.BOOL]
45+
if self.tosa_spec.support_integer():
46+
valid_dtypes.extend([ts.DType.INT8, ts.DType.INT16])
47+
if self.tosa_spec.support_float():
48+
valid_dtypes.extend([ts.DType.FP16, ts.DType.FP32])
49+
4450
validate_valid_dtype(
4551
self.target,
4652
[inputs[0], output],
47-
[ts.DType.INT8, ts.DType.INT32, ts.DType.FP32, ts.DType.BOOL],
53+
valid_dtypes,
4854
output.tosa_spec,
4955
)
5056

backends/arm/quantizer/arm_quantizer.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,86 @@ def get_symmetric_quantization_config(
145145
return quantization_config
146146

147147

148+
@functools.lru_cache
149+
def get_symmetric_a16w8_quantization_config(
150+
is_per_channel: bool = True,
151+
is_qat: bool = False,
152+
is_dynamic: bool = False,
153+
weight_qmin: int = -127,
154+
weight_qmax: int = 127,
155+
):
156+
"""
157+
16A8W quantization config: 16-bit activations, 8-bit weights.
158+
159+
This configuration provides better accuracy than 8A8W while maintaining
160+
reasonable memory usage through 8-bit weights.
161+
162+
Args:
163+
is_per_channel: Whether to use per-channel quantization for weights
164+
is_qat: Whether this is for Quantization Aware Training
165+
is_dynamic: Whether to use dynamic quantization
166+
weight_qmin: Minimum quantization value for weights
167+
weight_qmax: Maximum quantization value for weights
168+
169+
Returns:
170+
QuantizationConfig with 16-bit activations and 8-bit weights
171+
"""
172+
extra_args: Dict[str, Any] = {"eps": 2**-12}
173+
174+
# Setup observer/fake-quant for 16-bit activations
175+
if is_qat:
176+
if is_dynamic:
177+
act_observer_or_fake_quant_ctr = FakeQuantize
178+
dynamic_quant_observer = MovingAverageMinMaxObserver.with_args(
179+
averaging_constant=1
180+
)
181+
extra_args["observer"] = dynamic_quant_observer
182+
else:
183+
act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize # type: ignore[assignment]
184+
else:
185+
if is_dynamic:
186+
act_observer_or_fake_quant_ctr = PlaceholderObserver # type: ignore[assignment]
187+
else:
188+
# HistogramObserver works well for 16-bit range
189+
act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment]
190+
191+
# 16-bit activation quantization spec
192+
act_quantization_spec = QuantizationSpec(
193+
dtype=torch.int16,
194+
quant_min=torch.iinfo(torch.int16).min, # -32768
195+
quant_max=torch.iinfo(torch.int16).max, # 32767
196+
qscheme=torch.per_tensor_symmetric,
197+
is_dynamic=is_dynamic,
198+
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
199+
**extra_args,
200+
),
201+
)
202+
203+
# Instead of reconstructing quantization_config, just clone and update as needed
204+
# Clone the quantization_config from get_symmetric_quantization_config and update activation spec
205+
base_config = get_symmetric_quantization_config(
206+
is_per_channel=is_per_channel,
207+
is_qat=is_qat,
208+
is_dynamic=is_dynamic,
209+
)
210+
# Replace activation quantization spec with 16-bit version
211+
if is_dynamic:
212+
quantization_config = QuantizationConfig(
213+
act_quantization_spec, # 16-bit input activations
214+
None,
215+
base_config.weight, # 8-bit weights from base config
216+
None,
217+
)
218+
else:
219+
quantization_config = QuantizationConfig(
220+
act_quantization_spec, # 16-bit input activations
221+
act_quantization_spec, # 16-bit output activations
222+
base_config.weight, # 8-bit weights from base config
223+
None,
224+
)
225+
return quantization_config
226+
227+
148228
NodeFilterType = Callable[[Node], bool]
149229
"""Type for a Node Filter used by annotators. A Node filter is a function that takes
150230
a Node and returns whether the node should be annotated or not.

backends/arm/test/ops/test_linear.py

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@
99
from typing import Tuple
1010

1111
import pytest
12-
1312
import torch
14-
from executorch.backends.arm.test import common
13+
from executorch.backends.arm.quantizer.arm_quantizer import (
14+
get_symmetric_a16w8_quantization_config,
15+
TOSAQuantizer,
16+
)
17+
from executorch.backends.arm.test import common, conftest
1518

1619
from executorch.backends.arm.test.tester.test_pipeline import (
1720
EthosU55PipelineINT,
@@ -20,6 +23,8 @@
2023
TosaPipelineINT,
2124
VgfPipeline,
2225
)
26+
from executorch.backends.arm.tosa_specification import TosaSpecification
27+
from executorch.backends.xnnpack.test.tester import Quantize
2328

2429
aten_op = "torch.ops.aten.linear.default"
2530

@@ -143,7 +148,6 @@ def test_linear_tosa_FP(test_data: torch.Tensor):
143148
pipeline.run()
144149

145150

146-
@pytest.mark.flaky(reruns=5) # TODO: Investigate flakyness.
147151
@common.parametrize("test_data", test_data_rank1_INT | test_data_rank4_INT)
148152
def test_linear_tosa_INT(test_data: torch.Tensor):
149153
test_data, out_features, has_bias, per_channel_quantization = test_data()
@@ -243,3 +247,64 @@ def test_linear_vgf_INT(test_data: torch.Tensor):
243247
per_channel_quantization=per_channel_quantization,
244248
)
245249
pipeline.run()
250+
251+
252+
def get_symmetric_a16w8_linear_quantizer(
253+
u55_config=False, per_channel_quantization=False
254+
):
255+
tosa_version = conftest.get_option("tosa_version")
256+
tosa_profiles = {
257+
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
258+
}
259+
260+
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
261+
quantizer.set_global(
262+
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
263+
)
264+
quantizer.set_module_type(
265+
torch.nn.Linear,
266+
get_symmetric_a16w8_quantization_config(
267+
is_per_channel=per_channel_quantization
268+
),
269+
)
270+
271+
return Quantize(
272+
quantizer,
273+
get_symmetric_a16w8_quantization_config(
274+
is_per_channel=per_channel_quantization
275+
),
276+
)
277+
278+
279+
@common.parametrize("test_data", test_data_rank1_INT | test_data_rank4_INT)
280+
@pytest.mark.xfail(
281+
reason="missing int16 linear ops support; fails at TOSA reference model run with Invalid TOSA graph"
282+
)
283+
def test_linear_16a8w_tosa_INT(test_data: torch.Tensor):
284+
"""Test linear operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
285+
test_data, out_features, has_bias, per_channel_quantization = test_data()
286+
in_features = test_data.shape[-1]
287+
288+
# Create pipeline with custom 16A8W quantization config
289+
pipeline = TosaPipelineINT[input_t1](
290+
Linear(
291+
in_features=in_features,
292+
out_features=out_features,
293+
bias=has_bias,
294+
),
295+
(test_data,),
296+
aten_op,
297+
exir_op=[],
298+
per_channel_quantization=per_channel_quantization,
299+
use_to_edge_transform_and_lower=True,
300+
tosa_extensions=["int16"],
301+
)
302+
303+
pipeline.change_args(
304+
"quantize",
305+
get_symmetric_a16w8_linear_quantizer(
306+
per_channel_quantization=per_channel_quantization
307+
),
308+
)
309+
# Run the pipeline
310+
pipeline.run()

backends/cadence/aot/ops_registrations.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,12 @@
448448
"roi_align_box_processor(Tensor rois, int output_size_h, int output_size_w, "
449449
"int sampling_ratio, bool aligned) -> (Tensor out)"
450450
)
451+
lib.define(
452+
"_softmax_f32_f32(Tensor self, int dim, bool? half_to_float) -> (Tensor out)"
453+
)
454+
lib.define(
455+
"_softmax_f32_f32.out(Tensor self, int dim, bool? half_to_float, *, Tensor(a!) out) -> Tensor(a!)"
456+
)
451457

452458
# Custom ops with aten namespace. Need to specify the lib var as FRAGMENT type as aten library is already defined
453459
aten_lib = Library("aten", "FRAGMENT")
@@ -2075,3 +2081,13 @@ def roi_align_box_processor_meta(
20752081
aligned: bool,
20762082
) -> torch.Tensor:
20772083
return rois.new_empty((rois.shape[0], 80), dtype=torch.uint8)
2084+
2085+
2086+
@register_fake("cadence::_softmax_f32_f32")
2087+
def softmax_f32_f32_meta(
2088+
self: torch.Tensor,
2089+
dim: int,
2090+
dtype: torch.dtype,
2091+
half_to_float: Optional[bool] = None,
2092+
) -> torch.Tensor:
2093+
return self.new_empty(self.size(), dtype=self.dtype)

backends/cadence/aot/type_dispatch.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,13 @@ class CompileTimeTypeDispatchPass(ExportPass):
9393
},
9494
weight_arg_idx=3,
9595
),
96+
exir_ops.edge.aten._softmax.default: OpConfig(
97+
"_softmax",
98+
type_dispatch_suffixes={
99+
(torch.float32,): "f32_f32",
100+
},
101+
variant="default",
102+
),
96103
}
97104

98105
def call_operator(

0 commit comments

Comments
 (0)