Skip to content

Commit d9d5fd8

Browse files
Merge branch 'main' into upcast-index-to-int64-for-index_copy
2 parents 0b3ddd2 + bf79544 commit d9d5fd8

23 files changed

+806
-264
lines changed

backends/arm/_passes/decompose_meandim_pass.py

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from executorch.backends.arm._passes.decompose_sum_pass import DecomposeSumPass
1414
from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT
1515
from executorch.backends.arm._passes.size_adjust_input_pass import SizeAdjustInputPass
16+
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
1617
from executorch.exir.backend.utils import WhyNoPartitionReporter
1718
from executorch.exir.dialects._ops import ops as exir_ops
1819
from executorch.exir.pass_base import ExportPass
@@ -50,6 +51,15 @@ def get_view(op):
5051
raise RuntimeError(f"Can't get meandim decomposition for op {op}")
5152

5253

54+
def get_quantization(op):
55+
"""Returns quant and dequant op of same type (per_channel/ tensor) as op if op is a dequant node, None otherwise."""
56+
if op in DQ_OPS:
57+
# Input of op can be placeholder, can't use that to get quant node directly.
58+
quant_type_index = DQ_OPS.index(op)
59+
return Q_OPS[quant_type_index], op
60+
return None
61+
62+
5363
class DecomposeMeanDimPass(ArmPass):
5464
"""
5565
Decomposes a meandim into avg_pool and/or sum + mul (1/N) depending on which dims the mean is taken for:
@@ -121,6 +131,7 @@ def call_operator(self, op, args, kwargs, meta):
121131
dims_to_reduce = [dim - 1 for dim in dims_to_reduce]
122132

123133
x = super().call_operator(view_op, (x, new_shape), {}, meta, True)
134+
x = self._maybe_insert_q_dq_after(x, meta)
124135

125136
# Reduce (h,w) dims by avg pool if possible
126137
x, dims_to_reduce = self._reduce_by_average_pool(op, x, dims_to_reduce, meta)
@@ -133,7 +144,7 @@ def call_operator(self, op, args, kwargs, meta):
133144
dims_to_reduce = [dim + len(original_dims) - 1 for dim in dims_to_reduce]
134145

135146
x = super().call_operator(view_op, (x, temp_shape), {}, meta, True)
136-
147+
x = self._maybe_insert_q_dq_after(x, meta)
137148
# Reduce remaining dims by sum
138149
x = self._reduce_by_sum(op, x, dims_to_reduce, meta, dtype)
139150

@@ -156,6 +167,45 @@ def _reduce_by_sum(self, op, input_node, dims, meta, dtype):
156167
full = super().call_operator(
157168
full_op, ([1] * len(output_shape), 1 / N), {"dtype": dtype}, meta, True
158169
)
170+
if (quant_ops := get_quantization(input_node.node.target)) is not None:
171+
# Insert Q and DQ nodes after full op.
172+
# Since the value of full is known, we can compute quant params such that dq(q_max_value)
173+
q_op, dq_op = quant_ops
174+
qmax = input_node.node.args[4]
175+
full_quant_args = (
176+
1 / (N * qmax), # Scale to map qmax to 1/N
177+
0, # Zero point
178+
*input_node.node.args[3:],
179+
)
180+
q_args = (full, *full_quant_args)
181+
full = super().call_operator(
182+
q_op,
183+
q_args,
184+
kwargs={},
185+
meta=meta,
186+
updated=True,
187+
)
188+
dq_args = (full, *full_quant_args)
189+
full = super().call_operator(
190+
dq_op, dq_args, kwargs={}, meta=meta, updated=True
191+
)
192+
193+
# Insert Q and DQ nodes after sum op.
194+
# Scale needs to be adjusted with N, since it was computed on data after the division with N.
195+
sum_quant_args = (input_node.node.args[1] * N, *input_node.node.args[2:])
196+
q_args = (sum, *sum_quant_args)
197+
sum = super().call_operator(
198+
q_op,
199+
q_args,
200+
kwargs={},
201+
meta=meta,
202+
updated=True,
203+
)
204+
dq_args = (sum, *sum_quant_args)
205+
sum = super().call_operator(
206+
dq_op, dq_args, kwargs={}, meta=meta, updated=True
207+
)
208+
159209
return super().call_operator(mul_op, (sum, full), {}, meta, True)
160210

161211
def _reduce_by_average_pool(self, op, input_node, dims, meta):
@@ -190,10 +240,38 @@ def _reduce_by_average_pool(self, op, input_node, dims, meta):
190240
)
191241

192242
if is_supported:
243+
out = super().call_operator(avgpool_op, args, {}, meta, True)
244+
out = self._maybe_insert_q_dq_after(out, meta)
193245
return (
194-
super().call_operator(avgpool_op, args, {}, meta, True),
246+
out,
195247
dims_to_reduce_by_sum,
196248
)
197249

198250
else:
199251
return input_node, dims
252+
253+
def _maybe_insert_q_dq_after(self, op, meta):
254+
"""If the input node of op is a dequant node, insert a q-dq pair after op with identical quantization parameters."""
255+
256+
if len(op.node.all_input_nodes) > 1:
257+
raise ValueError(
258+
f"Expected one input to {op.node}, got inputs {op.node.all_input_nodes}"
259+
)
260+
input_node = op.node.all_input_nodes[0]
261+
if (quant_ops := get_quantization(input_node.target)) is not None:
262+
q_op, dq_op = quant_ops
263+
quant_args = list(input_node.args[1:])
264+
q_args = (op, *quant_args)
265+
out = super().call_operator(
266+
q_op,
267+
q_args,
268+
kwargs={},
269+
meta=meta,
270+
updated=True,
271+
)
272+
dq_args = (out, *quant_args)
273+
return super().call_operator(
274+
dq_op, dq_args, kwargs={}, meta=meta, updated=True
275+
)
276+
else:
277+
return op

backends/arm/test/ops/test_avg_pool2d.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
VgfPipeline,
2424
)
2525

26-
aten_op = "torch.ops.aten.avg_pool2d.default"
26+
aten_op = "avg_pool2d.default"
2727
exir_op = "executorch_exir_dialects_edge__ops_aten_avg_pool2d_default"
2828

2929
input_t = Tuple[torch.Tensor]
@@ -34,6 +34,15 @@ def forward(self, *args, **kwargs):
3434
return super().forward(*args, **kwargs)
3535

3636

37+
class BecomesMeanInToEdge(torch.nn.Module):
38+
"""This averagepool will be converted to mean when lowering to edge. This causes the decompose_meandim pass to not
39+
trigger until the backend pipeline, which requires extra care.
40+
"""
41+
42+
def forward(self, x: torch.Tensor):
43+
return torch.nn.functional.adaptive_avg_pool2d(x, (1, 1))
44+
45+
3746
test_modules = {
3847
"zeros": lambda: (AvgPool2d(4, 2, 0, False), (torch.zeros(1, 16, 50, 32),)),
3948
"ones": lambda: (AvgPool2d(4, 2, 0, False, True), (torch.ones(1, 16, 50, 32),)),
@@ -110,6 +119,9 @@ def forward(self, *args, **kwargs):
110119
AvgPool2d(3, (1, 3), 1, count_include_pad=False),
111120
(torch.rand(1, 16, 54, 54),),
112121
),
122+
"becomes_mean_rank3": lambda: (BecomesMeanInToEdge(), (torch.rand(2, 8, 8),)),
123+
"becomes_mean_rank4": lambda: (BecomesMeanInToEdge(), (torch.rand(2, 2, 8, 8),)),
124+
"becomes_mean_rank5": lambda: (BecomesMeanInToEdge(), (torch.rand(2, 2, 8, 8),)),
113125
}
114126

115127

backends/arm/test/tester/analyze_output_utils.py

Lines changed: 74 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import logging
77
import tempfile
8+
from typing import Any, cast, Sequence
89

910
import torch
1011
from executorch.backends.arm.test.runner_utils import (
@@ -17,9 +18,30 @@
1718
logger = logging.getLogger(__name__)
1819

1920

20-
def _print_channels(result, reference, channels_close, C, H, W, rtol, atol):
21+
TensorLike = torch.Tensor | tuple[torch.Tensor, ...]
22+
23+
24+
def _ensure_tensor(value: TensorLike) -> torch.Tensor:
25+
if isinstance(value, torch.Tensor):
26+
return value
27+
if value and isinstance(value[0], torch.Tensor):
28+
return value[0]
29+
raise TypeError("Expected a Tensor or a non-empty tuple of Tensors")
30+
31+
32+
def _print_channels(
33+
result: torch.Tensor,
34+
reference: torch.Tensor,
35+
channels_close: Sequence[bool],
36+
C: int,
37+
H: int,
38+
W: int,
39+
rtol: float,
40+
atol: float,
41+
) -> str:
2142

2243
output_str = ""
44+
exp = "000"
2345
booldata = False
2446
if reference.dtype == torch.bool or result.dtype == torch.bool:
2547
booldata = True
@@ -62,7 +84,15 @@ def _print_channels(result, reference, channels_close, C, H, W, rtol, atol):
6284
return output_str
6385

6486

65-
def _print_elements(result, reference, C, H, W, rtol, atol):
87+
def _print_elements(
88+
result: torch.Tensor,
89+
reference: torch.Tensor,
90+
C: int,
91+
H: int,
92+
W: int,
93+
rtol: float,
94+
atol: float,
95+
) -> str:
6696
output_str = ""
6797
for y in range(H):
6898
res = "["
@@ -92,14 +122,16 @@ def _print_elements(result, reference, C, H, W, rtol, atol):
92122

93123

94124
def print_error_diffs(
95-
tester,
96-
result: torch.Tensor | tuple,
97-
reference: torch.Tensor | tuple,
98-
quantization_scale=None,
99-
atol=1e-03,
100-
rtol=1e-03,
101-
qtol=0,
102-
):
125+
tester_or_result: Any,
126+
result_or_reference: TensorLike,
127+
reference: TensorLike | None = None,
128+
# Force remaining args to be keyword-only to keep the two positional call patterns unambiguous.
129+
*,
130+
quantization_scale: float | None = None,
131+
atol: float = 1e-03,
132+
rtol: float = 1e-03,
133+
qtol: float = 0,
134+
) -> None:
103135
"""
104136
Prints the error difference between a result tensor and a reference tensor in NCHW format.
105137
Certain formatting rules are applied to clarify errors:
@@ -130,15 +162,16 @@ def print_error_diffs(
130162
131163
132164
"""
133-
134-
if isinstance(reference, tuple):
135-
reference = reference[0]
136-
if isinstance(result, tuple):
137-
result = result[0]
138-
139-
if not result.shape == reference.shape:
165+
if reference is None:
166+
result = _ensure_tensor(cast(TensorLike, tester_or_result))
167+
reference_tensor = _ensure_tensor(result_or_reference)
168+
else:
169+
result = _ensure_tensor(result_or_reference)
170+
reference_tensor = _ensure_tensor(reference)
171+
172+
if result.shape != reference_tensor.shape:
140173
raise ValueError(
141-
f"Output needs to be of same shape: {result.shape} != {reference.shape}"
174+
f"Output needs to be of same shape: {result.shape} != {reference_tensor.shape}"
142175
)
143176
shape = result.shape
144177

@@ -161,29 +194,29 @@ def print_error_diffs(
161194

162195
# Reshape tensors to 4D NCHW format
163196
result = torch.reshape(result, (N, C, H, W))
164-
reference = torch.reshape(reference, (N, C, H, W))
197+
reference_tensor = torch.reshape(reference_tensor, (N, C, H, W))
165198

166199
output_str = ""
167200
for n in range(N):
168201
output_str += f"BATCH {n}\n"
169202
result_batch = result[n, :, :, :]
170-
reference_batch = reference[n, :, :, :]
203+
reference_batch = reference_tensor[n, :, :, :]
171204

172205
is_close = torch.allclose(result_batch, reference_batch, rtol, atol)
173206
if is_close:
174207
output_str += ".\n"
175208
else:
176-
channels_close = [None] * C
209+
channels_close: list[bool] = [False] * C
177210
for c in range(C):
178211
result_hw = result[n, c, :, :]
179-
reference_hw = reference[n, c, :, :]
212+
reference_hw = reference_tensor[n, c, :, :]
180213

181214
channels_close[c] = torch.allclose(result_hw, reference_hw, rtol, atol)
182215

183216
if any(channels_close) or len(channels_close) == 1:
184217
output_str += _print_channels(
185218
result[n, :, :, :],
186-
reference[n, :, :, :],
219+
reference_tensor[n, :, :, :],
187220
channels_close,
188221
C,
189222
H,
@@ -193,17 +226,23 @@ def print_error_diffs(
193226
)
194227
else:
195228
output_str += _print_elements(
196-
result[n, :, :, :], reference[n, :, :, :], C, H, W, rtol, atol
229+
result[n, :, :, :],
230+
reference_tensor[n, :, :, :],
231+
C,
232+
H,
233+
W,
234+
rtol,
235+
atol,
197236
)
198237
if reference_batch.dtype == torch.bool or result_batch.dtype == torch.bool:
199238
mismatches = (reference_batch != result_batch).sum().item()
200239
total = reference_batch.numel()
201240
output_str += f"(BOOLEAN tensor) {mismatches} / {total} elements differ ({mismatches / total:.2%})\n"
202241

203242
# Only compute numeric error metrics if tensor is not boolean
204-
if reference.dtype != torch.bool and result.dtype != torch.bool:
205-
reference_range = torch.max(reference) - torch.min(reference)
206-
diff = torch.abs(reference - result).flatten()
243+
if reference_tensor.dtype != torch.bool and result.dtype != torch.bool:
244+
reference_range = torch.max(reference_tensor) - torch.min(reference_tensor)
245+
diff = torch.abs(reference_tensor - result).flatten()
207246
diff = diff[diff.nonzero()]
208247
if not len(diff) == 0:
209248
diff_percent = diff / reference_range
@@ -230,14 +269,14 @@ def print_error_diffs(
230269

231270

232271
def dump_error_output(
233-
tester,
234-
reference_output,
235-
stage_output,
236-
quantization_scale=None,
237-
atol=1e-03,
238-
rtol=1e-03,
239-
qtol=0,
240-
):
272+
tester: Any,
273+
reference_output: TensorLike,
274+
stage_output: TensorLike,
275+
quantization_scale: float | None = None,
276+
atol: float = 1e-03,
277+
rtol: float = 1e-03,
278+
qtol: float = 0,
279+
) -> None:
241280
"""
242281
Prints Quantization info and error tolerances, and saves the differing tensors to disc.
243282
"""

0 commit comments

Comments
 (0)