Skip to content

Commit 1739639

Browse files
Arm backend: Support per-channel in TOSA.RESCALE
Adds support for per-channel rescaling in TOSA dialect RESCALE op. Signed-off-by: Oscar Andersson <[email protected]> Change-Id: I4c779634f97b7c930ee76246758fd019e3a6c2e1
1 parent 7215b88 commit 1739639

14 files changed

+107
-95
lines changed

backends/arm/_passes/decompose_int16_activation_conv2d_pass.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,14 +105,14 @@ def call_operator(self, op, args, kwargs, meta):
105105

106106
conv_output = super().call_operator(
107107
exir_ops.backend.tosa.RESCALE.default,
108-
(convolution, torch.int32, conv_rescale_factor, 0, 0),
108+
(convolution, torch.int32, [conv_rescale_factor], 0, 0),
109109
{},
110110
new_meta,
111111
)
112112

113113
bias_rescaled = super().call_operator(
114114
exir_ops.backend.tosa.RESCALE.default,
115-
(channel_bias, torch.int32, bias_rescale_factor, 0, 0),
115+
(channel_bias, torch.int32, [bias_rescale_factor], 0, 0),
116116
{},
117117
new_meta,
118118
)
@@ -129,7 +129,7 @@ def call_operator(self, op, args, kwargs, meta):
129129
(
130130
add,
131131
output_dtype,
132-
(common_scale / (conv_output_scale * (1 << bits_left_to_shift))),
132+
[(common_scale / (conv_output_scale * (1 << bits_left_to_shift)))],
133133
0,
134134
0,
135135
),

backends/arm/_passes/insert_rescales_pass.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def fold_dq_q_to_rescale(self, node: Node, user: Node, graph_module: GraphModule
4545
(
4646
node.all_input_nodes[0],
4747
q_args.dtype,
48-
new_scale,
48+
[new_scale],
4949
dq_args.zp,
5050
q_args.zp,
5151
),
@@ -228,10 +228,10 @@ def _rescale_inputs(self, graph, node, rescale_qargs: Dict[int, QuantArgs]) -> b
228228
(
229229
arg_node,
230230
torch.int32,
231-
qp.get_scale_per_tensor()
232-
/ rescale_qargs[
233-
i
234-
].get_scale_per_tensor(), # Old scale / new scale
231+
[
232+
qp.get_scale_per_tensor()
233+
/ rescale_qargs[i].get_scale_per_tensor()
234+
], # [Old scale / new scale]
235235
qp.get_zp_per_tensor(), # Old zero point
236236
rescale_qargs[i].get_zp_per_tensor(), # New zero point
237237
),
@@ -264,8 +264,10 @@ def _rescale_outputs(self, graph, node, rescale_qargs: Optional[QuantArgs]) -> b
264264
(
265265
node,
266266
qarg.dtype,
267-
rescale_qargs.get_scale_per_tensor()
268-
/ qarg.get_scale_per_tensor(), # Old scale / new scale
267+
[
268+
rescale_qargs.get_scale_per_tensor()
269+
/ qarg.get_scale_per_tensor()
270+
], # [Old scale / new scale]
269271
rescale_qargs.get_zp_per_tensor(), # Old zero point
270272
qarg.get_zp_per_tensor(), # New zero point
271273
),

backends/arm/_passes/insert_table_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
286286
rescale_node = create_node(
287287
graph=graph_module.graph,
288288
op_target=exir_ops.backend.tosa.RESCALE.default,
289-
args=(table_op_node, output_qparams[0].dtype, scale, 0, 0),
289+
args=(table_op_node, output_qparams[0].dtype, [scale], 0, 0),
290290
)
291291
output_node = rescale_node
292292

backends/arm/_passes/rewrite_conv2d_pass.py

Lines changed: 68 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66

7+
import itertools
78
from typing import Set, Type
89

910
import torch
@@ -16,6 +17,10 @@
1617
is_buffer,
1718
is_param,
1819
)
20+
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
21+
get_input_qparams,
22+
get_output_qparams,
23+
)
1924
from executorch.backends.arm.constants import HWCM_ORDER, NHWC_INVERSE_ORDER
2025
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
2126
from executorch.backends.transforms.utils import create_constant_placeholder
@@ -156,6 +161,40 @@ def _add_bias(
156161
node.update_arg(2, bias_node)
157162
return bias_node
158163

164+
def insert_output_rescale(self, graph_module, node):
165+
input_qparams = get_input_qparams(node)
166+
output_qparams = get_output_qparams(node)[0]
167+
weight_qparams = input_qparams[1]
168+
input_qparams = input_qparams[0]
169+
is_per_channel = weight_qparams.per_channel
170+
if is_per_channel:
171+
weight_scale = weight_qparams.get_scale_per_channel()
172+
else:
173+
weight_scale = [weight_qparams.get_scale_per_tensor()]
174+
input_scale = input_qparams.get_scale_per_tensor()
175+
post_conv2d_scale = [
176+
(inp * w) / out
177+
for inp, w, out in zip(
178+
itertools.cycle([input_scale]),
179+
weight_scale,
180+
itertools.cycle([output_qparams.get_scale_per_tensor()]),
181+
)
182+
]
183+
with graph_module.graph.inserting_after(node):
184+
rescale_node = create_node(
185+
graph=graph_module.graph,
186+
op_target=exir_ops.backend.tosa.RESCALE.default,
187+
args=(
188+
node,
189+
output_qparams.dtype,
190+
post_conv2d_scale,
191+
0,
192+
output_qparams.get_zp_per_tensor(),
193+
),
194+
from_node=node,
195+
)
196+
return rescale_node
197+
159198
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
160199
modified = False
161200
for node in graph_module.graph.nodes:
@@ -180,20 +219,20 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
180219
) = node.args
181220

182221
pad = [val for val in pad for _ in (0, 1)]
183-
input_shape = get_first_fake_tensor(x).shape
184-
weight_shape = get_first_fake_tensor(weight).shape
222+
input_fake_tensor = get_first_fake_tensor(x)
223+
weight_fake_tensor = get_first_fake_tensor(weight)
185224
# Adjust the pad value if needed to meet the
186225
# strict convolution output shape calculation.
187226
pad[1] = self._adjust_pad_if_needed(
188-
input_shape[2],
189-
weight_shape[2],
227+
input_fake_tensor.shape[2],
228+
weight_fake_tensor.shape[2],
190229
stride[0],
191230
pad[1],
192231
dilation[0],
193232
)
194233
pad[3] = self._adjust_pad_if_needed(
195-
input_shape[3],
196-
weight_shape[3],
234+
input_fake_tensor.shape[3],
235+
weight_fake_tensor.shape[3],
197236
stride[1],
198237
pad[3],
199238
dilation[1],
@@ -204,7 +243,8 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
204243

205244
if self._is_depthwise_conv2d(node):
206245
target_op = exir_ops.backend.tosa.DEPTHWISE_CONV2D.default
207-
self._reshape_weights(weight, input_shape[1])
246+
self._reshape_weights(weight, input_fake_tensor.shape[1])
247+
weight_fake_tensor = get_first_fake_tensor(weight)
208248
else:
209249
target_op = exir_ops.backend.tosa.CONV2D.default
210250

@@ -227,9 +267,29 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
227267
args=conv2d_args,
228268
from_node=node,
229269
)
270+
bias_fake_tensor = get_first_fake_tensor(bias) if bias else None
271+
tosa_node_fake_tensor = target_op(
272+
input_fake_tensor,
273+
weight_fake_tensor,
274+
bias_fake_tensor,
275+
*conv2d_args[3:],
276+
)
230277

278+
if (
279+
tosa_node_fake_tensor.dtype == torch.int32
280+
and input_fake_tensor.dtype == torch.int8
281+
) or (
282+
tosa_node_fake_tensor.dtype == torch.int32
283+
and input_fake_tensor.dtype == torch.int16
284+
):
285+
output_rescale = self.insert_output_rescale(graph_module, tosa_op)
286+
node.replace_all_uses_with(output_rescale)
287+
if input_fake_tensor.dtype == torch.int16:
288+
tosa_op.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.INT48
289+
else:
231290
node.replace_all_uses_with(tosa_op)
232-
graph_module.graph.erase_node(node)
291+
292+
graph_module.graph.erase_node(node)
233293

234294
if modified:
235295
graph_module.recompile()

backends/arm/_passes/rewrite_matmul.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def _insert_output_rescale(self, graph_module, node, tosa_matmul_node, dtype):
4444
rescale_node.args = (
4545
tosa_matmul_node,
4646
dtype,
47-
scale,
47+
[scale],
4848
0,
4949
output_qparams.get_zp_per_tensor(),
5050
)

backends/arm/_passes/rewrite_upsample.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def call(self, graph_module):
7474
rescale_node.args = (
7575
tosa_resize_node,
7676
output_dtype,
77-
output_scale,
77+
[output_scale],
7878
0, # zero point
7979
0, # zero point
8080
)

backends/arm/operators/op_tosa_conv2d.py

Lines changed: 3 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,12 @@
88

99
"""Provide a visitor for lowering 2D convolution to TOSA (INT/FP)."""
1010

11-
import itertools
1211
from typing import Any, List
1312

1413
import torch
1514

1615
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
1716
get_input_qparams,
18-
get_output_qparams,
1917
)
2018
from executorch.backends.arm.operators.node_visitor import (
2119
NodeVisitor,
@@ -26,9 +24,7 @@
2624
validate_valid_dtype,
2725
)
2826
from executorch.backends.arm.tosa.mapping import TosaArg
29-
from executorch.backends.arm.tosa.quant_utils import build_rescale
3027
from executorch.backends.arm.tosa.specification import Tosa_1_00, TosaSpecification
31-
from executorch.backends.arm.tosa.utils import tosa_shape
3228

3329

3430
@register_node_visitor
@@ -60,8 +56,7 @@ def define_node(
6056
inputs: List[TosaArg],
6157
output: TosaArg,
6258
) -> None:
63-
"""Define the TOSA CONV2D/DEPTHWISE_CONV2D operator and post-rescale."""
64-
from tosa.RoundingMode import RoundingMode # type: ignore
59+
"""Define the TOSA CONV2D/DEPTHWISE_CONV2D operator."""
6560

6661
input, weight, bias, stride, pad, dilation, _, _, group = inputs
6762
validate_num_inputs(self.target, inputs, 9)
@@ -109,23 +104,8 @@ def define_node(
109104
input_qparams = get_input_qparams(node)
110105
weight_zp = input_qparams[1].zp # type: ignore[assignment]
111106

112-
# The output type is int32 when input type is int8.
113-
if inputs[0].dtype == ts.DType.INT8:
114-
conv2d_res = tosa_graph.addIntermediate(
115-
tosa_shape(output.shape, output.dim_order), ts.DType.INT32
116-
)
117-
conv2d_output_name = conv2d_res.name
118-
acc_type = ts.DType.INT32
119-
elif inputs[0].dtype == ts.DType.INT16:
120-
conv2d_res = tosa_graph.addIntermediate(
121-
tosa_shape(output.shape, output.dim_order), ts.DType.INT48
122-
)
123-
conv2d_output_name = conv2d_res.name
124-
acc_type = ts.DType.INT48
125-
else:
126-
conv2d_output_name = output.name
127-
conv2d_res = output
128-
acc_type = ts.DType.FP32
107+
conv2d_output_name = output.name
108+
acc_type = output.dtype
129109

130110
tosa_graph.addConst(
131111
[1], inputs[0].dtype, [input_zp], name=f"{conv2d_output_name}_input_zp"
@@ -162,36 +142,3 @@ def define_node(
162142
[conv2d_output_name],
163143
attr,
164144
)
165-
166-
# For quantized convolution, rescale the output value back to the same
167-
# integer value domain of the next op. Otherwise return float32 output.
168-
if output.dtype == ts.DType.INT8 or output.dtype == ts.DType.INT16:
169-
# Get scale_factor from input, weight, and output.
170-
input_scale = input_qparams[0].get_scale_per_tensor() # type: ignore[possibly-undefined] # pyre-ignore [61]
171-
per_channel_quant = input_qparams[1].per_channel # pyre-ignore [61]
172-
if per_channel_quant:
173-
weight_scale = input_qparams[1].get_scale_per_channel()
174-
else:
175-
weight_scale = [
176-
input_qparams[1].get_scale_per_tensor()
177-
] # pyre-ignore [61]
178-
output_qargs = get_output_qparams(node)
179-
post_conv2d_scale = [
180-
(inp * w) / out
181-
for inp, w, out in zip(
182-
itertools.cycle([input_scale]),
183-
weight_scale,
184-
itertools.cycle([output_qargs[0].get_scale_per_tensor()]),
185-
)
186-
]
187-
build_rescale(
188-
tosa_fb=tosa_graph,
189-
scale=post_conv2d_scale,
190-
input_node=conv2d_res, # type: ignore[possibly-undefined]
191-
output_name=output.name,
192-
output_type=output.dtype,
193-
input_zp=[0],
194-
output_zp=[output_qargs[0].get_zp_per_tensor()],
195-
per_channel=per_channel_quant,
196-
rounding_mode=RoundingMode.SINGLE_ROUND,
197-
)

backends/arm/operators/op_tosa_depthwise_conv2d.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7+
"""Provide a visitor for lowering 2D depthwise convolution to TOSA (INT/FP)."""
8+
79
from executorch.backends.arm.operators.node_visitor import register_node_visitor
810
from executorch.backends.arm.operators.op_tosa_conv2d import Conv2dVisitor
911
from executorch.backends.arm.tosa import TosaSpecification

backends/arm/operators/op_tosa_rescale.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def define_node(
4343

4444
input_dtype = inputs[0].dtype
4545
output_dtype = cast(torch.dtype, node.args[1])
46-
scale = cast(float, node.args[2])
46+
scales = cast(list[float], node.args[2])
4747
input_zp = cast(int, node.args[3])
4848
output_zp = cast(int, node.args[4])
4949

@@ -65,12 +65,12 @@ def define_node(
6565

6666
build_rescale(
6767
tosa_graph,
68-
scale=[scale],
68+
scale=scales,
6969
input_node=inputs[0],
7070
output_name=output.name,
7171
output_type=output.dtype,
7272
input_zp=[input_zp],
7373
output_zp=[output_zp],
7474
rounding_mode=RoundingMode.SINGLE_ROUND,
75-
per_channel=False,
75+
per_channel=len(scales) > 1,
7676
)

backends/arm/test/misc/test_tosa_dialect_conv2d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_conv2d_tosa_INT():
3131
4,
3232
),
3333
(1, 8, 20, 20),
34-
torch.int8,
34+
torch.int32,
3535
),
3636
(
3737
(
@@ -46,7 +46,7 @@ def test_conv2d_tosa_INT():
4646
4,
4747
),
4848
(1, 4, 10, 10),
49-
torch.int8,
49+
torch.int32,
5050
),
5151
]
5252

0 commit comments

Comments
 (0)