Skip to content

Commit f9a1d79

Browse files
Jiseong-ohChen03ZhaoSamsungsangsoo.Ko
committed
Add quantization feature and example codes for MV2
1. Add quant strategies of enn-backend 2. Add support for the enn's quant strategies 3. Provide example code of MV2 Co-authored-by: chen.zhao <[email protected]> Co-authored-by: sangsoo.Ko <[email protected]>
1 parent 9ab5592 commit f9a1d79

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+2590
-105
lines changed
Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
# Copyright (c) 2025 Samsung Electronics Co. LTD
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import operator
8+
from typing import Any, Dict, List, Optional
9+
10+
import torch
11+
from executorch.backends.samsung.utils.constants import QuantConstants
12+
from executorch.backends.samsung.utils.utils import is_graph_input
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
from executorch.exir.pass_base import ExportPass, PassResult
15+
from torch._export.utils import get_buffer
16+
from torch.export import ExportedProgram
17+
from torch.fx import GraphModule, Node
18+
19+
20+
class AnnotateQparamsPass(ExportPass):
21+
"""This parse is to add quantize properties to node need to be quantized.
22+
23+
Annotate Quant params:
24+
For src_node->Q->DQ->..., we will add the quant params from Q->DQ node
25+
to the src_node
26+
27+
Annotate Requantize:
28+
For src_node->Q->DQ->Q->DQ->..., if the multiple Q->DQ contains
29+
different quant params, we will mark the src_node as need requantize,
30+
and add Q->DQ after removing all the Q->DQs.
31+
"""
32+
33+
deliver_nodes = {
34+
exir_ops.edge.aten.view_copy.default,
35+
exir_ops.edge.aten.permute_copy.default,
36+
exir_ops.edge.aten.squeeze_copy.default,
37+
exir_ops.edge.aten.squeeze_copy.dim,
38+
exir_ops.edge.aten.squeeze_copy.dims,
39+
exir_ops.edge.aten.slice_copy.Tensor,
40+
exir_ops.edge.aten.unsqueeze_copy.default,
41+
exir_ops.edge.aten.concat.default,
42+
exir_ops.edge.aten.cat.default,
43+
exir_ops.edge.aten.expand_copy.default,
44+
}
45+
46+
def __init__(self, edge_program: ExportedProgram):
47+
super().__init__()
48+
self.edge_program = edge_program
49+
50+
def _get_last_dqs(self, node: Node) -> List[Node]:
51+
r"""From one Q-DQ node, find the last DQs in the quantization node chain.
52+
53+
54+
need to consider such case:
55+
/--Q-DQ-node1
56+
node->Q->DQ--node-node2
57+
\--Q-DQ-node3
58+
This is a dfs implemention, so result will keep sorted
59+
Args:
60+
node (Node): Search DQ from this node.
61+
62+
Returns:
63+
List[Node]: list of DQ node by original sequence
64+
"""
65+
66+
def _impl(node: Node, res_list: List[Node]):
67+
if (
68+
node.target not in QuantConstants.QUANT_OPS_KEY_MAP
69+
and node.target not in QuantConstants.DEQUANT_OPS_KEY_MAP
70+
):
71+
return
72+
for user in node.users.keys():
73+
if (
74+
user.target not in QuantConstants.QUANT_OPS_KEY_MAP
75+
and user.target not in QuantConstants.DEQUANT_OPS_KEY_MAP
76+
):
77+
res_list.append(node)
78+
else:
79+
_impl(user, res_list)
80+
81+
res_list: List[Node] = []
82+
for user in node.users:
83+
_impl(user, res_list)
84+
return res_list
85+
86+
def _deliver_quant_params(self, node: Node):
87+
assert (
88+
quantize_attrs := node.meta.get("quantize_attrs")
89+
), "Must be annotated node."
90+
requantize_map: Dict[Node, Node] = node.meta.get("requantize", {})
91+
while node.users:
92+
if len(node.users) != 1:
93+
break
94+
user = list(node.users.keys())[0]
95+
if (
96+
user.target not in QuantConstants.QUANT_OPS_KEY_MAP
97+
and user.target not in QuantConstants.DEQUANT_OPS_KEY_MAP
98+
):
99+
break
100+
node = user
101+
# Case1: ...-q-dq(cur)-deliver_node-node(not d-dq)
102+
# Case2: deliver_node(delivered)-deliver_node-node(not q-dq)
103+
for idx, user in enumerate(node.users.keys()):
104+
# For the branch who need to be requantized, we deliver the requantize params
105+
user_attrs = requantize_map.get(idx, quantize_attrs)
106+
if user.target not in self.deliver_nodes:
107+
continue
108+
if len(user.users) == 1:
109+
# Possibily no need for checking len(users)>1
110+
user_of_user = list(user.users)[0]
111+
# node-q-dq-deliver-q-dq not need for delivery
112+
if (
113+
user_of_user.target in QuantConstants.QUANT_OPS_KEY_MAP
114+
or user_of_user.target in QuantConstants.DEQUANT_OPS_KEY_MAP
115+
):
116+
continue
117+
# Deliver quant for node-q-dq-deliver_node-node(not qdq)
118+
user.meta["quantize_attrs"] = user_attrs
119+
self._deliver_quant_params(user)
120+
121+
def _annotate_requantize(self, node: Node):
122+
assert (
123+
ori_quant_attrs := node.meta.get("quantize_attrs")
124+
), "No quant parameters found"
125+
list_for_requantize = self._get_last_dqs(node)
126+
node.meta["requantize"] = node.meta.get("requantize", {})
127+
128+
# We use index to mark the output to be requantized
129+
# Because user obj and name may change when we requantize them.
130+
131+
def _check_same(requant_obj, ori_obj) -> bool:
132+
if type(requant_obj) != type(ori_obj): # noqa E721
133+
# We need actually same type here.
134+
return False
135+
if not isinstance(requant_obj, torch.Tensor):
136+
return requant_obj == ori_obj
137+
if requant_obj.shape != ori_obj.shape:
138+
return False
139+
return bool((requant_obj == ori_obj).all())
140+
141+
requantize_map: Dict[int, Dict] = node.meta["requantize"]
142+
for idx, dq in enumerate(list_for_requantize):
143+
q = dq.all_input_nodes[0]
144+
if q.target not in QuantConstants.QUANT_OPS_KEY_MAP:
145+
continue
146+
key_map = QuantConstants.DEQUANT_OPS_KEY_MAP[dq.target]
147+
requantize_attrs = self.get_quant_attrs(q, key_map)
148+
if not all(
149+
_check_same(ori_quant_attrs[key], requantize_attrs[key])
150+
for key in key_map.values()
151+
):
152+
requantize_map[idx] = requantize_attrs
153+
154+
def _annotate(self, graph_module: GraphModule):
155+
for node in graph_module.graph.nodes:
156+
if key_map := QuantConstants.DEQUANT_OPS_KEY_MAP.get(node.target, None):
157+
# We will fold node with constant output in the future pass as a constant node
158+
# example: Constant->Q->DQ->nodeN->Q->DQ, this seq will be folded to one
159+
# We need to store the q-params from last DQ params for quantizing constant value
160+
quant_attrs = self.get_quant_attrs(node, key_map)
161+
node.meta["quantize_attrs"] = quant_attrs
162+
continue
163+
else:
164+
key_map = QuantConstants.QUANT_OPS_KEY_MAP.get(node.target, None)
165+
# ignore pre-quantized params now.
166+
if not key_map:
167+
continue
168+
source_node = node.args[0]
169+
if source_node.target in (
170+
*QuantConstants.QUANT_OPS_KEY_MAP,
171+
*QuantConstants.DEQUANT_OPS_KEY_MAP,
172+
):
173+
# Currently, don't add quant info for d_qd node here.
174+
continue
175+
quant_attrs = self.get_quant_attrs(node, key_map)
176+
assert node.args[0].target != operator.getitem, "Not supported now."
177+
source_node = node.args[0]
178+
source_node.meta["quantize_attrs"] = quant_attrs
179+
self._annotate_requantize(source_node)
180+
self._deliver_quant_params(source_node)
181+
182+
def _annotate_real_out(self, graph_module: GraphModule):
183+
for output_nodes in filter(
184+
lambda x: x.op == "output", graph_module.graph.nodes
185+
):
186+
output_nodes = list(output_nodes.args[0])
187+
for idx, output_node in enumerate(output_nodes):
188+
if output_node.target not in [
189+
*QuantConstants.QUANT_OPS_KEY_MAP.keys(),
190+
*QuantConstants.DEQUANT_OPS_KEY_MAP.keys(),
191+
]:
192+
continue
193+
while output_node.args[0].target in [
194+
*QuantConstants.QUANT_OPS_KEY_MAP.keys(),
195+
*QuantConstants.DEQUANT_OPS_KEY_MAP.keys(),
196+
]:
197+
output_node = output_node.args[0]
198+
output_nodes[idx] = output_node
199+
for node in output_nodes:
200+
if node.target in QuantConstants.QUANT_OPS_KEY_MAP:
201+
node.args[0].meta["real_out"] = True
202+
else:
203+
node.meta["real_out"] = True
204+
205+
def _annotate_real_in(self, graph_module: GraphModule):
206+
for in_node in filter(
207+
lambda x: is_graph_input(self.edge_program, x), graph_module.graph.nodes
208+
):
209+
in_node.meta["real_in"] = True
210+
211+
def call(self, graph_module: GraphModule):
212+
self._annotate(graph_module)
213+
self._annotate_real_out(graph_module)
214+
self._annotate_real_in(graph_module)
215+
graph_module.recompile()
216+
return PassResult(graph_module, True)
217+
218+
def get_quant_attrs(
219+
self, quant_node: torch.fx.Node, key_map: Optional[Dict] = None
220+
) -> Dict[str, Any]:
221+
quant_attr_keys = [arg.name for arg in quant_node.target._schema.arguments]
222+
quant_attrs = dict.fromkeys(quant_attr_keys)
223+
for key, attr in zip(quant_attr_keys[1:], quant_node.args[1:]):
224+
# For channel-wise quantization, params are stored by buffer nodes.
225+
if isinstance(attr, torch.fx.Node):
226+
assert isinstance(attr.target, str), "Not supported now. "
227+
attr = get_buffer(self.edge_program, attr)
228+
quant_attrs[key] = attr
229+
quant_attrs["target"] = quant_node.target
230+
if key_map is None:
231+
return quant_attrs
232+
miss_attrs = []
233+
for aten_attr, snc_attr in key_map.items():
234+
if aten_attr not in quant_attrs:
235+
miss_attrs.append(aten_attr)
236+
continue
237+
attr = quant_attrs[aten_attr]
238+
quant_attrs.pop(aten_attr)
239+
quant_attrs[snc_attr] = attr
240+
assert (
241+
not miss_attrs
242+
), f"Miss quant attrs {miss_attrs} for node {quant_node.name}"
243+
return quant_attrs
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright (c) 2025 Samsung Electronics Co. LTD
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.backends.samsung.quantizer.quantizer import global_quant_info
9+
from executorch.backends.samsung.utils.constants import QuantConstants
10+
from executorch.backends.transforms.utils import get_param_tensor, is_param_node
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from executorch.exir.pass_base import ExportPass, PassResult
13+
from torch.export import ExportedProgram
14+
15+
16+
class AnnotateScalarParametersPass(ExportPass):
17+
"""
18+
Need to add quantization parameters for scalars for some ops
19+
Ifm(Quantized)------TargetOP---
20+
Scalar(Non-Quant)---/
21+
Notice: Such scalars are converted to tensor node by default pass
22+
"""
23+
24+
TARGET_OPS = {
25+
exir_ops.edge.aten.mul.Tensor,
26+
exir_ops.edge.aten.add.Tensor,
27+
exir_ops.edge.aten.div.Tensor,
28+
}
29+
30+
def __init__(self, edge_program: ExportedProgram):
31+
super().__init__()
32+
self.edge_program = edge_program
33+
34+
def annotate(self, graph_module: torch.fx.GraphModule):
35+
for node in graph_module.graph.nodes:
36+
if node.target not in self.TARGET_OPS or "quantize_attrs" not in node.meta:
37+
continue
38+
torch_quant_dtype = global_quant_info.weight_precison.torch_dtype
39+
for input_arg in node.all_input_nodes:
40+
if input_arg.op not in ("placeholder", "get_attr") or not is_param_node(
41+
self.edge_program, input_arg
42+
):
43+
continue
44+
else:
45+
tensor = get_param_tensor(self.edge_program, input_arg)
46+
if not tensor.shape:
47+
qparams = {
48+
QuantConstants.QUANT_KEY.scale: float(tensor),
49+
QuantConstants.QUANT_KEY.quant_dtype: torch_quant_dtype,
50+
QuantConstants.QUANT_KEY.quant_max: torch.iinfo(
51+
torch_quant_dtype
52+
).max,
53+
QuantConstants.QUANT_KEY.quant_min: torch.iinfo(
54+
torch_quant_dtype
55+
).min,
56+
QuantConstants.QUANT_KEY.zero_point: 0,
57+
}
58+
input_arg.meta["quantize_attrs"] = qparams
59+
60+
def call(self, graph_module: torch.fx.GraphModule):
61+
graph = graph_module.graph
62+
self.annotate(graph_module)
63+
graph.eliminate_dead_code()
64+
graph_module.recompile()
65+
return PassResult(graph_module, True)

0 commit comments

Comments
 (0)