Skip to content

Commit dd18bfe

Browse files
committed
Update
[ghstack-poisoned]
2 parents 0d164ce + 02d6d67 commit dd18bfe

File tree

14 files changed

+138
-53
lines changed

14 files changed

+138
-53
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ Platform Support:
2929
- Arm
3030
- Cadence
3131
- MediaTek
32+
- NXP
3233
- OpenVINO
3334
- Qualcomm
3435
- Vulkan

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@
6767
)
6868
from .insert_rescales_pass import InsertRescalePass # noqa
6969
from .insert_table_ops import InsertTableOpsPass # noqa
70+
from .match_arg_dtype_pass import MatchArgDtypePass # noqa
7071
from .match_arg_ranks_pass import MatchArgRanksPass # noqa
71-
from .match_where_self_arg_dtype_pass import MatchWhereSelfDtypePass # noqa
7272
from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa
7373
from .remove_clone_pass import RemoveClonePass # noqa
7474
from .replace_scalar_with_tensor_pass import ( # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@
6666
InsertCastForOpsWithInt64InputPass,
6767
InsertRescalePass,
6868
InsertTableOpsPass,
69+
MatchArgDtypePass,
6970
MatchArgRanksPass,
70-
MatchWhereSelfDtypePass,
7171
QuantizeOperatorArguments,
7272
RemoveClonePass,
7373
ReplaceInfValues,
@@ -116,7 +116,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
116116
self.add_pass(ConvertToClampPass())
117117
self.add_pass(ConvertMinMaxPass())
118118
self.add_pass(ConvertAnyDefaultDimDimsPass())
119-
self.add_pass(MatchWhereSelfDtypePass())
119+
self.add_pass(MatchArgDtypePass())
120120
if self.tosa_spec.is_U55_subset:
121121
self.add_pass(CastToInt32Pass())
122122

@@ -193,8 +193,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
193193
self.add_pass(ConvertToClampPass())
194194
self.add_pass(ConvertMinMaxPass())
195195
self.add_pass(ConvertAnyDefaultDimDimsPass())
196-
self.add_pass(MatchWhereSelfDtypePass())
197-
196+
self.add_pass(MatchArgDtypePass())
198197
self.add_pass(AnnotateDecomposedMatmulPass())
199198
self.add_pass(QuantizeOperatorArguments())
200199
self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg]

backends/arm/_passes/decompose_grouped_conv.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from copy import copy
77

88
import torch
9+
from executorch.backends.arm.tosa_quant_utils import QuantArgs
910
from executorch.exir.dialects._ops import ops as exir_ops
1011
from executorch.exir.pass_base import ExportPass
1112

@@ -48,7 +49,40 @@ def _get_decomposition(op):
4849
torch.ops.aten.cat.default,
4950
)
5051
case _:
51-
raise RuntimeError("Unvalid op for grouped conv decomposition.")
52+
raise RuntimeError("Invalid op for grouped conv decomposition")
53+
54+
@staticmethod
55+
def _split_per_channel_qparams(qarg, index, output_slice_size):
56+
if qarg is not None and qarg.per_channel:
57+
start_index = index * output_slice_size
58+
stop_index = (index + 1) * output_slice_size
59+
return QuantArgs(
60+
scale=qarg.scale[start_index:stop_index],
61+
zp=qarg.zp[start_index:stop_index],
62+
qmin=qarg.qmin,
63+
qmax=qarg.qmax,
64+
dtype=qarg.dtype,
65+
axis=qarg.axis,
66+
per_channel=qarg.per_channel,
67+
)
68+
return qarg
69+
70+
@staticmethod
71+
def _get_meta_copy(meta, i, output_slice_size):
72+
meta_copy = meta.copy()
73+
if "input_qparams" in meta.data and len(meta.data["input_qparams"]) > 0:
74+
# Handle per-channel quantization by splitting quantization params
75+
# similarly to how activations/weights/biases are split.
76+
new_qparams = meta.data.get("input_qparams").copy()
77+
# Get quantization params of the weights and slice them.
78+
qarg = new_qparams[1]
79+
new_qparams[1] = DecomposeGroupedConv._split_per_channel_qparams(
80+
qarg, index=i, output_slice_size=output_slice_size
81+
)
82+
83+
meta_copy.data["input_qparams"] = new_qparams
84+
85+
return meta_copy
5286

5387
def call_operator(self, op, args, kwargs, meta):
5488
if op == exir_ops.edge.aten.convolution.default:
@@ -105,7 +139,6 @@ def call_operator(self, op, args, kwargs, meta):
105139
if bias_node is None:
106140
bias_slices.append(None)
107141
else:
108-
109142
start_index = i * output_slice_size
110143
stop_index = (i + 1) * output_slice_size
111144
slice_args = (bias_node, 0, start_index, stop_index)
@@ -115,20 +148,23 @@ def call_operator(self, op, args, kwargs, meta):
115148
)
116149

117150
output_slices = []
118-
for input_slice, filter_slice, bias_slice in zip(
119-
input_slices, filter_slices, bias_slices
151+
for i, (input_slice, filter_slice, bias_slice) in enumerate(
152+
zip(input_slices, filter_slices, bias_slices)
120153
):
121154

155+
meta_copy = DecomposeGroupedConv._get_meta_copy(meta, i, output_slice_size)
156+
122157
if op == exir_ops.edge.aten.convolution.default:
123158
conv_args = (input_slice, filter_slice, bias_slice, *args[3:8], 1)
124159
elif op == torch.ops.aten.conv2d.default:
125160
conv_args = (input_slice, filter_slice, bias_slice, *args[3:6], 1)
126161
else:
127-
raise RuntimeError("Unvalid op for grouped conv decomposition.")
162+
raise RuntimeError("Invalid op for grouped conv decomposition")
128163

129164
output_slices.append(
130-
super().call_operator(conv_op, conv_args, kwargs, meta)
165+
super().call_operator(conv_op, conv_args, kwargs, meta_copy)
131166
)
132167

133168
cat_args = (output_slices, 1)
134-
return super().call_operator(cat_op, cat_args, kwargs, no_q_dq_meta)
169+
# propagate original metadata (including quantization params) to the concatenated output
170+
return super().call_operator(cat_op, cat_args, kwargs, meta)

backends/arm/_passes/decompose_meandim_pass.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from copy import copy
67
from math import prod
78

89
import torch
@@ -75,35 +76,47 @@ def call_operator(self, op, args, kwargs, meta):
7576
return super().call_operator(op, args, kwargs, meta)
7677

7778
x = get_node_arg(args, 0)
78-
input_shape = x.data.size()
79-
output_shape = meta["val"].size()
79+
input_shape = list(x.data.shape)
80+
output_shape = list(meta["val"].shape)
8081
dims_to_reduce = get_node_arg(args, 1)
8182
dims_to_reduce = [dim % len(input_shape) for dim in dims_to_reduce]
83+
dims_to_reduce = [dim for dim in dims_to_reduce if input_shape[dim] != 1]
8284

8385
dtype = meta["val"].dtype
8486
view_op = get_view(op)
8587

86-
if len(input_shape) > 4:
87-
raise NotImplementedError(
88-
f"{op} with rank > 4 is currently not supported for the TOSA backend."
89-
)
88+
# Reshape to 4D
89+
if len(input_shape) != 4:
90+
new_shape = copy(input_shape)
91+
92+
while len(new_shape) < 4:
93+
new_shape.insert(0, 1)
94+
dims_to_reduce = [dim + 1 for dim in dims_to_reduce]
9095

91-
# Unsqueeze to 4D
92-
if len(input_shape) < 4:
93-
pad_n = 4 - len(input_shape)
94-
new_shape = [1] * pad_n + list(input_shape)
95-
dims_to_reduce = [dim + pad_n for dim in dims_to_reduce]
96+
while len(new_shape) > 4:
97+
i = new_shape.pop(0)
98+
new_shape[0] = new_shape[0] * i
99+
dims_to_reduce = [dim - 1 for dim in dims_to_reduce]
96100

97101
x = super().call_operator(view_op, (x, new_shape), {}, meta, True)
98102

99103
# Reduce (h,w) dims by avg pool if possible
100104
x, dims_to_reduce = self._reduce_by_average_pool(op, x, dims_to_reduce, meta)
101105

106+
# Reshape back to 5D if necessary
107+
if len(input_shape) > 4:
108+
original_dims = input_shape[0:-4]
109+
temp_shape = list(x.data.shape)[1:]
110+
temp_shape = original_dims + temp_shape
111+
dims_to_reduce = [dim + len(original_dims) - 1 for dim in dims_to_reduce]
112+
113+
x = super().call_operator(view_op, (x, temp_shape), {}, meta, True)
114+
102115
# Reduce remaining dims by sum
103116
x = self._reduce_by_sum(op, x, dims_to_reduce, meta, dtype)
104117

105118
# Reshape to correct output shape if necessary
106-
if x.data.size() != output_shape:
119+
if list(x.data.shape) != output_shape:
107120
x = super().call_operator(view_op, (x, output_shape), {}, meta, True)
108121

109122
return x

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class FoldAndAnnotateQParamsPass(ArmPass):
7575
node.
7676
The quantization parameters from the DQ/Q nodes are stored as meta values to be
7777
accessible for later lowering and serialization passes.
78-
The assumption is that the quantization annotatation adds DQ nodes for all tensor
78+
The assumption is that the quantization annotation adds DQ nodes for all tensor
7979
inputs to the target one Q node to the output.
8080
8181
Example ('executorch_exir_dialects_edge__ops_' prefix removed from operators for readability):
@@ -95,7 +95,7 @@ class FoldAndAnnotateQParamsPass(ArmPass):
9595
9696
output_dq: "f32[5]" = quantized_decomposed_dequantize_per_tensor_default(aten_add_tensor_q, 0.05487706884741783, -128, -128, 127, torch.int8)
9797
98-
The quantization parameters for x_dq and aten_add_tensor_q are store in meta for the aten_add_tensor node.
98+
The quantization parameters for x_dq and aten_add_tensor_q are stored in meta for the aten_add_tensor node.
9999
100100
"""
101101

@@ -132,7 +132,7 @@ def fold_and_annotate_arg(
132132
nodes_to_remove.add(arg)
133133
if input_qparams is not None and input_qparams != arg_quant_params:
134134
# Two args are quantized differently
135-
raise RuntimeError("Input qparams does not match!")
135+
raise RuntimeError("Input qparams do not match")
136136
input_qparams = arg_quant_params
137137
if input_qparams is not None:
138138
node.meta["input_qparams"][i] = input_qparams

backends/arm/_passes/match_where_self_arg_dtype_pass.py renamed to backends/arm/_passes/match_arg_dtype_pass.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import torch
7-
from executorch.backends.arm._passes.arm_pass_utils import create_node
7+
from executorch.backends.arm._passes.arm_pass_utils import create_node, get_node_arg
88
from executorch.exir.dialects._ops import ops as exir_ops
99
from executorch.exir.pass_base import ExportPass, PassResult
1010

@@ -26,7 +26,7 @@ def get_largest_dtype(dtype_1, dtype_2):
2626
return dtype_1 if DTYPE_RANK[dtype_1] > DTYPE_RANK[dtype_2] else dtype_2
2727

2828

29-
class MatchWhereSelfDtypePass(ExportPass):
29+
class MatchArgDtypePass(ExportPass):
3030
"""Pass to match data types of non-condition input tensors.
3131
3232
Edge dialect allows different data types for non-condition tensors, while TOSA
@@ -38,14 +38,18 @@ class MatchWhereSelfDtypePass(ExportPass):
3838
3939
"""
4040

41+
targeted_ops = {exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.where.self}
42+
4143
def call(self, graph_module: torch.fx.GraphModule):
4244
modified_graph = False
4345
graph = graph_module.graph
44-
node_list = graph.find_nodes(
45-
op="call_function", target=exir_ops.edge.aten.where.self
46-
)
47-
for node in node_list:
48-
cond, input_, other_ = node.args
46+
47+
for node in list(graph.nodes):
48+
if node.op != "call_function" or node.target not in self.targeted_ops:
49+
continue
50+
51+
input_ = get_node_arg(node.args, 0)
52+
other_ = get_node_arg(node.args, 1)
4953

5054
input_dtype = input_.meta["val"].dtype
5155
other_dtype = other_.meta["val"].dtype

backends/arm/test/ops/test_conv2d.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -385,8 +385,6 @@ def forward(self, x):
385385
f"{k},per_channel_quant={q}": (lambda v=v, q=q: (v(), q))
386386
for (k, v) in test_data_MI.items()
387387
for q in [True, False]
388-
# TODO: Invalid TOSA graph (MLETORCH-1144)
389-
if (k not in ["groups", "groups_bias"]) and (q is True)
390388
}
391389

392390
fvp_xfails = {

backends/arm/test/ops/test_mean_dim.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,21 @@ class MeanDim(torch.nn.Module):
195195
(-4, -3, -2, -1),
196196
False,
197197
),
198+
"rank5_01234": lambda: (
199+
torch.rand(1, 1, 7, 3, 2),
200+
(-5, -4, -3, -2, -1),
201+
False,
202+
),
203+
"rank5_234": lambda: (
204+
torch.rand(1, 1, 7, 3, 2),
205+
(-3, -2, -1),
206+
False,
207+
),
208+
"rank5_12": lambda: (
209+
torch.rand(1, 1, 7, 3, 2),
210+
(1, 2),
211+
False,
212+
),
198213
"u55_avg_pool_not_supported": lambda: (
199214
torch.rand(1, 1, 1, 257),
200215
(0, 1, 2, 3),
@@ -236,7 +251,14 @@ def test_mean_dim_tosa_BI(test_data):
236251
pipeline.run()
237252

238253

239-
@common.parametrize("test_data", MeanDim.test_data_suite)
254+
xfails = {
255+
"rank5_01234": "Rank 5 graph input currently not supported in EthosUBackend (passes since CHW are all averaged over so data order does not matter in this case)",
256+
"rank5_234": "Rank 5 graph input currently not supported in EthosUBackend (passes since CHW are all averaged over so data order does not matter in this case)",
257+
"rank5_12": "Rank 5 graph input currently not supported in EthosUBackend",
258+
}
259+
260+
261+
@common.parametrize("test_data", MeanDim.test_data_suite, xfails=xfails, strict=False)
240262
@common.XfailIfNoCorstone300
241263
def test_mean_dim_u55_BI(test_data):
242264
test_data, dim, keep_dim = test_data()
@@ -256,7 +278,7 @@ def test_mean_dim_u55_BI(test_data):
256278
pipeline.run()
257279

258280

259-
@common.parametrize("test_data", MeanDim.test_data_suite)
281+
@common.parametrize("test_data", MeanDim.test_data_suite, xfails=xfails, strict=False)
260282
@common.XfailIfNoCorstone320
261283
def test_mean_dim_u85_BI(test_data):
262284
test_data, dim, keep_dim = test_data()

backends/arm/test/ops/test_scalars.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -242,21 +242,16 @@ def test_add_scalar_u85_BI():
242242

243243

244244
# SUB MI ------------------------------------------------------
245-
mi_sub_xfails = {
246-
"int_r1_ts": "TypeError: All IO needs to have the same data type, got input 1: 8, input 2: 6 and output: 8",
247-
"int_r4_ts": "TypeError: All IO needs to have the same data type, got input 1: 8, input 2: 6 and output: 8",
248-
**xfails,
249-
}
250245

251246

252-
@common.parametrize("test_data", tensor_scalar_tests, xfails=mi_sub_xfails)
247+
@common.parametrize("test_data", tensor_scalar_tests, xfails=xfails)
253248
def test_sub_tensor_tosa_MI_scalar(test_data):
254249
"""Tests regular sub with one scalar input."""
255250
pipeline = TosaPipelineMI[input_t1](Sub(), test_data, aten_op=Sub.aten_op)
256251
pipeline.run()
257252

258253

259-
@common.parametrize("test_data", tensor_scalar_tests, xfails=mi_sub_xfails)
254+
@common.parametrize("test_data", tensor_scalar_tests, xfails=xfails)
260255
def test_sub_tensor_tosa_MI_inplace(test_data):
261256
"""Tests inplace sub with one scalar input."""
262257
pipeline = TosaPipelineMI[input_t1](SubInplace(), test_data, aten_op=[])

0 commit comments

Comments
 (0)