Skip to content

Commit ce9e326

Browse files
authored
Merge branch 'main' into export-D85704977
2 parents fce1940 + 1523606 commit ce9e326

File tree

81 files changed

+2645
-839
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

81 files changed

+2645
-839
lines changed

backends/arm/_passes/decompose_meandim_pass.py

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from executorch.backends.arm._passes.decompose_sum_pass import DecomposeSumPass
1414
from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT
1515
from executorch.backends.arm._passes.size_adjust_input_pass import SizeAdjustInputPass
16+
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
1617
from executorch.exir.backend.utils import WhyNoPartitionReporter
1718
from executorch.exir.dialects._ops import ops as exir_ops
1819
from executorch.exir.pass_base import ExportPass
@@ -50,6 +51,15 @@ def get_view(op):
5051
raise RuntimeError(f"Can't get meandim decomposition for op {op}")
5152

5253

54+
def get_quantization(op):
55+
"""Returns quant and dequant op of same type (per_channel/ tensor) as op if op is a dequant node, None otherwise."""
56+
if op in DQ_OPS:
57+
# Input of op can be placeholder, can't use that to get quant node directly.
58+
quant_type_index = DQ_OPS.index(op)
59+
return Q_OPS[quant_type_index], op
60+
return None
61+
62+
5363
class DecomposeMeanDimPass(ArmPass):
5464
"""
5565
Decomposes a meandim into avg_pool and/or sum + mul (1/N) depending on which dims the mean is taken for:
@@ -121,6 +131,7 @@ def call_operator(self, op, args, kwargs, meta):
121131
dims_to_reduce = [dim - 1 for dim in dims_to_reduce]
122132

123133
x = super().call_operator(view_op, (x, new_shape), {}, meta, True)
134+
x = self._maybe_insert_q_dq_after(x, meta)
124135

125136
# Reduce (h,w) dims by avg pool if possible
126137
x, dims_to_reduce = self._reduce_by_average_pool(op, x, dims_to_reduce, meta)
@@ -133,7 +144,7 @@ def call_operator(self, op, args, kwargs, meta):
133144
dims_to_reduce = [dim + len(original_dims) - 1 for dim in dims_to_reduce]
134145

135146
x = super().call_operator(view_op, (x, temp_shape), {}, meta, True)
136-
147+
x = self._maybe_insert_q_dq_after(x, meta)
137148
# Reduce remaining dims by sum
138149
x = self._reduce_by_sum(op, x, dims_to_reduce, meta, dtype)
139150

@@ -156,6 +167,45 @@ def _reduce_by_sum(self, op, input_node, dims, meta, dtype):
156167
full = super().call_operator(
157168
full_op, ([1] * len(output_shape), 1 / N), {"dtype": dtype}, meta, True
158169
)
170+
if (quant_ops := get_quantization(input_node.node.target)) is not None:
171+
# Insert Q and DQ nodes after full op.
172+
# Since the value of full is known, we can compute quant params such that dq(q_max_value)
173+
q_op, dq_op = quant_ops
174+
qmax = input_node.node.args[4]
175+
full_quant_args = (
176+
1 / (N * qmax), # Scale to map qmax to 1/N
177+
0, # Zero point
178+
*input_node.node.args[3:],
179+
)
180+
q_args = (full, *full_quant_args)
181+
full = super().call_operator(
182+
q_op,
183+
q_args,
184+
kwargs={},
185+
meta=meta,
186+
updated=True,
187+
)
188+
dq_args = (full, *full_quant_args)
189+
full = super().call_operator(
190+
dq_op, dq_args, kwargs={}, meta=meta, updated=True
191+
)
192+
193+
# Insert Q and DQ nodes after sum op.
194+
# Scale needs to be adjusted with N, since it was computed on data after the division with N.
195+
sum_quant_args = (input_node.node.args[1] * N, *input_node.node.args[2:])
196+
q_args = (sum, *sum_quant_args)
197+
sum = super().call_operator(
198+
q_op,
199+
q_args,
200+
kwargs={},
201+
meta=meta,
202+
updated=True,
203+
)
204+
dq_args = (sum, *sum_quant_args)
205+
sum = super().call_operator(
206+
dq_op, dq_args, kwargs={}, meta=meta, updated=True
207+
)
208+
159209
return super().call_operator(mul_op, (sum, full), {}, meta, True)
160210

161211
def _reduce_by_average_pool(self, op, input_node, dims, meta):
@@ -190,10 +240,38 @@ def _reduce_by_average_pool(self, op, input_node, dims, meta):
190240
)
191241

192242
if is_supported:
243+
out = super().call_operator(avgpool_op, args, {}, meta, True)
244+
out = self._maybe_insert_q_dq_after(out, meta)
193245
return (
194-
super().call_operator(avgpool_op, args, {}, meta, True),
246+
out,
195247
dims_to_reduce_by_sum,
196248
)
197249

198250
else:
199251
return input_node, dims
252+
253+
def _maybe_insert_q_dq_after(self, op, meta):
254+
"""If the input node of op is a dequant node, insert a q-dq pair after op with identical quantization parameters."""
255+
256+
if len(op.node.all_input_nodes) > 1:
257+
raise ValueError(
258+
f"Expected one input to {op.node}, got inputs {op.node.all_input_nodes}"
259+
)
260+
input_node = op.node.all_input_nodes[0]
261+
if (quant_ops := get_quantization(input_node.target)) is not None:
262+
q_op, dq_op = quant_ops
263+
quant_args = list(input_node.args[1:])
264+
q_args = (op, *quant_args)
265+
out = super().call_operator(
266+
q_op,
267+
q_args,
268+
kwargs={},
269+
meta=meta,
270+
updated=True,
271+
)
272+
dq_args = (out, *quant_args)
273+
return super().call_operator(
274+
dq_op, dq_args, kwargs={}, meta=meta, updated=True
275+
)
276+
else:
277+
return op

backends/arm/_passes/fuse_constant_ops_pass.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ def resolve_arg(arg):
6565
if isinstance(arg, torch.fx.Node) and arg in input_nodes:
6666
idx = input_nodes.index(arg)
6767
t = get_param_tensor(self.exported_program, arg)
68-
if qparams:
68+
# Check if qparams exist for this arg
69+
if qparams and idx in qparams.keys():
6970
t = qparams[idx].dequantize_value(t)
7071
return t
7172
if isinstance(arg, tuple):

backends/arm/operators/op_avg_pool2d.py

Lines changed: 10 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class AvgPool2dVisitor(NodeVisitor):
3333

3434
tosa_specs = [
3535
TosaSpecification.create_from_string("TOSA-1.0+INT"),
36+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
3637
]
3738

3839
def __init__(self, *args):
@@ -105,43 +106,6 @@ def _build_generic_avgpool2d(
105106
attr,
106107
)
107108

108-
def define_node(
109-
self,
110-
node: torch.fx.Node,
111-
tosa_graph: Any,
112-
inputs: List[TosaArg],
113-
output: TosaArg,
114-
) -> None:
115-
validate_num_inputs(self.target, inputs, [3, 4, 5, 6, 7])
116-
validate_same_dtype(self.target, [inputs[0], output], ts)
117-
validate_valid_dtype(
118-
self.target, [inputs[0], output], ts.DType.INT8, output.tosa_spec
119-
)
120-
121-
accumulator_type = ts.DType.INT32
122-
123-
input_qargs = get_input_qparams(node)
124-
input_zp = input_qargs[0].get_zp_per_tensor()
125-
126-
output_qargs = get_output_qparams(node)
127-
output_zp = output_qargs[0].get_zp_per_tensor()
128-
129-
self._build_generic_avgpool2d(
130-
node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type
131-
)
132-
133-
134-
@register_node_visitor
135-
class AvgPool2dVisitor_FP(AvgPool2dVisitor):
136-
target = "aten.avg_pool2d.default"
137-
138-
tosa_specs = [
139-
TosaSpecification.create_from_string("TOSA-1.0+FP"),
140-
]
141-
142-
def __init__(self, *args):
143-
super().__init__(*args)
144-
145109
def define_node(
146110
self,
147111
node: torch.fx.Node,
@@ -159,14 +123,17 @@ def define_node(
159123
)
160124

161125
if inputs[0].dtype == ts.DType.INT8:
162-
super().define_node(node, tosa_graph, inputs, output)
126+
accumulator_type = ts.DType.INT32
127+
input_qargs = get_input_qparams(node)
128+
input_zp = input_qargs[0].get_zp_per_tensor()
163129

164-
if inputs[0].dtype == ts.DType.FP32:
130+
output_qargs = get_output_qparams(node)
131+
output_zp = output_qargs[0].get_zp_per_tensor()
132+
else:
165133
accumulator_type = ts.DType.FP32
166-
# Initilize zero point to zero.
167134
input_zp = 0
168135
output_zp = 0
169136

170-
self._build_generic_avgpool2d(
171-
node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type
172-
)
137+
self._build_generic_avgpool2d(
138+
node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type
139+
)

backends/arm/operators/op_clamp.py

Lines changed: 29 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# Copyright 2025 Arm Limited and/or its affiliates.
2-
# All rights reserved.
32
#
43
# This source code is licensed under the BSD-style license found in the
54
# LICENSE file in the root directory of this source tree
@@ -27,18 +26,19 @@
2726

2827

2928
@register_node_visitor
30-
class ClampVisitor_INT(NodeVisitor):
29+
class ClampVisitor(NodeVisitor):
3130
target = "aten.clamp.default"
3231

3332
tosa_specs = [
3433
TosaSpecification.create_from_string("TOSA-1.0+INT"),
34+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
3535
]
3636

3737
def __init__(self, *args):
3838
super().__init__(*args)
3939

4040
def _get_min_max_arguments(
41-
self, node: Node, dtype_min: int | float, dtype_max: int | float
41+
self, node: Node, dtype: torch.dtype
4242
) -> Tuple[int | float, int | float]:
4343

4444
def cast_type(value: Any) -> int | float:
@@ -48,6 +48,13 @@ def cast_type(value: Any) -> int | float:
4848
# Attempt to cast to float
4949
return float(value)
5050

51+
if dtype.is_floating_point:
52+
dtype_min = torch.finfo(dtype).min
53+
dtype_max = torch.finfo(dtype).max
54+
else:
55+
dtype_min = torch.iinfo(dtype).min
56+
dtype_max = torch.iinfo(dtype).max
57+
5158
min_arg = dtype_min
5259
max_arg = dtype_max
5360

@@ -60,53 +67,15 @@ def cast_type(value: Any) -> int | float:
6067

6168
return min_arg, max_arg
6269

63-
def define_node(
64-
self,
65-
node: Node,
66-
tosa_graph: Any,
67-
inputs: List[TosaArg],
68-
output: TosaArg,
69-
) -> None:
70-
validate_num_inputs(self.target, inputs, [2, 3])
71-
validate_same_dtype(self.target, [inputs[0], output], ts)
72-
validate_valid_dtype(
73-
self.target, [inputs[0], output], [ts.DType.INT8], output.tosa_spec
74-
)
75-
76-
# NOTE: Quantization of the min/max arguments is handled by QuantizeOperatorArguments
77-
min_int8, max_int8 = self._get_min_max_arguments(
78-
node,
79-
torch.iinfo(torch.int8).min,
80-
torch.iinfo(torch.int8).max,
81-
)
82-
83-
attr = ts.TosaSerializerAttribute()
84-
attr.ClampAttribute(
85-
np.frombuffer(np.int8(min_int8).tobytes(), dtype=np.uint8).tolist(),
86-
np.frombuffer(np.int8(max_int8).tobytes(), dtype=np.uint8).tolist(),
87-
ts.NanPropagationMode.PROPAGATE,
88-
)
89-
90-
self._serialize_operator(
91-
node,
92-
tosa_graph,
93-
ts.Op.CLAMP,
94-
[inputs[0].name],
95-
[output.name],
96-
attr,
97-
)
98-
99-
100-
@register_node_visitor
101-
class ClampVisitor_FP(ClampVisitor_INT):
102-
# inheriting 'target' from INT class
103-
104-
tosa_specs = [
105-
TosaSpecification.create_from_string("TOSA-1.0+FP"),
106-
]
107-
108-
def __init__(self, *args):
109-
super().__init__(*args)
70+
def _to_bytes(self, value: int | float, dtype: torch.dtype) -> bytes:
71+
if dtype == torch.float32:
72+
return np.frombuffer(np.float32(value).tobytes(), dtype=np.uint8).tolist()
73+
elif dtype == torch.float16:
74+
return np.frombuffer(np.float16(value).tobytes(), dtype=np.uint8).tolist()
75+
elif dtype == torch.int8:
76+
return np.frombuffer(np.int8(value).tobytes(), dtype=np.uint8).tolist()
77+
else:
78+
raise ValueError(f"Unsupported dtype for to_bytes: {dtype}")
11079

11180
def define_node(
11281
self,
@@ -120,42 +89,20 @@ def define_node(
12089
validate_valid_dtype(
12190
self.target,
12291
[inputs[0], output],
123-
[ts.DType.FP16, ts.DType.FP32],
92+
[ts.DType.INT8, ts.DType.FP16, ts.DType.FP32],
12493
output.tosa_spec,
12594
)
12695

96+
node_input_dtype = node.meta["val"].dtype
97+
# NOTE: Quantization of the min/max arguments is handled by QuantizeOperatorArguments
98+
min_val, max_val = self._get_min_max_arguments(node, node_input_dtype)
99+
127100
attr = ts.TosaSerializerAttribute()
128-
match inputs[0].dtype:
129-
case ts.DType.FP16:
130-
min_f, max_f = self._get_min_max_arguments(
131-
node,
132-
torch.finfo(torch.float16).min,
133-
torch.finfo(torch.float16).max,
134-
)
135-
min_bytes = np.frombuffer(
136-
np.float16(min_f).tobytes(), dtype=np.uint8
137-
).tolist()
138-
max_bytes = np.frombuffer(
139-
np.float16(max_f).tobytes(), dtype=np.uint8
140-
).tolist()
141-
case ts.DType.FP32:
142-
min_f, max_f = self._get_min_max_arguments(
143-
node,
144-
torch.finfo(torch.float32).min,
145-
torch.finfo(torch.float32).max,
146-
)
147-
min_bytes = np.frombuffer(
148-
np.float32(min_f).tobytes(), dtype=np.uint8
149-
).tolist()
150-
max_bytes = np.frombuffer(
151-
np.float32(max_f).tobytes(), dtype=np.uint8
152-
).tolist()
153-
case _:
154-
raise RuntimeError(
155-
f"Internal error: Unsupported dtype {inputs[0].dtype} in {self.target}"
156-
)
157-
158-
attr.ClampAttribute(min_bytes, max_bytes, ts.NanPropagationMode.PROPAGATE)
101+
attr.ClampAttribute(
102+
self._to_bytes(min_val, node_input_dtype),
103+
self._to_bytes(max_val, node_input_dtype),
104+
nan_mode=ts.NanPropagationMode.PROPAGATE,
105+
)
159106

160107
self._serialize_operator(
161108
node,

0 commit comments

Comments
 (0)