Skip to content

Commit bf22ab7

Browse files
authored
Qualcomm AI Engine Direct - Strengthen Unit Test Robustness (#13892)
Summary: - Fixed the seed for E2E model scripts in unit test - Resolved the bug during dump optrace - The per-channel quant config of the bias for conv op should be derived by activation and weight. - Resolve the issue caused by [the PR](#13606) that prevented the spec from being correctly updated to the quantized type.
1 parent 1a7441f commit bf22ab7

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+261
-294
lines changed

backends/qualcomm/_passes/build_quant_io.py

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
# LICENSE file in the root directory of this source tree.
66
import torch
77
from executorch.backends.qualcomm.utils.constants import QCOM_QUANTIZED_IO
8+
from executorch.exir.delegate import executorch_call_delegate
89

9-
from executorch.exir.pass_base import ExportPass, PassResult
10+
from executorch.exir.pass_base import ExportPass, ProxyValue
1011
from executorch.exir.tensor import TensorSpec
12+
from torch.utils import _pytree as pytree
1113

1214

1315
class BuildQuantIo(ExportPass):
@@ -26,26 +28,22 @@ def _make_spec(self, x):
2628
else:
2729
return None
2830

29-
def _build(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
30-
# Forcedly update delegate node's meta['spec'] to get correct output
31-
# tensor size in runtime
32-
call_delegate = [
33-
node
34-
for node in graph_module.graph.nodes
35-
if node.op == "call_function" and node.name == "executorch_call_delegate"
36-
]
37-
assert len(call_delegate) == 1
38-
for n in graph_module.graph.nodes:
39-
if QCOM_QUANTIZED_IO in n.meta:
40-
n.meta["val"] = n.meta["val"].to(dtype=n.meta[QCOM_QUANTIZED_IO])
41-
42-
spec = []
43-
for user in list(call_delegate[0].users):
44-
spec.append(self._make_spec(user.meta["val"]))
45-
call_delegate[0].meta["spec"] = tuple(spec)
46-
47-
def call(self, graph_module: torch.fx.GraphModule):
48-
self._build(graph_module)
49-
graph_module.graph.eliminate_dead_code()
50-
graph_module.recompile()
51-
return PassResult(graph_module, True)
31+
def placeholder(self, name: str, arg, meta):
32+
if quantized_dtype := meta.data.get(QCOM_QUANTIZED_IO, None):
33+
arg = arg.to(dtype=quantized_dtype)
34+
meta["spec"] = self._make_spec(arg)
35+
return super().placeholder(name, arg, meta)
36+
37+
def call_getitem(self, value, key: int, meta):
38+
meta["spec"] = value.node.meta["spec"][key]
39+
return super().call_getitem(value, key, meta)
40+
41+
def call_delegate(self, lowered_module, args, kwargs, meta):
42+
args_data, _ = pytree.tree_map_only(
43+
ProxyValue, lambda x: x.data, (args, kwargs)
44+
)
45+
meta["spec"] = pytree.tree_map(
46+
self._make_spec,
47+
executorch_call_delegate(lowered_module, *args_data),
48+
)
49+
return super().call_delegate(lowered_module, args, kwargs, meta)

backends/qualcomm/quantizer/qconfig.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ def get_ptq_per_block_quant_config(
396396
)
397397

398398

399-
# TODO merge qat and ptq to a fucntion, and use a bool flag to control it
399+
# TODO merge qat and ptq to a function, and use a bool flag to control it
400400
def get_8a8w_qnn_qat_config(
401401
act_symmetric: bool = False, act_observer=MovingAverageMinMaxObserver
402402
) -> QuantizationConfig:
@@ -598,21 +598,7 @@ def get_qat_per_channel_quant_config(
598598
observer_or_fake_quant_ctr=weight_fake_quant_ctr,
599599
)
600600

601-
bias_fake_quant_ctr = FakeQuantize.with_args(
602-
dtype=torch.int32,
603-
quant_min=torch.iinfo(torch.int32).min,
604-
quant_max=torch.iinfo(torch.int32).max,
605-
qscheme=torch.per_tensor_symmetric,
606-
reduce_range=True,
607-
observer=MovingAverageMinMaxObserver,
608-
)
609-
bias_quantization_spec = QuantizationSpec(
610-
dtype=torch.int32,
611-
quant_min=torch.iinfo(torch.int32).min,
612-
quant_max=torch.iinfo(torch.int32).max,
613-
qscheme=torch.per_tensor_symmetric,
614-
observer_or_fake_quant_ctr=bias_fake_quant_ctr,
615-
)
601+
bias_quantization_spec = _derived_bias_quant_spec
616602

617603
quantization_config = QuantizationConfig(
618604
input_activation=act_quantization_spec,

0 commit comments

Comments
 (0)