Skip to content

Commit d9e99cb

Browse files
AdrianLundellMartin Lindström
andauthored
Arm backend: Grouped conv per-channel quant support (#12671)
Adds support for per-channel quantization of grouped convolution. Signed-off-by: Oscar Andersson <[email protected]> Co-authored-by: Martin Lindström <[email protected]>
1 parent b7eff08 commit d9e99cb

File tree

3 files changed

+46
-12
lines changed

3 files changed

+46
-12
lines changed

backends/arm/_passes/decompose_grouped_conv.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from copy import copy
77

88
import torch
9+
from executorch.backends.arm.tosa_quant_utils import QuantArgs
910
from executorch.exir.dialects._ops import ops as exir_ops
1011
from executorch.exir.pass_base import ExportPass
1112

@@ -48,7 +49,40 @@ def _get_decomposition(op):
4849
torch.ops.aten.cat.default,
4950
)
5051
case _:
51-
raise RuntimeError("Unvalid op for grouped conv decomposition.")
52+
raise RuntimeError("Invalid op for grouped conv decomposition")
53+
54+
@staticmethod
55+
def _split_per_channel_qparams(qarg, index, output_slice_size):
56+
if qarg is not None and qarg.per_channel:
57+
start_index = index * output_slice_size
58+
stop_index = (index + 1) * output_slice_size
59+
return QuantArgs(
60+
scale=qarg.scale[start_index:stop_index],
61+
zp=qarg.zp[start_index:stop_index],
62+
qmin=qarg.qmin,
63+
qmax=qarg.qmax,
64+
dtype=qarg.dtype,
65+
axis=qarg.axis,
66+
per_channel=qarg.per_channel,
67+
)
68+
return qarg
69+
70+
@staticmethod
71+
def _get_meta_copy(meta, i, output_slice_size):
72+
meta_copy = meta.copy()
73+
if "input_qparams" in meta.data and len(meta.data["input_qparams"]) > 0:
74+
# Handle per-channel quantization by splitting quantization params
75+
# similarly to how activations/weights/biases are split.
76+
new_qparams = meta.data.get("input_qparams").copy()
77+
# Get quantization params of the weights and slice them.
78+
qarg = new_qparams[1]
79+
new_qparams[1] = DecomposeGroupedConv._split_per_channel_qparams(
80+
qarg, index=i, output_slice_size=output_slice_size
81+
)
82+
83+
meta_copy.data["input_qparams"] = new_qparams
84+
85+
return meta_copy
5286

5387
def call_operator(self, op, args, kwargs, meta):
5488
if op == exir_ops.edge.aten.convolution.default:
@@ -105,7 +139,6 @@ def call_operator(self, op, args, kwargs, meta):
105139
if bias_node is None:
106140
bias_slices.append(None)
107141
else:
108-
109142
start_index = i * output_slice_size
110143
stop_index = (i + 1) * output_slice_size
111144
slice_args = (bias_node, 0, start_index, stop_index)
@@ -115,20 +148,23 @@ def call_operator(self, op, args, kwargs, meta):
115148
)
116149

117150
output_slices = []
118-
for input_slice, filter_slice, bias_slice in zip(
119-
input_slices, filter_slices, bias_slices
151+
for i, (input_slice, filter_slice, bias_slice) in enumerate(
152+
zip(input_slices, filter_slices, bias_slices)
120153
):
121154

155+
meta_copy = DecomposeGroupedConv._get_meta_copy(meta, i, output_slice_size)
156+
122157
if op == exir_ops.edge.aten.convolution.default:
123158
conv_args = (input_slice, filter_slice, bias_slice, *args[3:8], 1)
124159
elif op == torch.ops.aten.conv2d.default:
125160
conv_args = (input_slice, filter_slice, bias_slice, *args[3:6], 1)
126161
else:
127-
raise RuntimeError("Unvalid op for grouped conv decomposition.")
162+
raise RuntimeError("Invalid op for grouped conv decomposition")
128163

129164
output_slices.append(
130-
super().call_operator(conv_op, conv_args, kwargs, meta)
165+
super().call_operator(conv_op, conv_args, kwargs, meta_copy)
131166
)
132167

133168
cat_args = (output_slices, 1)
134-
return super().call_operator(cat_op, cat_args, kwargs, no_q_dq_meta)
169+
# propagate original metadata (including quantization params) to the concatenated output
170+
return super().call_operator(cat_op, cat_args, kwargs, meta)

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class FoldAndAnnotateQParamsPass(ArmPass):
7575
node.
7676
The quantization parameters from the DQ/Q nodes are stored as meta values to be
7777
accessible for later lowering and serialization passes.
78-
The assumption is that the quantization annotatation adds DQ nodes for all tensor
78+
The assumption is that the quantization annotation adds DQ nodes for all tensor
7979
inputs to the target one Q node to the output.
8080
8181
Example ('executorch_exir_dialects_edge__ops_' prefix removed from operators for readability):
@@ -95,7 +95,7 @@ class FoldAndAnnotateQParamsPass(ArmPass):
9595
9696
output_dq: "f32[5]" = quantized_decomposed_dequantize_per_tensor_default(aten_add_tensor_q, 0.05487706884741783, -128, -128, 127, torch.int8)
9797
98-
The quantization parameters for x_dq and aten_add_tensor_q are store in meta for the aten_add_tensor node.
98+
The quantization parameters for x_dq and aten_add_tensor_q are stored in meta for the aten_add_tensor node.
9999
100100
"""
101101

@@ -132,7 +132,7 @@ def fold_and_annotate_arg(
132132
nodes_to_remove.add(arg)
133133
if input_qparams is not None and input_qparams != arg_quant_params:
134134
# Two args are quantized differently
135-
raise RuntimeError("Input qparams does not match!")
135+
raise RuntimeError("Input qparams do not match")
136136
input_qparams = arg_quant_params
137137
if input_qparams is not None:
138138
node.meta["input_qparams"][i] = input_qparams

backends/arm/test/ops/test_conv2d.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -385,8 +385,6 @@ def forward(self, x):
385385
f"{k},per_channel_quant={q}": (lambda v=v, q=q: (v(), q))
386386
for (k, v) in test_data_MI.items()
387387
for q in [True, False]
388-
# TODO: Invalid TOSA graph (MLETORCH-1144)
389-
if (k not in ["groups", "groups_bias"]) and (q is True)
390388
}
391389

392390
fvp_xfails = {

0 commit comments

Comments
 (0)