From 23d7b69a5f61c47b898b63e56a654642e2b7d61a Mon Sep 17 00:00:00 2001 From: Oscar Andersson Date: Tue, 19 Aug 2025 13:22:31 +0200 Subject: [PATCH] Arm backend: Support per-channel in TOSA.RESCALE Adds support for per-channel rescaling in TOSA dialect RESCALE op. Signed-off-by: Oscar Andersson Change-Id: I4c779634f97b7c930ee76246758fd019e3a6c2e1 --- .../decompose_int16_activation_conv2d_pass.py | 6 +- backends/arm/_passes/insert_rescales_pass.py | 16 ++-- backends/arm/_passes/insert_table_ops.py | 2 +- backends/arm/_passes/rewrite_conv2d_pass.py | 76 +++++++++++++++++-- backends/arm/_passes/rewrite_matmul.py | 2 +- backends/arm/_passes/rewrite_upsample.py | 2 +- backends/arm/operators/op_tosa_conv2d.py | 59 +------------- .../arm/operators/op_tosa_depthwise_conv2d.py | 4 + backends/arm/operators/op_tosa_rescale.py | 6 +- .../arm/test/misc/test_tosa_dialect_conv2d.py | 4 +- .../test/misc/test_tosa_dialect_dw_conv2d.py | 4 +- backends/arm/test/passes/test_rescale_pass.py | 14 ++-- backends/arm/tosa/dialect/ops/conv2d.py | 3 +- backends/arm/tosa/dialect/ops/rescale.py | 6 +- 14 files changed, 110 insertions(+), 94 deletions(-) diff --git a/backends/arm/_passes/decompose_int16_activation_conv2d_pass.py b/backends/arm/_passes/decompose_int16_activation_conv2d_pass.py index d43c2a8c89c..388ce217807 100644 --- a/backends/arm/_passes/decompose_int16_activation_conv2d_pass.py +++ b/backends/arm/_passes/decompose_int16_activation_conv2d_pass.py @@ -105,14 +105,14 @@ def call_operator(self, op, args, kwargs, meta): conv_output = super().call_operator( exir_ops.backend.tosa.RESCALE.default, - (convolution, torch.int32, conv_rescale_factor, 0, 0), + (convolution, torch.int32, [conv_rescale_factor], 0, 0), {}, new_meta, ) bias_rescaled = super().call_operator( exir_ops.backend.tosa.RESCALE.default, - (channel_bias, torch.int32, bias_rescale_factor, 0, 0), + (channel_bias, torch.int32, [bias_rescale_factor], 0, 0), {}, new_meta, ) @@ -129,7 +129,7 @@ def call_operator(self, op, args, kwargs, meta): ( add, output_dtype, - (common_scale / (conv_output_scale * (1 << bits_left_to_shift))), + [(common_scale / (conv_output_scale * (1 << bits_left_to_shift)))], 0, 0, ), diff --git a/backends/arm/_passes/insert_rescales_pass.py b/backends/arm/_passes/insert_rescales_pass.py index a7fa614c8c3..89630978366 100644 --- a/backends/arm/_passes/insert_rescales_pass.py +++ b/backends/arm/_passes/insert_rescales_pass.py @@ -45,7 +45,7 @@ def fold_dq_q_to_rescale(self, node: Node, user: Node, graph_module: GraphModule ( node.all_input_nodes[0], q_args.dtype, - new_scale, + [new_scale], dq_args.zp, q_args.zp, ), @@ -228,10 +228,10 @@ def _rescale_inputs(self, graph, node, rescale_qargs: Dict[int, QuantArgs]) -> b ( arg_node, torch.int32, - qp.get_scale_per_tensor() - / rescale_qargs[ - i - ].get_scale_per_tensor(), # Old scale / new scale + [ + qp.get_scale_per_tensor() + / rescale_qargs[i].get_scale_per_tensor() + ], # [Old scale / new scale] qp.get_zp_per_tensor(), # Old zero point rescale_qargs[i].get_zp_per_tensor(), # New zero point ), @@ -264,8 +264,10 @@ def _rescale_outputs(self, graph, node, rescale_qargs: Optional[QuantArgs]) -> b ( node, qarg.dtype, - rescale_qargs.get_scale_per_tensor() - / qarg.get_scale_per_tensor(), # Old scale / new scale + [ + rescale_qargs.get_scale_per_tensor() + / qarg.get_scale_per_tensor() + ], # [Old scale / new scale] rescale_qargs.get_zp_per_tensor(), # Old zero point qarg.get_zp_per_tensor(), # New zero point ), diff --git a/backends/arm/_passes/insert_table_ops.py b/backends/arm/_passes/insert_table_ops.py index e77d0c64c71..8d8a1284011 100644 --- a/backends/arm/_passes/insert_table_ops.py +++ b/backends/arm/_passes/insert_table_ops.py @@ -286,7 +286,7 @@ def call(self, graph_module: GraphModule) -> PassResult: rescale_node = create_node( graph=graph_module.graph, op_target=exir_ops.backend.tosa.RESCALE.default, - args=(table_op_node, output_qparams[0].dtype, scale, 0, 0), + args=(table_op_node, output_qparams[0].dtype, [scale], 0, 0), ) output_node = rescale_node diff --git a/backends/arm/_passes/rewrite_conv2d_pass.py b/backends/arm/_passes/rewrite_conv2d_pass.py index 8b4f43c35c7..c46cfb3b205 100644 --- a/backends/arm/_passes/rewrite_conv2d_pass.py +++ b/backends/arm/_passes/rewrite_conv2d_pass.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. +import itertools from typing import Set, Type import torch @@ -16,6 +17,10 @@ is_buffer, is_param, ) +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_input_qparams, + get_output_qparams, +) from executorch.backends.arm.constants import HWCM_ORDER, NHWC_INVERSE_ORDER from executorch.backends.arm.tosa.mapping import TosaSpecialDtype from executorch.backends.transforms.utils import create_constant_placeholder @@ -156,6 +161,40 @@ def _add_bias( node.update_arg(2, bias_node) return bias_node + def insert_output_rescale(self, graph_module, node): + input_qparams = get_input_qparams(node) + output_qparams = get_output_qparams(node)[0] + weight_qparams = input_qparams[1] + input_qparams = input_qparams[0] + is_per_channel = weight_qparams.per_channel + if is_per_channel: + weight_scale = weight_qparams.get_scale_per_channel() + else: + weight_scale = [weight_qparams.get_scale_per_tensor()] + input_scale = input_qparams.get_scale_per_tensor() + post_conv2d_scale = [ + (inp * w) / out + for inp, w, out in zip( + itertools.cycle([input_scale]), + weight_scale, + itertools.cycle([output_qparams.get_scale_per_tensor()]), + ) + ] + with graph_module.graph.inserting_after(node): + rescale_node = create_node( + graph=graph_module.graph, + op_target=exir_ops.backend.tosa.RESCALE.default, + args=( + node, + output_qparams.dtype, + post_conv2d_scale, + 0, + output_qparams.get_zp_per_tensor(), + ), + from_node=node, + ) + return rescale_node + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: modified = False for node in graph_module.graph.nodes: @@ -180,20 +219,20 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: ) = node.args pad = [val for val in pad for _ in (0, 1)] - input_shape = get_first_fake_tensor(x).shape - weight_shape = get_first_fake_tensor(weight).shape + input_fake_tensor = get_first_fake_tensor(x) + weight_fake_tensor = get_first_fake_tensor(weight) # Adjust the pad value if needed to meet the # strict convolution output shape calculation. pad[1] = self._adjust_pad_if_needed( - input_shape[2], - weight_shape[2], + input_fake_tensor.shape[2], + weight_fake_tensor.shape[2], stride[0], pad[1], dilation[0], ) pad[3] = self._adjust_pad_if_needed( - input_shape[3], - weight_shape[3], + input_fake_tensor.shape[3], + weight_fake_tensor.shape[3], stride[1], pad[3], dilation[1], @@ -204,7 +243,8 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: if self._is_depthwise_conv2d(node): target_op = exir_ops.backend.tosa.DEPTHWISE_CONV2D.default - self._reshape_weights(weight, input_shape[1]) + self._reshape_weights(weight, input_fake_tensor.shape[1]) + weight_fake_tensor = get_first_fake_tensor(weight) else: target_op = exir_ops.backend.tosa.CONV2D.default @@ -227,9 +267,29 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: args=conv2d_args, from_node=node, ) + bias_fake_tensor = get_first_fake_tensor(bias) if bias else None + tosa_node_fake_tensor = target_op( + input_fake_tensor, + weight_fake_tensor, + bias_fake_tensor, + *conv2d_args[3:], + ) + if ( + tosa_node_fake_tensor.dtype == torch.int32 + and input_fake_tensor.dtype == torch.int8 + ) or ( + tosa_node_fake_tensor.dtype == torch.int32 + and input_fake_tensor.dtype == torch.int16 + ): + output_rescale = self.insert_output_rescale(graph_module, tosa_op) + node.replace_all_uses_with(output_rescale) + if input_fake_tensor.dtype == torch.int16: + tosa_op.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.INT48 + else: node.replace_all_uses_with(tosa_op) - graph_module.graph.erase_node(node) + + graph_module.graph.erase_node(node) if modified: graph_module.recompile() diff --git a/backends/arm/_passes/rewrite_matmul.py b/backends/arm/_passes/rewrite_matmul.py index 28ff800792b..410f0d62bff 100644 --- a/backends/arm/_passes/rewrite_matmul.py +++ b/backends/arm/_passes/rewrite_matmul.py @@ -44,7 +44,7 @@ def _insert_output_rescale(self, graph_module, node, tosa_matmul_node, dtype): rescale_node.args = ( tosa_matmul_node, dtype, - scale, + [scale], 0, output_qparams.get_zp_per_tensor(), ) diff --git a/backends/arm/_passes/rewrite_upsample.py b/backends/arm/_passes/rewrite_upsample.py index c9f25a1e845..e0ef1dbcf4a 100644 --- a/backends/arm/_passes/rewrite_upsample.py +++ b/backends/arm/_passes/rewrite_upsample.py @@ -74,7 +74,7 @@ def call(self, graph_module): rescale_node.args = ( tosa_resize_node, output_dtype, - output_scale, + [output_scale], 0, # zero point 0, # zero point ) diff --git a/backends/arm/operators/op_tosa_conv2d.py b/backends/arm/operators/op_tosa_conv2d.py index 3631a143b50..0e10867da7e 100644 --- a/backends/arm/operators/op_tosa_conv2d.py +++ b/backends/arm/operators/op_tosa_conv2d.py @@ -8,14 +8,12 @@ """Provide a visitor for lowering 2D convolution to TOSA (INT/FP).""" -import itertools from typing import Any, List import torch from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( get_input_qparams, - get_output_qparams, ) from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -26,9 +24,7 @@ validate_valid_dtype, ) from executorch.backends.arm.tosa.mapping import TosaArg -from executorch.backends.arm.tosa.quant_utils import build_rescale from executorch.backends.arm.tosa.specification import Tosa_1_00, TosaSpecification -from executorch.backends.arm.tosa.utils import tosa_shape @register_node_visitor @@ -58,7 +54,8 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - """Define the TOSA CONV2D/DEPTHWISE_CONV2D operator and post-rescale.""" + """Define the TOSA CONV2D/DEPTHWISE_CONV2D operator.""" + input, weight, bias, stride, pad, dilation, _, _, group = inputs validate_num_inputs(self.target, inputs, 9) @@ -105,23 +102,8 @@ def define_node( input_qparams = get_input_qparams(node) weight_zp = input_qparams[1].zp # type: ignore[assignment] - # The output type is int32 when input type is int8. - if inputs[0].dtype == ts.DType.INT8: - conv2d_res = tosa_graph.addIntermediate( - tosa_shape(output.shape, output.dim_order), ts.DType.INT32 - ) - conv2d_output_name = conv2d_res.name - acc_type = ts.DType.INT32 - elif inputs[0].dtype == ts.DType.INT16: - conv2d_res = tosa_graph.addIntermediate( - tosa_shape(output.shape, output.dim_order), ts.DType.INT48 - ) - conv2d_output_name = conv2d_res.name - acc_type = ts.DType.INT48 - else: - conv2d_output_name = output.name - conv2d_res = output - acc_type = ts.DType.FP32 + conv2d_output_name = output.name + acc_type = output.dtype tosa_graph.addConst( [1], inputs[0].dtype, [input_zp], name=f"{conv2d_output_name}_input_zp" @@ -158,36 +140,3 @@ def define_node( [conv2d_output_name], attr, ) - - # For quantized convolution, rescale the output value back to the same - # integer value domain of the next op. Otherwise return float32 output. - if output.dtype == ts.DType.INT8 or output.dtype == ts.DType.INT16: - # Get scale_factor from input, weight, and output. - input_scale = input_qparams[0].get_scale_per_tensor() # type: ignore[possibly-undefined] # pyre-ignore [61] - per_channel_quant = input_qparams[1].per_channel # pyre-ignore [61] - if per_channel_quant: - weight_scale = input_qparams[1].get_scale_per_channel() - else: - weight_scale = [ - input_qparams[1].get_scale_per_tensor() - ] # pyre-ignore [61] - output_qargs = get_output_qparams(node) - post_conv2d_scale = [ - (inp * w) / out - for inp, w, out in zip( - itertools.cycle([input_scale]), - weight_scale, - itertools.cycle([output_qargs[0].get_scale_per_tensor()]), - ) - ] - build_rescale( - tosa_fb=tosa_graph, - scale=post_conv2d_scale, - input_node=conv2d_res, # type: ignore[possibly-undefined] - output_name=output.name, - output_type=output.dtype, - input_zp=[0], - output_zp=[output_qargs[0].get_zp_per_tensor()], - per_channel=per_channel_quant, - rounding_mode=ts.RoundingMode.SINGLE_ROUND, - ) diff --git a/backends/arm/operators/op_tosa_depthwise_conv2d.py b/backends/arm/operators/op_tosa_depthwise_conv2d.py index 3538b6f31da..1d1c317a0b8 100644 --- a/backends/arm/operators/op_tosa_depthwise_conv2d.py +++ b/backends/arm/operators/op_tosa_depthwise_conv2d.py @@ -4,7 +4,11 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe + +"""Provide a visitor for lowering 2D depthwise convolution to TOSA (INT/FP).""" + import tosa_serializer as ts + from executorch.backends.arm.operators.node_visitor import register_node_visitor from executorch.backends.arm.operators.op_tosa_conv2d import Conv2dVisitor from executorch.backends.arm.tosa import TosaSpecification diff --git a/backends/arm/operators/op_tosa_rescale.py b/backends/arm/operators/op_tosa_rescale.py index db3738a8fd1..26f27370bdd 100644 --- a/backends/arm/operators/op_tosa_rescale.py +++ b/backends/arm/operators/op_tosa_rescale.py @@ -41,7 +41,7 @@ def define_node( input_dtype = inputs[0].dtype output_dtype = cast(torch.dtype, node.args[1]) - scale = cast(float, node.args[2]) + scales = cast(list[float], node.args[2]) input_zp = cast(int, node.args[3]) output_zp = cast(int, node.args[4]) @@ -63,12 +63,12 @@ def define_node( build_rescale( tosa_graph, - scale=[scale], + scale=scales, input_node=inputs[0], output_name=output.name, output_type=output.dtype, input_zp=[input_zp], output_zp=[output_zp], rounding_mode=ts.RoundingMode.SINGLE_ROUND, - per_channel=False, + per_channel=len(scales) > 1, ) diff --git a/backends/arm/test/misc/test_tosa_dialect_conv2d.py b/backends/arm/test/misc/test_tosa_dialect_conv2d.py index 867578a4ff5..3496ca0d5b6 100644 --- a/backends/arm/test/misc/test_tosa_dialect_conv2d.py +++ b/backends/arm/test/misc/test_tosa_dialect_conv2d.py @@ -31,7 +31,7 @@ def test_conv2d_tosa_INT(): 4, ), (1, 8, 20, 20), - torch.int8, + torch.int32, ), ( ( @@ -46,7 +46,7 @@ def test_conv2d_tosa_INT(): 4, ), (1, 4, 10, 10), - torch.int8, + torch.int32, ), ] diff --git a/backends/arm/test/misc/test_tosa_dialect_dw_conv2d.py b/backends/arm/test/misc/test_tosa_dialect_dw_conv2d.py index 8d9224d90fe..8b50df20830 100644 --- a/backends/arm/test/misc/test_tosa_dialect_dw_conv2d.py +++ b/backends/arm/test/misc/test_tosa_dialect_dw_conv2d.py @@ -32,7 +32,7 @@ def test_depthwise_conv2d_tosa_INT(): 8, ), (1, 16, 20, 20), - torch.int8, + torch.int32, ), ( ( @@ -48,7 +48,7 @@ def test_depthwise_conv2d_tosa_INT(): 8, ), (1, 32, 10, 10), - torch.int8, + torch.int32, ), ] diff --git a/backends/arm/test/passes/test_rescale_pass.py b/backends/arm/test/passes/test_rescale_pass.py index 9774ebd2fcd..ecd1deadf4f 100644 --- a/backends/arm/test/passes/test_rescale_pass.py +++ b/backends/arm/test/passes/test_rescale_pass.py @@ -31,21 +31,21 @@ def test_rescale_op(): ( torch.randint(low=0, high=100, size=(4, 4, 4), dtype=torch.int8), torch.int32, - 0.2, + [0.2], 2, 0, ), ( torch.randint(low=0, high=100, size=(4, 4, 4), dtype=torch.int32), torch.int8, - 0.2, + [0.2], 0, -128, ), ( torch.randint(low=0, high=100, size=(4, 4, 4), dtype=torch.int8), torch.int8, - 0.8, + [0.8], 10, 127, ), @@ -71,14 +71,14 @@ def test_nonzero_zp_for_int32(): ( torch.randint(low=0, high=100, size=(4, 4, 4), dtype=torch.int8), torch.int32, - 0.2, + [0.2], 2, # Should be 0, expect error 1, ), ( torch.randint(low=0, high=100, size=(4, 4, 4), dtype=torch.int32), torch.int8, - 0.2, + [0.2], 1, 1, # Should be 0, expect error ), @@ -107,14 +107,14 @@ def test_zp_outside_range(): ( torch.randint(low=0, high=100, size=(4, 4, 4), dtype=torch.int8), torch.int32, - 0.2, + [0.2], 128, # Should be <128, expect error 0, ), ( torch.randint(low=0, high=100, size=(4, 4, 4), dtype=torch.int32), torch.int8, - 0.2, + [0.2], 0, -129, # Should be >-129m expect error ), diff --git a/backends/arm/tosa/dialect/ops/conv2d.py b/backends/arm/tosa/dialect/ops/conv2d.py index 052c1111615..45afae51708 100644 --- a/backends/arm/tosa/dialect/ops/conv2d.py +++ b/backends/arm/tosa/dialect/ops/conv2d.py @@ -45,8 +45,7 @@ def validate_conv2d_args_dtypes( f"TOSA spec {tosa_spec} only supports {torch.int32} bias for {x.dtype} input but found {bias.dtype}", op=op, ) - # TODO update to int32 for int8 inputs - output_dtype = torch.int8 if x.dtype == torch.int8 else torch.int16 + output_dtype = torch.int32 elif x.dtype in supported_float_types: if not tosa_spec.support_float(): diff --git a/backends/arm/tosa/dialect/ops/rescale.py b/backends/arm/tosa/dialect/ops/rescale.py index 5f0cf9d15dc..f622bbf115d 100644 --- a/backends/arm/tosa/dialect/ops/rescale.py +++ b/backends/arm/tosa/dialect/ops/rescale.py @@ -3,6 +3,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import List + import torch from executorch.backends.arm.tosa.dialect.lib import TosaValueError from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op @@ -14,13 +16,13 @@ @register_fake_tosa_op( - "RESCALE(Tensor input1, ScalarType dtype, float scale, int in_zp, int out_zp) -> Tensor", # schema + "RESCALE(Tensor input1, ScalarType dtype, float[] scale, int in_zp, int out_zp) -> Tensor", # schema ( TosaSpecification.create_from_string("TOSA-1.0+INT"), ), # target TOSA specifications ) def RESCALE( - x: torch.Tensor, dtype: torch.dtype, scale: float, in_zp: int, out_zp: int + x: torch.Tensor, dtype: torch.dtype, scales: List[float], in_zp: int, out_zp: int ) -> torch.Tensor: tosa_spec = get_context_spec() """Casts the input tensor to dtype `dtype` to produce the correct tensor meta for a _rescale op.