Skip to content

Commit 8b67236

Browse files
Jiseong-ohChen03ZhaoSamsungsangsoo.Kochong-chen01xiongzhan-linghu
authored
Enable Exynos backend Quatization (pytorch#14464)
### Summary - Implemented quantized strategies for enn-backend. - Added support for ENN's quantization strategies. - Successfully verified multiple quantized models. ### Test plan python -m executorch.examples.samsung.scripts.${MODEL_NAME} -c e9955 -p A8W8 cc @SS-JIA @digantdesai @kimishpatel --------- Signed-off-by: jiseong.oh <[email protected]> Co-authored-by: chen.zhao <[email protected]> Co-authored-by: sangsoo.Ko <[email protected]> Co-authored-by: chong-chen <[email protected]> Co-authored-by: xz-linghu <[email protected]>
1 parent 7f31fd8 commit 8b67236

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+4079
-99
lines changed

.ci/scripts/setup-samsung-linux-deps.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ download_ai_lite_core() {
1313
API_BASE="https://soc-developer.semiconductor.samsung.com/api/v1/resource/ai-litecore/download"
1414
API_KEY=$SAMSUNG_AI_LITECORE_KEY
1515

16-
VERSION="0.5"
16+
VERSION="0.7"
1717
OS_NAME="Ubuntu 22.04"
1818
OUT_FILE="/tmp/exynos-ai-litecore-v${VERSION}.tar.gz"
1919
TARGET_PATH="/tmp/exynos_ai_lite_core"
@@ -62,7 +62,7 @@ install_enn_backend() {
6262
export PYTHONPATH=${PYTHONPATH:-}:${EXECUTORCH_ROOT}/..
6363
}
6464

65-
AI_LITE_CORE_VERSION=0.5.0
65+
AI_LITE_CORE_VERSION=0.7.0
6666

6767
download_ai_lite_core ${AI_LITE_CORE_VERSION}
6868
install_enn_backend

.github/workflows/pull.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -935,6 +935,12 @@ jobs:
935935
python -m executorch.examples.samsung.aot_compiler --model_name=$model -c E9955
936936
done
937937
938+
# Test quant models
939+
model_scripts="deeplab_v3 edsr inception_v3 inception_v4 mobilenet_v2 mobilenet_v3 resnet18 resnet50 vit wav2letter"
940+
for m_script in $model_scripts; do
941+
python -m executorch.examples.samsung.scripts.${m_script} -c e9955 -p A8W8
942+
done
943+
938944
# Test ops
939945
python -m unittest discover -s backends/samsung/test/ops -p "test_*.py"
940946
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
# Copyright (c) 2025 Samsung Electronics Co. LTD
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import operator
8+
from typing import Any, Dict, List, Optional
9+
10+
import torch
11+
from executorch.backends.samsung.utils.constants import QuantConstants
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
from executorch.exir.pass_base import ExportPass, PassResult
14+
from torch._export.utils import get_buffer
15+
from torch.export import ExportedProgram
16+
from torch.fx import GraphModule, Node
17+
18+
19+
class AnnotateQparamsPass(ExportPass):
20+
"""This parse is to add quantize properties to node need to be quantized.
21+
22+
Annotate Quant params:
23+
For src_node->Q->DQ->..., we will add the quant params from Q->DQ node
24+
to the src_node
25+
26+
Annotate Requantize:
27+
For src_node->Q->DQ->Q->DQ->..., if the multiple Q->DQ contains
28+
different quant params, we will mark the src_node as need requantize,
29+
and add Q->DQ after removing all the Q->DQs.
30+
"""
31+
32+
propagate_nodes = {
33+
exir_ops.edge.aten.view_copy.default,
34+
exir_ops.edge.aten.permute_copy.default,
35+
exir_ops.edge.aten.squeeze_copy.default,
36+
exir_ops.edge.aten.squeeze_copy.dim,
37+
exir_ops.edge.aten.squeeze_copy.dims,
38+
exir_ops.edge.aten.slice_copy.Tensor,
39+
exir_ops.edge.aten.unsqueeze_copy.default,
40+
exir_ops.edge.aten.concat.default,
41+
exir_ops.edge.aten.cat.default,
42+
exir_ops.edge.aten.expand_copy.default,
43+
}
44+
45+
def __init__(self, edge_program: ExportedProgram):
46+
super().__init__()
47+
self.edge_program = edge_program
48+
49+
def _get_last_dqs(self, node: Node) -> List[Node]:
50+
r"""From one Q-DQ node, find the last DQs in the quantization node chain.
51+
52+
53+
need to consider such case:
54+
/--Q-DQ-node1
55+
node->Q->DQ--node-node2
56+
\--Q-DQ-node3
57+
This is a dfs implemention, so result will keep sorted
58+
Args:
59+
node (Node): Search DQ from this node.
60+
61+
Returns:
62+
List[Node]: list of DQ node by original sequence
63+
"""
64+
65+
def _impl(node: Node, res_list: List[Node]):
66+
if (
67+
node.target not in QuantConstants.QUANT_OPS_KEY_MAP
68+
and node.target not in QuantConstants.DEQUANT_OPS_KEY_MAP
69+
):
70+
return
71+
for user in node.users.keys():
72+
if (
73+
user.target not in QuantConstants.QUANT_OPS_KEY_MAP
74+
and user.target not in QuantConstants.DEQUANT_OPS_KEY_MAP
75+
):
76+
res_list.append(node)
77+
else:
78+
_impl(user, res_list)
79+
80+
res_list: List[Node] = []
81+
for user in node.users:
82+
_impl(user, res_list)
83+
return res_list
84+
85+
def _propagate_quant_params(self, node: Node):
86+
assert (
87+
quantize_attrs := node.meta.get("quantize_attrs")
88+
), "Must be annotated node."
89+
requantize_map: Dict[Node, Node] = node.meta.get("requantize", {})
90+
while node.users:
91+
if len(node.users) != 1:
92+
break
93+
user = list(node.users.keys())[0]
94+
if (
95+
user.target not in QuantConstants.QUANT_OPS_KEY_MAP
96+
and user.target not in QuantConstants.DEQUANT_OPS_KEY_MAP
97+
):
98+
break
99+
node = user
100+
# Case1: ...-q-dq(cur)-propagate_node-node(not d-dq)
101+
# Case2: propagate_node(propagateed)-propagate_node-node(not q-dq)
102+
for idx, user in enumerate(node.users.keys()):
103+
# For the branch who need to be requantized, we propagate the requantize params
104+
user_attrs = requantize_map.get(idx, quantize_attrs)
105+
if user.target not in self.propagate_nodes:
106+
continue
107+
if len(user.users) == 1:
108+
# Possibily no need for checking len(users)>1
109+
user_of_user = list(user.users)[0]
110+
# node-q-dq-propagate-q-dq not need for propagatey
111+
if (
112+
user_of_user.target in QuantConstants.QUANT_OPS_KEY_MAP
113+
or user_of_user.target in QuantConstants.DEQUANT_OPS_KEY_MAP
114+
):
115+
continue
116+
# propagate quant for node-q-dq-propagate_node-node(not qdq)
117+
user.meta["quantize_attrs"] = user_attrs
118+
self._propagate_quant_params(user)
119+
120+
def _annotate_requantize(self, node: Node):
121+
assert (
122+
ori_quant_attrs := node.meta.get("quantize_attrs")
123+
), "No quant parameters found"
124+
list_for_requantize = self._get_last_dqs(node)
125+
node.meta["requantize"] = node.meta.get("requantize", {})
126+
127+
# We use index to mark the output to be requantized
128+
# Because user obj and name may change when we requantize them.
129+
130+
def _check_same(requant_obj, ori_obj) -> bool:
131+
if type(requant_obj) != type(ori_obj): # noqa E721
132+
# We need actually same type here.
133+
return False
134+
if not isinstance(requant_obj, torch.Tensor):
135+
return requant_obj == ori_obj
136+
if requant_obj.shape != ori_obj.shape:
137+
return False
138+
return bool((requant_obj == ori_obj).all())
139+
140+
requantize_map: Dict[int, Dict] = node.meta["requantize"]
141+
for idx, dq in enumerate(list_for_requantize):
142+
q = dq.all_input_nodes[0]
143+
if q.target not in QuantConstants.QUANT_OPS_KEY_MAP:
144+
continue
145+
key_map = QuantConstants.DEQUANT_OPS_KEY_MAP[dq.target]
146+
requantize_attrs = self.get_quant_attrs(q, key_map)
147+
if not all(
148+
_check_same(ori_quant_attrs[key], requantize_attrs[key])
149+
for key in key_map.values()
150+
):
151+
requantize_map[idx] = requantize_attrs
152+
153+
def _annotate(self, graph_module: GraphModule):
154+
for node in graph_module.graph.nodes:
155+
key_map = QuantConstants.QUANT_OPS_KEY_MAP.get(node.target, None)
156+
if not key_map:
157+
continue
158+
source_node = node.args[0]
159+
if source_node.target in (
160+
*QuantConstants.QUANT_OPS_KEY_MAP,
161+
*QuantConstants.DEQUANT_OPS_KEY_MAP,
162+
):
163+
# Currently, don't add quant info for d_qd node here.
164+
continue
165+
elif source_node.target == operator.getitem:
166+
source_node = source_node.args[0]
167+
quant_attrs = self.get_quant_attrs(node, key_map)
168+
source_node.meta["quantize_attrs"] = quant_attrs
169+
self._annotate_requantize(source_node)
170+
self._propagate_quant_params(source_node)
171+
172+
def call(self, graph_module: GraphModule):
173+
self._annotate(graph_module)
174+
graph_module.recompile()
175+
return PassResult(graph_module, True)
176+
177+
def get_quant_attrs(
178+
self, quant_node: torch.fx.Node, key_map: Optional[Dict] = None
179+
) -> Dict[str, Any]:
180+
quant_attr_keys = [arg.name for arg in quant_node.target._schema.arguments]
181+
quant_attrs = dict.fromkeys(quant_attr_keys)
182+
for key, attr in zip(quant_attr_keys[1:], quant_node.args[1:]):
183+
# For channel-wise quantization, params are stored by buffer nodes.
184+
if isinstance(attr, torch.fx.Node):
185+
attr = get_buffer(self.edge_program, attr)
186+
quant_attrs[key] = attr
187+
quant_attrs["target"] = quant_node.target
188+
if key_map is None:
189+
return quant_attrs
190+
miss_attrs = []
191+
for aten_attr, snc_attr in key_map.items():
192+
if aten_attr not in quant_attrs:
193+
miss_attrs.append(aten_attr)
194+
continue
195+
attr = quant_attrs[aten_attr]
196+
quant_attrs.pop(aten_attr)
197+
quant_attrs[snc_attr] = attr
198+
assert (
199+
not miss_attrs
200+
), f"Miss quant attrs {miss_attrs} for node {quant_node.name}"
201+
return quant_attrs
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright (c) 2025 Samsung Electronics Co. LTD
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.backends.samsung.quantizer.quantizer import global_quant_info
9+
from executorch.backends.samsung.utils.constants import QuantConstants
10+
from executorch.backends.transforms.utils import get_param_tensor, is_param_node
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from executorch.exir.pass_base import ExportPass, PassResult
13+
from torch.export import ExportedProgram
14+
15+
16+
class AnnotateScalarParametersPass(ExportPass):
17+
"""
18+
Need to add quantization parameters for scalars for some ops
19+
Ifm(Quantized)------TargetOP---
20+
Scalar(Non-Quant)---/
21+
Notice: Such scalars are converted to tensor node by default pass
22+
"""
23+
24+
TARGET_OPS = {
25+
exir_ops.edge.aten.mul.Tensor,
26+
exir_ops.edge.aten.add.Tensor,
27+
exir_ops.edge.aten.div.Tensor,
28+
}
29+
30+
def __init__(self, edge_program: ExportedProgram):
31+
super().__init__()
32+
self.edge_program = edge_program
33+
34+
def annotate(self, graph_module: torch.fx.GraphModule):
35+
for node in graph_module.graph.nodes:
36+
if node.target not in self.TARGET_OPS or "quantize_attrs" not in node.meta:
37+
continue
38+
torch_quant_dtype = global_quant_info.weight_precison.torch_dtype
39+
for input_arg in node.all_input_nodes:
40+
if input_arg.op not in ("placeholder", "get_attr") or not is_param_node(
41+
self.edge_program, input_arg
42+
):
43+
continue
44+
else:
45+
tensor = get_param_tensor(self.edge_program, input_arg)
46+
if not tensor.shape:
47+
qparams = {
48+
QuantConstants.QUANT_KEY.scale: float(tensor),
49+
QuantConstants.QUANT_KEY.quant_dtype: torch_quant_dtype,
50+
QuantConstants.QUANT_KEY.quant_max: torch.iinfo(
51+
torch_quant_dtype
52+
).max,
53+
QuantConstants.QUANT_KEY.quant_min: torch.iinfo(
54+
torch_quant_dtype
55+
).min,
56+
QuantConstants.QUANT_KEY.zero_point: 0,
57+
}
58+
input_arg.meta["quantize_attrs"] = qparams
59+
60+
def call(self, graph_module: torch.fx.GraphModule):
61+
graph = graph_module.graph
62+
self.annotate(graph_module)
63+
graph.eliminate_dead_code()
64+
graph_module.recompile()
65+
return PassResult(graph_module, True)

0 commit comments

Comments
 (0)