Skip to content

Commit fd62c81

Browse files
Fix review comments
Fix comments Co-authored-by: chen03.zhao <[email protected]>
1 parent 23942b6 commit fd62c81

File tree

9 files changed

+57
-121
lines changed

9 files changed

+57
-121
lines changed

backends/samsung/_passes/annotate_qparams.py

Lines changed: 13 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class AnnotateQparamsPass(ExportPass):
3030
and add Q->DQ after removing all the Q->DQs.
3131
"""
3232

33-
deliver_nodes = {
33+
propagate_nodes = {
3434
exir_ops.edge.aten.view_copy.default,
3535
exir_ops.edge.aten.permute_copy.default,
3636
exir_ops.edge.aten.squeeze_copy.default,
@@ -83,7 +83,7 @@ def _impl(node: Node, res_list: List[Node]):
8383
_impl(user, res_list)
8484
return res_list
8585

86-
def _deliver_quant_params(self, node: Node):
86+
def _propagate_quant_params(self, node: Node):
8787
assert (
8888
quantize_attrs := node.meta.get("quantize_attrs")
8989
), "Must be annotated node."
@@ -98,25 +98,25 @@ def _deliver_quant_params(self, node: Node):
9898
):
9999
break
100100
node = user
101-
# Case1: ...-q-dq(cur)-deliver_node-node(not d-dq)
102-
# Case2: deliver_node(delivered)-deliver_node-node(not q-dq)
101+
# Case1: ...-q-dq(cur)-propagate_node-node(not d-dq)
102+
# Case2: propagate_node(propagateed)-propagate_node-node(not q-dq)
103103
for idx, user in enumerate(node.users.keys()):
104-
# For the branch who need to be requantized, we deliver the requantize params
104+
# For the branch who need to be requantized, we propagate the requantize params
105105
user_attrs = requantize_map.get(idx, quantize_attrs)
106-
if user.target not in self.deliver_nodes:
106+
if user.target not in self.propagate_nodes:
107107
continue
108108
if len(user.users) == 1:
109109
# Possibily no need for checking len(users)>1
110110
user_of_user = list(user.users)[0]
111-
# node-q-dq-deliver-q-dq not need for delivery
111+
# node-q-dq-propagate-q-dq not need for propagatey
112112
if (
113113
user_of_user.target in QuantConstants.QUANT_OPS_KEY_MAP
114114
or user_of_user.target in QuantConstants.DEQUANT_OPS_KEY_MAP
115115
):
116116
continue
117-
# Deliver quant for node-q-dq-deliver_node-node(not qdq)
117+
# propagate quant for node-q-dq-propagate_node-node(not qdq)
118118
user.meta["quantize_attrs"] = user_attrs
119-
self._deliver_quant_params(user)
119+
self._propagate_quant_params(user)
120120

121121
def _annotate_requantize(self, node: Node):
122122
assert (
@@ -153,16 +153,7 @@ def _check_same(requant_obj, ori_obj) -> bool:
153153

154154
def _annotate(self, graph_module: GraphModule):
155155
for node in graph_module.graph.nodes:
156-
if key_map := QuantConstants.DEQUANT_OPS_KEY_MAP.get(node.target, None):
157-
# We will fold node with constant output in the future pass as a constant node
158-
# example: Constant->Q->DQ->nodeN->Q->DQ, this seq will be folded to one
159-
# We need to store the q-params from last DQ params for quantizing constant value
160-
quant_attrs = self.get_quant_attrs(node, key_map)
161-
node.meta["quantize_attrs"] = quant_attrs
162-
continue
163-
else:
164-
key_map = QuantConstants.QUANT_OPS_KEY_MAP.get(node.target, None)
165-
# ignore pre-quantized params now.
156+
key_map = QuantConstants.QUANT_OPS_KEY_MAP.get(node.target, None)
166157
if not key_map:
167158
continue
168159
source_node = node.args[0]
@@ -172,46 +163,15 @@ def _annotate(self, graph_module: GraphModule):
172163
):
173164
# Currently, don't add quant info for d_qd node here.
174165
continue
166+
elif source_node.target == operator.getitem:
167+
source_node = source_node.args[0]
175168
quant_attrs = self.get_quant_attrs(node, key_map)
176-
assert node.args[0].target != operator.getitem, "Not supported now."
177-
source_node = node.args[0]
178169
source_node.meta["quantize_attrs"] = quant_attrs
179170
self._annotate_requantize(source_node)
180-
self._deliver_quant_params(source_node)
181-
182-
def _annotate_real_out(self, graph_module: GraphModule):
183-
for output_nodes in filter(
184-
lambda x: x.op == "output", graph_module.graph.nodes
185-
):
186-
output_nodes = list(output_nodes.args[0])
187-
for idx, output_node in enumerate(output_nodes):
188-
if output_node.target not in [
189-
*QuantConstants.QUANT_OPS_KEY_MAP.keys(),
190-
*QuantConstants.DEQUANT_OPS_KEY_MAP.keys(),
191-
]:
192-
continue
193-
while output_node.args[0].target in [
194-
*QuantConstants.QUANT_OPS_KEY_MAP.keys(),
195-
*QuantConstants.DEQUANT_OPS_KEY_MAP.keys(),
196-
]:
197-
output_node = output_node.args[0]
198-
output_nodes[idx] = output_node
199-
for node in output_nodes:
200-
if node.target in QuantConstants.QUANT_OPS_KEY_MAP:
201-
node.args[0].meta["real_out"] = True
202-
else:
203-
node.meta["real_out"] = True
204-
205-
def _annotate_real_in(self, graph_module: GraphModule):
206-
for in_node in filter(
207-
lambda x: is_graph_input(self.edge_program, x), graph_module.graph.nodes
208-
):
209-
in_node.meta["real_in"] = True
171+
self._propagate_quant_params(source_node)
210172

211173
def call(self, graph_module: GraphModule):
212174
self._annotate(graph_module)
213-
self._annotate_real_out(graph_module)
214-
self._annotate_real_in(graph_module)
215175
graph_module.recompile()
216176
return PassResult(graph_module, True)
217177

@@ -223,7 +183,6 @@ def get_quant_attrs(
223183
for key, attr in zip(quant_attr_keys[1:], quant_node.args[1:]):
224184
# For channel-wise quantization, params are stored by buffer nodes.
225185
if isinstance(attr, torch.fx.Node):
226-
assert isinstance(attr.target, str), "Not supported now. "
227186
attr = get_buffer(self.edge_program, attr)
228187
quant_attrs[key] = attr
229188
quant_attrs["target"] = quant_node.target

backends/samsung/_passes/conv1d_to_conv2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,5 +93,5 @@ def call(self, graph_module: torch.fx.GraphModule):
9393
unsqueeze_before.meta["quantize_attrs"] = prev_qparams
9494

9595
graph_module.recompile()
96-
graph_module = super().call(graph_module).graph_module
96+
_ = super().call(graph_module).graph_module
9797
return PassResult(graph_module, True)

backends/samsung/_passes/fold_qdq.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,5 @@ def call(self, graph_module: GraphModule):
3232
self._fold(graph_module)
3333
graph_module.recompile()
3434
dead_code_elimination_pass(graph_module)
35+
_ = super().call(graph_module).graph_module
3536
return PassResult(graph_module, True)

backends/samsung/_passes/fold_redundant_as_strided_copy.py

Lines changed: 0 additions & 53 deletions
This file was deleted.

backends/samsung/_passes/fuse_conv_act.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,4 +73,5 @@ def call(self, graph_module: GraphModule):
7373
self._fuse(graph_module)
7474
graph_module.recompile()
7575
dead_code_elimination_pass(graph_module)
76+
_ = super().call(graph_module).graph_module
7677
return PassResult(graph_module, True)

backends/samsung/_passes/remove_useless_ops.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,41 @@ class RemoveUselessOpPass(ExportPass):
2424
def __init__(self):
2525
super().__init__()
2626

27+
def gen_pattern_as_strided_copy(self, graph_module: GraphModule):
28+
for node in list(graph_module.graph.nodes): # noqa: C416
29+
if node.target != exir_ops.edge.aten.mean.dim:
30+
continue
31+
if len(node.users) != 1:
32+
continue
33+
successor = list(node.users.keys())[0]
34+
if successor.target != exir_ops.edge.aten.as_strided_copy.default:
35+
continue
36+
is_pattern = True
37+
count = 0
38+
for i, stride in enumerate(successor.args[2]):
39+
if stride < node.meta["val"].size()[i]:
40+
if stride == 1:
41+
count += 1
42+
else:
43+
is_pattern = False
44+
break
45+
if count >= 2:
46+
is_pattern = False
47+
break
48+
if is_pattern:
49+
yield successor
50+
51+
def _fold_as_strided_copy(
52+
self,
53+
graph_module: GraphModule,
54+
):
55+
for as_strided_copy_node in self.gen_pattern_as_strided_copy(graph_module):
56+
for user in list(as_strided_copy_node.users.keys()):
57+
user.replace_input_with(
58+
as_strided_copy_node, as_strided_copy_node.args[0]
59+
)
60+
graph_module.graph.erase_node(as_strided_copy_node)
61+
2762
def _remove_useless(
2863
self,
2964
graph_module: GraphModule,
@@ -42,9 +77,11 @@ def _remove_useless(
4277
for user in [user for user in node.users.keys()]: # noqa: C416
4378
user.replace_input_with(node, node.all_input_nodes[0])
4479
graph_module.graph.erase_node(node)
80+
self._fold_as_strided_copy(graph_module)
4581

4682
def call(self, graph_module: GraphModule):
4783
self._remove_useless(graph_module)
4884
graph_module.recompile()
4985
dead_code_elimination_pass(graph_module)
86+
_ = super().call(graph_module).graph_module
5087
return PassResult(graph_module, True)

backends/samsung/enn_preprocess.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from executorch.backends.samsung._passes.customized_constant_prop import (
1717
ConstantPropPass,
1818
)
19+
from executorch.backends.samsung._passes.annotate_qparams import AnnotateQparamsPass
1920
from executorch.backends.samsung._passes.fold_qdq import FoldQDQPass
2021
from executorch.backends.samsung._passes.insert_qdq import InsertQDQPass
2122
from executorch.backends.samsung._passes.replace_scalar_ops import ReplaceOpsWithScalar
@@ -58,6 +59,7 @@ def preprocess(
5859

5960
enn_preprocess_passes = PassManager(
6061
passes=[
62+
AnnotateQparamsPass(edge_program),
6163
FoldQDQPass(),
6264
ConstantPropPass(edge_program),
6365
Conv1dToConv2d(edge_program),

backends/samsung/serialization/enn_graph_schema.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,6 @@ def define_tensor( # noqa: C901
7979

8080
if quant_param is not None:
8181
need_quantize = True
82-
if quant_param.get(QuantConstants.QUANT_KEY.quant_dtype) == torch.int32:
83-
quant_param = none_quant_tensor_quant_meta()
84-
need_quantize = False
8582

8683
scales = self._affine_meta_param(
8784
quant_param[QuantConstants.QUANT_KEY.scale]
@@ -131,8 +128,6 @@ def serialize(self):
131128
def _affine_meta_param(param: Any) -> str:
132129
type_str_affine_table = {
133130
torch.int8: "AINT8",
134-
torch.int32: "FLOAT32", # INT32 just used for HW quant.
135-
torch.int16: "AINT16", # INT32 just used for HW quant.
136131
}
137132
if isinstance(param, str):
138133
return param

backends/samsung/utils/export_utils.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,6 @@
99

1010
import executorch.exir as exir
1111
import torch
12-
from executorch.backends.samsung._passes.annotate_qparams import AnnotateQparamsPass
13-
from executorch.backends.samsung._passes.fold_redundant_as_strided_copy import (
14-
FoldRudundantAsStridedCopyPass,
15-
)
1612
from executorch.backends.samsung._passes.fuse_conv_act import FuseConvActPass
1713
from executorch.backends.samsung._passes.remove_useless_ops import RemoveUselessOpPass
1814
from executorch.backends.samsung.partition.enn_partitioner import EnnPartitioner
@@ -50,11 +46,9 @@ def get_edge_compile_config():
5046
)
5147

5248

53-
def get_enn_pass_list(edge_program: ExportedProgram) -> List[PassType]:
49+
def get_enn_pass_list() -> List[PassType]:
5450
return [
5551
RemoveUselessOpPass(),
56-
FoldRudundantAsStridedCopyPass(),
57-
AnnotateQparamsPass(edge_program),
5852
RemoveCloneOpsTransform(),
5953
FuseConvActPass(),
6054
]
@@ -90,7 +84,7 @@ def to_edge_transform_and_lower_to_enn(
9084
) -> exir.ExecutorchProgramManager:
9185
assert compile_specs is not None, "For now, we must deliver complile specs"
9286
prog = torch.export.export(module, inputs)
93-
pass_list = get_enn_pass_list(prog)
87+
pass_list = get_enn_pass_list()
9488
if custom_pass_config:
9589
pass_list.extend(custom_pass_config)
9690
return to_edge_transform_and_lower(

0 commit comments

Comments
 (0)