Skip to content

Commit fc926e6

Browse files
committed
Revert "Add helper functions for Q/DQ folding pass"
This reverts commit fd9eb28. Signed-off-by: Digant Desai <[email protected]>
1 parent e4658a3 commit fc926e6

File tree

5 files changed

+45
-77
lines changed

5 files changed

+45
-77
lines changed

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -16,32 +16,6 @@
1616
from torch.fx import GraphModule, Node
1717

1818

19-
def get_input_qparams(node: Node) -> dict[int, QuantArgs]:
20-
"""
21-
Get the input quantization parameters from a node, set by the 'FoldAndAnnotateQParamsPass'.
22-
Raises a ValueError if the node doesn't have any parameters set.
23-
"""
24-
if "input_qparams" not in node.meta.keys():
25-
raise ValueError(f"No input quantization parameter found in node {node}")
26-
input_qparams = cast(dict[int, QuantArgs], node.meta["input_qparams"])
27-
if len(input_qparams) == 0:
28-
raise ValueError(f"No input quantization parameter found in node {node}")
29-
return input_qparams
30-
31-
32-
def get_output_qparams(node: Node) -> dict[int, QuantArgs]:
33-
"""
34-
Get the output quantization parameters from a node, set by the 'FoldAndAnnotateQParamsPass'.
35-
Raises a ValueError if the node doesn't have any parameters set.
36-
"""
37-
if "output_qparams" not in node.meta.keys():
38-
raise ValueError(f"No output quantization parameter found in node {node}")
39-
input_qparams = cast(dict[int, QuantArgs], node.meta["output_qparams"])
40-
if len(input_qparams) == 0:
41-
raise ValueError(f"No output quantization parameter found in node {node}")
42-
return input_qparams
43-
44-
4519
class FoldAndAnnotateQParamsPass(ExportPass):
4620
"""
4721
A pass that walks the graph and removes any DQ and Q nodes before and after the target

backends/arm/operators/op_add.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,9 @@ def define_node(
7676
if output.dtype == ts.DType.INT8:
7777
# Scale output back to 8 bit
7878
# pyre-ignore
79-
tqutils.insert_rescale_op_to_int8(tosa_graph, add_output, scale_back, node)
79+
tqutils.insert_rescale_node_back_to_int8(
80+
tosa_graph, add_output, scale_back, node
81+
)
8082

8183

8284
@register_node_visitor

backends/arm/operators/op_max.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,11 @@
55

66
# pyre-unsafe
77

8-
from typing import List
8+
from typing import cast, List
99

1010
import executorch.backends.arm.tosa_quant_utils as tqutils
11+
1112
import serializer.tosa_serializer as ts
12-
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
13-
get_input_qparams,
14-
)
1513
from executorch.backends.arm.operators.node_visitor import (
1614
NodeVisitor,
1715
register_node_visitor,
@@ -40,23 +38,30 @@ def define_node(
4038
) -> None:
4139
assert inputs[0].dtype == inputs[1].dtype
4240

43-
max_output = output
41+
input_qparams = cast(dict[int, tqutils.QuantArgs], node.meta["input_qparams"])
42+
min_output = output
43+
4444
if inputs[0].dtype == ts.DType.INT8:
45-
input_qparams = get_input_qparams(node)
46-
assert (
47-
len(input_qparams) == 2
48-
), f"Both inputs needs to have quantization information for {node}"
4945
# insert RESCALEs to int32
46+
x_scale = input_qparams[0].scale
47+
x_zp = input_qparams[0].zp
48+
49+
y_scale = input_qparams[1].scale
50+
y_zp = input_qparams[1].zp
51+
5052
assert (
51-
input_qparams[0] == input_qparams[1]
52-
), "Both inputs must have same quantization for MAX"
53+
x_zp == y_zp
54+
), "Different zp for inputs, MAX should be quantized with shared quantization!"
55+
assert (
56+
x_scale == y_scale
57+
), "Different scale for input, MAX should be quantized with shared quantization!"
5358

5459
operand_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
5560
tosa_graph, inputs, node
5661
)
5762

5863
output.shape = tosa_shape(output.shape, output.dim_order)
59-
max_output = tosa_graph.addIntermediate(output.shape, ts.DType.INT32)
64+
min_output = tosa_graph.addIntermediate(output.shape, ts.DType.INT32)
6065
else:
6166
operand_inputs = inputs
6267

@@ -66,9 +71,11 @@ def define_node(
6671
operand_inputs[0].name,
6772
operand_inputs[1].name,
6873
],
69-
[max_output.name],
74+
[min_output.name],
7075
)
7176

7277
if output.dtype == ts.DType.INT8:
7378
# insert RESCALE from int32 back to int8
74-
tqutils.insert_rescale_op_to_int8(tosa_graph, max_output, scale_back, node)
79+
tqutils.insert_rescale_node_back_to_int8(
80+
tosa_graph, min_output, scale_back, node
81+
)

backends/arm/operators/op_min.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,11 @@
55

66
# pyre-unsafe
77

8-
from typing import List
8+
from typing import cast, List
99

1010
import executorch.backends.arm.tosa_quant_utils as tqutils
1111

1212
import serializer.tosa_serializer as ts
13-
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
14-
get_input_qparams,
15-
)
1613
from executorch.backends.arm.operators.node_visitor import (
1714
NodeVisitor,
1815
register_node_visitor,
@@ -41,16 +38,23 @@ def define_node(
4138
) -> None:
4239
assert inputs[0].dtype == inputs[1].dtype
4340

41+
input_qparams = cast(dict[int, tqutils.QuantArgs], node.meta["input_qparams"])
4442
min_output = output
43+
4544
if inputs[0].dtype == ts.DType.INT8:
46-
input_qparams = get_input_qparams(node)
47-
assert (
48-
len(input_qparams) == 2
49-
), f"Both inputs needs to have quantization information for {node}"
5045
# insert RESCALEs to int32
46+
x_scale = input_qparams[0].scale
47+
x_zp = input_qparams[0].zp
48+
49+
y_scale = input_qparams[1].scale
50+
y_zp = input_qparams[1].zp
51+
52+
assert (
53+
x_zp == y_zp
54+
), "Different zp for inputs, MIN should be quantized with shared quantization!"
5155
assert (
52-
input_qparams[0] == input_qparams[1]
53-
), "Both inputs must have same quantization for MIN"
56+
x_scale == y_scale
57+
), "Different scale for input, MIN should be quantized with shared quantization!"
5458

5559
operand_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
5660
tosa_graph, inputs, node
@@ -72,4 +76,6 @@ def define_node(
7276

7377
if output.dtype == ts.DType.INT8:
7478
# insert RESCALE from int32 back to int8
75-
tqutils.insert_rescale_op_to_int8(tosa_graph, min_output, scale_back, node)
79+
tqutils.insert_rescale_node_back_to_int8(
80+
tosa_graph, min_output, scale_back, node
81+
)

backends/arm/tosa_quant_utils.py

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -57,19 +57,14 @@ def insert_rescale_ops_to_int32(
5757
the graph upstream for DQ nodes.
5858
"""
5959

60-
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
61-
get_input_qparams,
62-
)
63-
6460
tensors = inputs.copy()
6561

6662
# Reshape tensor according to TOSA dim order
6763
for tensor in tensors:
6864
dim_order = tensor.dim_order
6965
tensor.shape = [tensor.shape[i] for i in dim_order]
7066

71-
input_qparams = get_input_qparams(node)
72-
qargs = input_qparams.values()
67+
qargs = list(cast(dict[int, QuantArgs], node.meta["input_qparams"]).values())
7368

7469
# Scale the int8 quantized input to a common scale in the integer
7570
# domain
@@ -89,7 +84,7 @@ def insert_rescale_ops_to_int32(
8984
return rescaled_nodes, min_scale
9085

9186

92-
def insert_rescale_op_to_int8(
87+
def insert_rescale_node_back_to_int8(
9388
tosa_graph: ts.TosaSerializer,
9489
last_tensor: TosaArg,
9590
scale: float,
@@ -107,14 +102,9 @@ def insert_rescale_op_to_int8(
107102
in the node meta dict as opposed to 'rescale_node_back_to_int8' which search
108103
the graph downstream for Q nodes.
109104
"""
110-
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
111-
get_output_qparams,
112-
)
113-
114-
output_qparams = get_output_qparams(node)
115-
assert len(output_qparams) == 1, "More than one output not supported"
105+
assert len(node.meta["output_qparams"]) == 1
116106

117-
qargs_out = output_qparams[0]
107+
qargs_out = cast(dict[int, QuantArgs], node.meta["output_qparams"])[0]
118108
output_rescale_scale = scale / qargs_out.scale
119109

120110
# Rescale Back to INT8
@@ -146,17 +136,6 @@ def quantize_value(self, x):
146136
def dequantize_value(self, qx: int) -> float:
147137
return (qx - self.zp) * self.scale
148138

149-
def __eq__(self, other):
150-
if isinstance(other, QuantArgs):
151-
return (
152-
self.scale == other.scale
153-
and self.zp == other.zp
154-
and self.qmin == other.qmin
155-
and self.qmax == other.qmax
156-
and self.dtype == other.dtype
157-
)
158-
return False
159-
160139
@classmethod
161140
def from_operator(cls, op, args):
162141
if op in dq_q_ops:

0 commit comments

Comments
 (0)