diff --git a/backends/arm/operator_support/convolution_support.py b/backends/arm/operator_support/convolution_support.py index 75899eb7425..5b4fefdbf81 100644 --- a/backends/arm/operator_support/convolution_support.py +++ b/backends/arm/operator_support/convolution_support.py @@ -11,7 +11,11 @@ register_tosa_support_check, SupportedTOSAOperatorCheck, ) -from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification +from executorch.backends.arm.tosa_specification import ( + Tosa_0_80, + Tosa_1_00, + TosaSpecification, +) from executorch.exir.dialects._ops import ops as exir_ops @@ -43,6 +47,9 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): # Hardware specific constraints if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset): + # TODO remove this once TOSA 1.0 support for u55 is added. + if isinstance(tosa_spec, Tosa_1_00) and "u55" in tosa_spec.extensions: + return False return True else: return self._is_node_supported_u55(node) diff --git a/backends/arm/operators/op_abs.py b/backends/arm/operators/op_abs.py index 648edde04f4..a6d649fd92e 100644 --- a/backends/arm/operators/op_abs.py +++ b/backends/arm/operators/op_abs.py @@ -4,12 +4,11 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe -from typing import List +from typing import Any, List import executorch.backends.arm.tosa_quant_utils as tqutils import executorch.backends.arm.tosa_utils as tutils -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -33,10 +32,13 @@ def __init__(self, *args): def define_node( self, node: Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + # Specification (0.80) states that input and output types # should all be the same if not (inputs[0].dtype == output.dtype): @@ -53,7 +55,7 @@ def define_node( if inputs[0].dtype == ts.DType.INT8: rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32( tosa_graph, inputs, node - ) + ) # type: ignore[possibly-undefined] else: # input[0].dtype == ts.DType.INT32 # Non quantized input, natively support by TOSA.abs @@ -96,10 +98,13 @@ def __init__(self, *args): def define_node( self, node: Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + # Specification (0.80) states that input and output types # should all be the same if not (inputs[0].dtype == output.dtype): @@ -129,3 +134,122 @@ def define_node( [output.name], None, ) + + +@register_node_visitor +class AbsVisitor_INT(NodeVisitor): + target = "aten.abs.default" + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-1.0+INT"), + ] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + + import serializer.tosa_serializer as ts # type: ignore + + # Specification (1.0) states that input and output types + # should all be the same + if not (inputs[0].dtype == output.dtype): + raise ValueError( + "All inputs and outputs need same dtype." + f"Got {inputs[0].dtype=}, {output.dtype=}" + ) + # Handle int8 (quantized) and int32 + if not (inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]): + raise ValueError( + "All inputs need to be INT8 or INT32." f"Got {inputs[0].dtype=}" + ) + + scale_back = 1.0 + if inputs[0].dtype == ts.DType.INT8: + rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32( + tosa_graph, inputs, node, self.tosa_specs + ) # type: ignore[possibly-undefined] + else: + # input[0].dtype == ts.DType.INT32 + # Non quantized input, natively support by TOSA.abs + rescaled_inputs = inputs + + if output.dtype == ts.DType.INT8: + broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order) + abs_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32) + else: + # output.dtype == ts.DType.INT32 + abs_output = output + + # Do the INT32 Abs + tosa_graph.addOperator( + ts.TosaOp.Op().ABS, + [ + rescaled_inputs[0].name, + ], + [abs_output.name], + None, + ) + + if output.dtype == ts.DType.INT8: + # Scale output back to 8 bit + # pyre-ignore + tqutils.insert_rescale_op_to_int8( + tosa_graph, abs_output, scale_back, node, self.tosa_specs + ) # type: ignore[possibly-undefined] + + +@register_node_visitor +class AbsVisitor_FP(AbsVisitor_INT): + # inheriting 'target' from BI class + + tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + + import serializer.tosa_serializer as ts # type: ignore + + # Specification (1.0) states that input and output types + # should all be the same + if not (inputs[0].dtype == output.dtype): + raise ValueError( + "All inputs and output need same dtype." + f"Got {inputs[0].dtype=}, {output.dtype=}" + ) + + if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]: + # Call the inherited define_node for handling integers + super().define_node(node, tosa_graph, inputs, output) + else: + # FP32 Abs lowering + + if not (inputs[0].dtype == ts.DType.FP32): + raise ValueError( + "All inputs need to be FP32." f"Got {inputs[0].dtype=}" + ) + + if not (output.dtype == ts.DType.FP32): + raise ValueError("All outputs need to be FP32." f"Got {output.dtype=}") + + # MI lowering + tosa_graph.addOperator( + ts.TosaOp.Op().ABS, + [inputs[0].name], + [output.name], + None, + ) diff --git a/backends/arm/operators/op_add.py b/backends/arm/operators/op_add.py index 904a2405047..11c32a3ae5f 100644 --- a/backends/arm/operators/op_add.py +++ b/backends/arm/operators/op_add.py @@ -5,12 +5,11 @@ # pyre-unsafe -from typing import List +from typing import Any, List import executorch.backends.arm.tosa_quant_utils as tqutils import executorch.backends.arm.tosa_utils as tutils -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -34,10 +33,13 @@ def __init__(self, *args): def define_node( self, node: Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + # Specification (0.80) states that input and output types # should all be the same if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype: @@ -58,7 +60,7 @@ def define_node( if len(inputs[0].shape) > len(inputs[1].shape) else inputs[1].dim_order ) - + scale_back = 1.0 if inputs[0].dtype == ts.DType.INT8: rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32( tosa_graph, inputs, node @@ -90,7 +92,9 @@ def define_node( if output.dtype == ts.DType.INT8: # Scale output back to 8 bit # pyre-ignore - tqutils.insert_rescale_op_to_int8(tosa_graph, add_output, scale_back, node) # type: ignore[possibly-undefined] + tqutils.insert_rescale_op_to_int8( + tosa_graph, add_output, scale_back, node + ) # type: ignore[possibly-undefined] @register_node_visitor @@ -107,10 +111,13 @@ def __init__(self, *args): def define_node( self, node: Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + # Specification (0.80) states that input and output types # should all be the same if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype: @@ -130,7 +137,7 @@ def define_node( f"Expected IO data type to be FP32, got {inputs[0].dtype}" ) - input1, input2 = tutils.reshape_for_broadcast(tosa_graph, inputs) + input1, input2 = inputs # MI lowering tosa_graph.addOperator( @@ -139,3 +146,122 @@ def define_node( [output.name], None, ) + + +@register_node_visitor +class AddVisitor_INT(NodeVisitor): + target = "aten.add.Tensor" + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-1.0+INT"), + ] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + + import serializer.tosa_serializer as ts # type: ignore + + # Specification (1.0) states that input and output types + # should all be the same + if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype: + raise TypeError( + f"All IO needs to have the same data type, got input 1: " + f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: " + f"{output.dtype}" + ) + # Handle int8 (quantized) and int32 + supported_dtypes = [ts.DType.INT8, ts.DType.INT32] + if inputs[0].dtype not in supported_dtypes: + raise TypeError( + f'IO data type needs to be {supported_dtypes}, got "{inputs[0].dtype}"' + ) + scale_back = 1.0 + if inputs[0].dtype == ts.DType.INT8: + rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32( + tosa_graph, inputs, node, self.tosa_specs + ) + else: + # input[0].dtype == ts.DType.INT32 + # Non quantized input, natively support by TOSA.ADD + rescaled_inputs = inputs + + if output.dtype == ts.DType.INT8: + broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order) + add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32) + else: + # output.dtype == ts.DType.INT32 + add_output = output + + input1, input2 = rescaled_inputs + + # Do the INT32 Add + tosa_graph.addOperator( + ts.TosaOp.Op().ADD, + [input1.name, input2.name], + [add_output.name], + None, + ) + + if output.dtype == ts.DType.INT8: + # Scale output back to 8 bit + # pyre-ignore + tqutils.insert_rescale_op_to_int8( + tosa_graph, add_output, scale_back, node, self.tosa_specs + ) # type: ignore[possibly-undefined] + + +@register_node_visitor +class AddVisitor_FP(AddVisitor_INT): + # inheriting 'target' from INT class + + tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + + import serializer.tosa_serializer as ts # type: ignore + + # Specification (1.0) states that input and output types + # should all be the same + if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype: + raise TypeError( + f"All IO needs to have the same data type, got input 1: " + f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: " + f"{output.dtype}" + ) + + if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]: + # Call the inherited define_node for handling integers + super().define_node(node, tosa_graph, inputs, output) + else: + # FP32 Add lowering + if inputs[0].dtype != ts.DType.FP32: + raise TypeError( + f"Expected IO data type to be FP32, got {inputs[0].dtype}" + ) + + input1, input2 = inputs + + # FP lowering + tosa_graph.addOperator( + ts.TosaOp.Op().ADD, + [input1.name, input2.name], + [output.name], + None, + ) diff --git a/backends/arm/operators/op_bmm.py b/backends/arm/operators/op_bmm.py index 69c8283bc8c..ebc43ca33f6 100644 --- a/backends/arm/operators/op_bmm.py +++ b/backends/arm/operators/op_bmm.py @@ -5,12 +5,10 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe -from typing import List +from typing import Any, List import torch -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( get_input_qparams, get_output_qparams, @@ -19,30 +17,42 @@ NodeVisitor, register_node_visitor, ) + from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import build_rescale +from executorch.backends.arm.tosa_quant_utils import build_rescale, build_rescale_v0_80 +from executorch.backends.arm.tosa_specification import TosaSpecification +from tosa.RoundingMode import RoundingMode # type: ignore @register_node_visitor -class BMMVisitor(NodeVisitor): +class BMMVisitor_0_80(NodeVisitor): target = "aten.bmm.default" + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80+BI"), + TosaSpecification.create_from_string("TOSA-0.80+MI"), + ] + def __init__(self, *args): super().__init__(*args) def define_node( self, node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype: raise TypeError( f"All IO needs to have the same data type, got: " f"{inputs[0].dtype=}, {inputs[1].dtype=} and {output.dtype=}" ) + # aten.bmm maps directly to MATMUL # NOTE: For now, only INT8 & FP32 is supported supported_dtypes = [ts.DType.INT8, ts.DType.FP32] for input in inputs: @@ -83,15 +93,102 @@ def define_node( input_qparams[0].scale * input_qparams[1].scale # type: ignore[possibly-undefined] # pyre-ignore[61] ) / output_qparams.scale - build_rescale( + build_rescale_v0_80( tosa_fb=tosa_graph, scale=[final_output_scale], # pyre-ignore[61]: Uninitialized local [61]: Local variable `bmm_result` is undefined, or not always defined. input_node=bmm_result, # type: ignore[possibly-undefined] output_name=output.name, output_type=ts.DType.INT8, - output_shape=bmm_result.shape, input_zp=0, output_zp=output_qparams.zp, is_double_round=False, ) + + +@register_node_visitor +class BMMVisitor(NodeVisitor): + target = "aten.bmm.default" + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), + ] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + + import serializer.tosa_serializer as ts # type: ignore + + if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype: + raise TypeError( + f"All IO needs to have the same data type, got: " + f"{inputs[0].dtype=}, {inputs[1].dtype=} and {output.dtype=}" + ) + + # aten.bmm maps directly to MATMUL + # NOTE: For now, only INT8 & FP32 is supported + supported_dtypes = [ts.DType.INT8, ts.DType.FP32] + for input in inputs: + if input.dtype not in supported_dtypes: + raise TypeError( + f'IO data type needs to be {supported_dtypes}, got "{input.dtype}"' + ) + + # aten.bmm maps directly to MATMUL + # NOTE: For now, only INT8 & FP32 is supported + + # For INT8, we need to get the zero points and add an intermediate tensor + # for a later rescale. + + if inputs[0].dtype == ts.DType.INT8: + input_qparams = get_input_qparams(node) + input0_zp = input_qparams[0].zp + input1_zp = input_qparams[1].zp + bmm_result = tosa_graph.addIntermediate(output.shape, ts.DType.INT32) + bmm_output_name = bmm_result.name + else: + bmm_output_name = output.name + input0_zp, input1_zp = 0, 0 + + tosa_graph.addConst([1], inputs[0].dtype, [input0_zp], name=f"{node.name}_A_ZP") + tosa_graph.addConst([1], inputs[1].dtype, [input1_zp], name=f"{node.name}_B_ZP") + + # Add the MATMUL to the TOSA graph. + tosa_graph.addOperator( + ts.TosaOp.Op().MATMUL, + [ + inputs[0].name, + inputs[1].name, + f"{node.name}_A_ZP", + f"{node.name}_B_ZP", + ], + [bmm_output_name], + ) + + # As INT8 accumulates into INT32, we need to rescale it back to INT8 + if output.dtype == ts.DType.INT8: + output_qparams = get_output_qparams(node)[0] + final_output_scale = ( + input_qparams[0].scale * input_qparams[1].scale # type: ignore[possibly-undefined] # pyre-ignore[61] + ) / output_qparams.scale + + build_rescale( + tosa_fb=tosa_graph, + scale=[final_output_scale], + # pyre-ignore[61]: Uninitialized local [61]: Local variable `bmm_result` is undefined, or not always defined. + input_node=bmm_result, # type: ignore[possibly-undefined] + output_name=output.name, + output_type=ts.DType.INT8, + input_zp=0, + output_zp=output_qparams.zp, + rounding_mode=RoundingMode.SINGLE_ROUND, + ) diff --git a/backends/arm/operators/op_conv2d.py b/backends/arm/operators/op_conv2d.py index 90475af1476..6f91c181bd2 100644 --- a/backends/arm/operators/op_conv2d.py +++ b/backends/arm/operators/op_conv2d.py @@ -4,12 +4,11 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe -from typing import List +from typing import Any, List +import numpy as np import torch -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( get_input_qparams, get_output_qparams, @@ -19,14 +18,20 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import build_rescale_conv_output +from executorch.backends.arm.tosa_quant_utils import build_rescale, build_rescale_v0_80 +from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.backends.arm.tosa_utils import build_reshape, tosa_shape @register_node_visitor -class Conv2dVisitor(NodeVisitor): +class Conv2dVisitor_0_80(NodeVisitor): target = "aten.convolution.default" + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80+BI"), + TosaSpecification.create_from_string("TOSA-0.80+MI"), + ] + def __init__(self, *args): super().__init__(*args) @@ -54,10 +59,13 @@ def adjust_pad_if_needed( def define_node( self, node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + input, weight, bias, stride, pad, dilation, _, _, group = inputs # Get the attributes of convolution. @@ -170,14 +178,224 @@ def define_node( input_scale = input_qparams[0].scale # type: ignore[possibly-undefined] # pyre-ignore [61] weight_scale = input_qparams[1].scale # pyre-ignore [61] output_qargs = get_output_qparams(node) - build_rescale_conv_output( - tosa_graph, - # pyre-fixme[61]: Uninitialized local [61]: Local variable `conv2d_res` is undefined, or not always defined. - conv2d_res, # type: ignore[possibly-undefined] - output.name, - output.dtype, - [input_scale], - [weight_scale], - [output_qargs[0].scale], - output_qargs[0].zp, + post_conv2d_scale = [ + (inp * w) / out + for inp, w, out in zip( + [input_scale], [weight_scale], [output_qargs[0].scale] + ) + ] + + build_rescale_v0_80( + tosa_fb=tosa_graph, + scale=post_conv2d_scale, + input_node=conv2d_res, # type: ignore[possibly-undefined] + output_name=output.name, + output_type=output.dtype, + input_zp=0, + output_zp=output_qargs[0].zp, + per_channel=isinstance(weight_scale, torch.Tensor), + ) # type: ignore[call-arg] + + +@register_node_visitor +class Conv2dVisitor(NodeVisitor): + target = "aten.convolution.default" + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), + ] + + def __init__(self, *args): + super().__init__(*args) + + # torch.nn.Conv2d does not require the result of + # `(input + 2 * pad - dilation * (weight - 1) - 1) / stride` + # to be an integer, but tosa currently strictly require this property. + # This function adjusts the pad value to meet the requirement. + def adjust_pad_if_needed( + self, input_size: int, input_weight: int, stride: int, pad: int, dilation: int + ) -> int: + mod_remainder = ( + input_size + 2 * pad - dilation * (input_weight - 1) - 1 + ) % stride + + # No need to adjust + if mod_remainder == 0: + return pad + + if mod_remainder > pad: + raise RuntimeError( + "This case should be handled by the SizeAdjustConv2d pass, is it enabled?" + ) + return pad - mod_remainder + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + + import serializer.tosa_serializer as ts # type: ignore + from tosa.RoundingMode import RoundingMode # type: ignore + + input, weight, bias, stride, pad, dilation, _, _, group = inputs + + # Get the attributes of convolution. + attr = ts.TosaSerializerAttribute() + pad_attr = [val for val in pad.special for _ in (0, 1)] + stride_attr = stride.special + dilation_attr = dilation.special + + # Adjust the pad value if needed to meet the + # strict convolution output shape calculation. + pad_attr[1] = self.adjust_pad_if_needed( + input.shape[2], + weight.shape[2], + stride_attr[0], + pad_attr[1], + dilation_attr[0], + ) + pad_attr[3] = self.adjust_pad_if_needed( + input.shape[3], + weight.shape[3], + stride_attr[1], + pad_attr[3], + dilation_attr[1], + ) + + input_zp = 0 + if inputs[0].dtype == ts.DType.INT8: + # int8 input requires quantization information + input_qparams = get_input_qparams(node) + input_zp = input_qparams[0].zp + + tosa_graph.addConst([1], output.dtype, [input_zp], name=f"{node.name}_input_zp") + tosa_graph.addConst([1], output.dtype, [0], name=f"{node.name}_weight_zp") + acc_type = ( + inputs[0].dtype if inputs[0].dtype == ts.DType.FP32 else ts.DType.INT32 + ) + + # Non-bias case. + if len(node.all_input_nodes) == 2: + # Create a zero bias tensor if not presented + out_channels = weight.shape[0] + bias_name = "bias" + node.name.split("default", 1)[1] + bias_type = output.dtype + if output.dtype == ts.DType.INT8: + # Conv is quantized to int8, but the TOSA operator has + # output type int32, and the bias must be the same type + # as the TOSA output type + bias_type = ts.DType.INT32 + bias = tosa_graph.addConst( + [out_channels], + bias_type, + [0] * out_channels, + name=bias_name, + ) + + # The output type is int32 when input type is int8. + conv2d_output_name = output.name + if output.dtype == ts.DType.INT8: + conv2d_res = tosa_graph.addIntermediate( + tosa_shape(output.shape, output.dim_order), ts.DType.INT32 + ) + conv2d_output_name = conv2d_res.name + + # Given input.shape is (N, Ci, H, W), and weight.shape is (Co, Ci/G, H, W) + in_channels = input.shape[1] + out_channels = weight.shape[0] + if (in_channels == group.number) and (out_channels % in_channels) == 0: + """Depthwise convolution case""" + # Reshape torch shape format of weight tensor to tosa required format. + # https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d + m_length = int(out_channels / in_channels) + weight_post_shape = [ + weight.shape[2], + weight.shape[3], + in_channels, + m_length, + ] + + weight_reshaped = tosa_graph.addIntermediate( + weight_post_shape, + weight.dtype, + ) + shape = tosa_graph.addConst( + np.array(weight_post_shape).shape, + ts.DType.SHAPE, + np.array(weight_post_shape), + name=weight_reshaped.name + "_shape", + ) + + attr = ts.TosaSerializerAttribute() + attr.ReshapeAttribute() + tosa_graph.addOperator( + ts.TosaOp.Op().RESHAPE, + [weight.name, shape.name], + [weight_reshaped.name], + attr, + ) + + tosa_op = ts.TosaOp.Op().DEPTHWISE_CONV2D + weight_name = weight_reshaped.name + + attr.DepthwiseConv2dAttribute( + pad=pad_attr, + stride=stride_attr, + dilation=dilation_attr, + local_bound=False, + acc_type=acc_type, + ) + else: + """Regular convolution case""" + tosa_op = ts.TosaOp.Op().CONV2D + weight_name = weight.name + + attr.Conv2dAttribute( + pad=pad_attr, + stride=stride_attr, + dilation=dilation_attr, + local_bound=False, + acc_type=acc_type, + ) + + tosa_graph.addOperator( + tosa_op, + [ + input.name, + weight_name, + bias.name, + f"{node.name}_input_zp", + f"{node.name}_weight_zp", + ], + [conv2d_output_name], + attr, + ) + + # For quantized convolution, rescale the output value back to the same + # integer value domain of the next op. Otherwise return float32 output. + if inputs[0].dtype == ts.DType.INT8: + # Get scale_factor from input, weight, and output. + input_scale = input_qparams[0].scale # type: ignore[possibly-undefined] # pyre-ignore [61] + weight_scale = input_qparams[1].scale # pyre-ignore [61] + output_qargs = get_output_qparams(node) + post_conv2d_scale = [ + (inp * w) / out + for inp, w, out in zip( + [input_scale], [weight_scale], [output_qargs[0].scale] + ) + ] + build_rescale( + tosa_fb=tosa_graph, + scale=post_conv2d_scale, + input_node=conv2d_res, # type: ignore[possibly-undefined] + output_name=output.name, + output_type=output.dtype, + input_zp=0, + output_zp=output_qargs[0].zp, + per_channel=isinstance(weight_scale, torch.Tensor), + rounding_mode=RoundingMode.SINGLE_ROUND, ) diff --git a/backends/arm/operators/op_eq.py b/backends/arm/operators/op_eq.py index 7f87fb5a81d..a36a1f1b0cd 100644 --- a/backends/arm/operators/op_eq.py +++ b/backends/arm/operators/op_eq.py @@ -5,34 +5,42 @@ # pyre-unsafe -from typing import List +from typing import Any, List import executorch.backends.arm.tosa_quant_utils as tqutils -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_specification import TosaSpecification from torch.fx import Node @register_node_visitor -class EqualVisitor(NodeVisitor): +class EqualVisitor_0_80(NodeVisitor): target = "aten.eq.Tensor" + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80+BI"), + TosaSpecification.create_from_string("TOSA-0.80+MI"), + ] + def __init__(self, *args): super().__init__(*args) def define_node( self, node: Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + if inputs[0].dtype != inputs[1].dtype: raise TypeError( "All inputs need to have the same data type for operator EQ but got " @@ -57,3 +65,51 @@ def define_node( output.name, None, ) + + +@register_node_visitor +class EqualVisitor(NodeVisitor): + target = "aten.eq.Tensor" + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), + ] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + + import serializer.tosa_serializer as ts # type: ignore + + if inputs[0].dtype != inputs[1].dtype: + raise TypeError( + "All inputs need to have the same data type for operator EQ but got " + f"{inputs[0].dtype=}, {inputs[1].dtype=}" + ) + + input_nodes = inputs + # Handle quantization + if inputs[0].dtype == ts.DType.INT8: + # Rescale inputs to 32 bit + rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32( + tosa_graph, inputs, node, self.tosa_specs + ) + + # Update IO + input_nodes = rescaled_inputs + + # Do the equal comparison + tosa_graph.addOperator( + ts.TosaOp.Op().EQUAL, + [input_nodes[0].name, input_nodes[1].name], + output.name, + None, + ) diff --git a/backends/arm/operators/op_ge.py b/backends/arm/operators/op_ge.py index b2193a2e7ed..c929f5f9c87 100644 --- a/backends/arm/operators/op_ge.py +++ b/backends/arm/operators/op_ge.py @@ -5,34 +5,42 @@ # pyre-unsafe -from typing import List +from typing import Any, List import executorch.backends.arm.tosa_quant_utils as tqutils -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_specification import TosaSpecification from torch.fx import Node @register_node_visitor -class GreaterEqualVisitor(NodeVisitor): +class GreaterEqualVisitor_0_80(NodeVisitor): target = "aten.ge.Tensor" + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80+BI"), + TosaSpecification.create_from_string("TOSA-0.80+MI"), + ] + def __init__(self, *args): super().__init__(*args) def define_node( self, node: Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + if inputs[0].dtype != inputs[1].dtype: raise TypeError( "All inputs need to have the same data type for operator GE but got " @@ -56,3 +64,50 @@ def define_node( [output.name], None, ) + + +@register_node_visitor +class GreaterEqualVisitor(NodeVisitor): + target = "aten.ge.Tensor" + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), + ] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + + import serializer.tosa_serializer as ts # type: ignore + + if inputs[0].dtype != inputs[1].dtype: + raise TypeError( + "All inputs need to have the same data type for operator GE but got " + f"{inputs[0].dtype=}, {inputs[1].dtype=}" + ) + + input_nodes = inputs + # Handle quantization + if inputs[0].dtype == ts.DType.INT8: + # Rescale inputs to 32 bit + rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32( + tosa_graph, inputs, node, self.tosa_specs + ) + + # Update IO + input_nodes = rescaled_inputs + + tosa_graph.addOperator( + ts.TosaOp.Op().GREATER_EQUAL, + [input_nodes[0].name, input_nodes[1].name], + [output.name], + None, + ) diff --git a/backends/arm/operators/op_gt.py b/backends/arm/operators/op_gt.py index 06f29e4505c..53196a0d03c 100644 --- a/backends/arm/operators/op_gt.py +++ b/backends/arm/operators/op_gt.py @@ -5,34 +5,42 @@ # pyre-unsafe -from typing import List +from typing import Any, List import executorch.backends.arm.tosa_quant_utils as tqutils -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_specification import TosaSpecification from torch.fx import Node @register_node_visitor -class GreaterThanVisitor(NodeVisitor): +class GreaterThanVisitor_0_80(NodeVisitor): target = "aten.gt.Tensor" + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80+BI"), + TosaSpecification.create_from_string("TOSA-0.80+MI"), + ] + def __init__(self, *args): super().__init__(*args) def define_node( self, node: Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + if inputs[0].dtype != inputs[1].dtype: raise TypeError( "All inputs need to have the same data type for operator GT but got " @@ -56,3 +64,50 @@ def define_node( [output.name], None, ) + + +@register_node_visitor +class GreaterThanVisitor(NodeVisitor): + target = "aten.gt.Tensor" + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), + ] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + + import serializer.tosa_serializer as ts # type: ignore + + if inputs[0].dtype != inputs[1].dtype: + raise TypeError( + "All inputs need to have the same data type for operator GT but got " + f"{inputs[0].dtype=}, {inputs[1].dtype=}" + ) + + input_nodes = inputs + # Handle quantization + if inputs[0].dtype == ts.DType.INT8: + # Rescale inputs to 32 bit + rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32( + tosa_graph, inputs, node, self.tosa_specs + ) + + # Update IO + input_nodes = rescaled_inputs + + tosa_graph.addOperator( + ts.TosaOp.Op().GREATER, + [input_nodes[0].name, input_nodes[1].name], + [output.name], + None, + ) diff --git a/backends/arm/operators/op_le.py b/backends/arm/operators/op_le.py index fadf4848359..d927c1ba0db 100644 --- a/backends/arm/operators/op_le.py +++ b/backends/arm/operators/op_le.py @@ -5,34 +5,42 @@ # pyre-unsafe -from typing import List +from typing import Any, List import executorch.backends.arm.tosa_quant_utils as tqutils -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_specification import TosaSpecification from torch.fx import Node @register_node_visitor -class LessEqualVisitor(NodeVisitor): +class LessEqualVisitor_0_80(NodeVisitor): target = "aten.le.Tensor" + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80+BI"), + TosaSpecification.create_from_string("TOSA-0.80+MI"), + ] + def __init__(self, *args): super().__init__(*args) def define_node( self, node: Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + if inputs[0].dtype != inputs[1].dtype: raise TypeError( "All inputs need to have the same data type for operator LE but got " @@ -56,3 +64,50 @@ def define_node( [output.name], None, ) + + +@register_node_visitor +class LessEqualVisitor(NodeVisitor): + target = "aten.le.Tensor" + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), + ] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + + import serializer.tosa_serializer as ts # type: ignore + + if inputs[0].dtype != inputs[1].dtype: + raise TypeError( + "All inputs need to have the same data type for operator LE but got " + f"{inputs[0].dtype=}, {inputs[1].dtype=}" + ) + + input_nodes = inputs + # Handle quantization + if inputs[0].dtype == ts.DType.INT8: + # Rescale inputs to 32 bit + rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32( + tosa_graph, inputs, node, self.tosa_specs + ) + + # Update IO + input_nodes = rescaled_inputs + + tosa_graph.addOperator( + ts.TosaOp.Op().GREATER_EQUAL, + [input_nodes[1].name, input_nodes[0].name], + [output.name], + None, + ) diff --git a/backends/arm/operators/op_lt.py b/backends/arm/operators/op_lt.py index a261cd2db9f..2e49eda7d98 100644 --- a/backends/arm/operators/op_lt.py +++ b/backends/arm/operators/op_lt.py @@ -5,34 +5,42 @@ # pyre-unsafe -from typing import List +from typing import Any, List import executorch.backends.arm.tosa_quant_utils as tqutils -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_specification import TosaSpecification from torch.fx import Node @register_node_visitor -class LessThanVisitor(NodeVisitor): +class LessThanVisitor_0_80(NodeVisitor): target = "aten.lt.Tensor" + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80+BI"), + TosaSpecification.create_from_string("TOSA-0.80+MI"), + ] + def __init__(self, *args): super().__init__(*args) def define_node( self, node: Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + if inputs[0].dtype != inputs[1].dtype: raise TypeError( "All inputs need to have the same data type for operator LT but got " @@ -56,3 +64,50 @@ def define_node( [output.name], None, ) + + +@register_node_visitor +class LessThanVisitor(NodeVisitor): + target = "aten.lt.Tensor" + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), + ] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + + import serializer.tosa_serializer as ts # type: ignore + + if inputs[0].dtype != inputs[1].dtype: + raise TypeError( + "All inputs need to have the same data type for operator LT but got " + f"{inputs[0].dtype=}, {inputs[1].dtype=}" + ) + + input_nodes = inputs + # Handle quantization + if inputs[0].dtype == ts.DType.INT8: + # Rescale inputs to 32 bit + rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32( + tosa_graph, inputs, node, self.tosa_specs + ) + + # Update IO + input_nodes = rescaled_inputs + + tosa_graph.addOperator( + ts.TosaOp.Op().GREATER, + [input_nodes[1].name, input_nodes[0].name], + [output.name], + None, + ) diff --git a/backends/arm/operators/op_maximum.py b/backends/arm/operators/op_maximum.py index ee52e5276cd..983ac5ded6d 100644 --- a/backends/arm/operators/op_maximum.py +++ b/backends/arm/operators/op_maximum.py @@ -5,37 +5,47 @@ # pyre-unsafe -from typing import List +from typing import Any, List import executorch.backends.arm.tosa_quant_utils as tqutils -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( get_input_qparams, ) + from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) + from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.backends.arm.tosa_utils import tosa_shape from torch.fx import Node @register_node_visitor -class MaxVisitor(NodeVisitor): +class MaxVisitor_0_80(NodeVisitor): target = "aten.maximum.default" + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80+BI"), + TosaSpecification.create_from_string("TOSA-0.80+MI"), + ] + def __init__(self, *args): super().__init__(*args) def define_node( self, node: Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + if inputs[0].dtype != inputs[1].dtype and inputs[0].dtype != output.dtype: raise TypeError( f"Data type of inputs and output must be the same. Got input 0 dtype: " @@ -78,3 +88,77 @@ def define_node( if output.dtype == ts.DType.INT8: # insert RESCALE from int32 back to int8 tqutils.insert_rescale_op_to_int8(tosa_graph, max_output, scale_back, node) + + +@register_node_visitor +class MaxVisitor(NodeVisitor): + target = "aten.maximum.default" + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), + ] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + + import serializer.tosa_serializer as ts # type: ignore + from tosa.NanPropagationMode import NanPropagationMode # type: ignore + + if inputs[0].dtype != inputs[1].dtype and inputs[0].dtype != output.dtype: + raise TypeError( + f"Data type of inputs and output must be the same. Got input 0 dtype: " + f"{inputs[0].dtype}, input 1 dtype: {inputs[1].dtype} and output " + f"dtype: {output.dtype}" + ) + + scale_back = 1.0 + max_output = output + if inputs[0].dtype == ts.DType.INT8: + input_qparams = get_input_qparams(node) + if len(input_qparams) != 2: + raise ValueError( + f"Both inputs need to have quantization information for {node}" + ) + if input_qparams[0] != input_qparams[1]: + raise ValueError( + "Both inputs must have the same quantization parameters for MAX" + ) + + operand_inputs, scale_back = tqutils.insert_rescale_ops_to_int32( + tosa_graph, inputs, node, self.tosa_specs + ) + + output.shape = tosa_shape(output.shape, output.dim_order) + max_output = tosa_graph.addIntermediate(output.shape, ts.DType.INT32) + else: + operand_inputs = inputs + + attr_maximum = ts.TosaSerializerAttribute() + + # Set to PROPOGATE as default + attr_maximum.MaximumAttribute(nan_mode=NanPropagationMode.PROPAGATE) + + tosa_graph.addOperator( + ts.TosaOp.Op().MAXIMUM, + [ + operand_inputs[0].name, + operand_inputs[1].name, + ], + [max_output.name], + attr_maximum, + ) + + if output.dtype == ts.DType.INT8: + # insert RESCALE from int32 back to int8 + tqutils.insert_rescale_op_to_int8( + tosa_graph, max_output, scale_back, node, self.tosa_specs + ) diff --git a/backends/arm/operators/op_minimum.py b/backends/arm/operators/op_minimum.py index 88cb8d376fe..f39e2ce6d61 100644 --- a/backends/arm/operators/op_minimum.py +++ b/backends/arm/operators/op_minimum.py @@ -5,12 +5,10 @@ # pyre-unsafe -from typing import List +from typing import Any, List import executorch.backends.arm.tosa_quant_utils as tqutils -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( get_input_qparams, ) @@ -19,24 +17,33 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.backends.arm.tosa_utils import tosa_shape from torch.fx import Node @register_node_visitor -class MinVisitor(NodeVisitor): +class MinVisitor_0_80(NodeVisitor): target = "aten.minimum.default" + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80+BI"), + TosaSpecification.create_from_string("TOSA-0.80+MI"), + ] + def __init__(self, *args): super().__init__(*args) def define_node( self, node: Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + if inputs[0].dtype != inputs[1].dtype and inputs[0].dtype != output.dtype: raise TypeError( f"Data type of inputs and output must be the same. Got input 0 dtype: " @@ -79,3 +86,77 @@ def define_node( if output.dtype == ts.DType.INT8: # insert RESCALE from int32 back to int8 tqutils.insert_rescale_op_to_int8(tosa_graph, min_output, scale_back, node) + + +@register_node_visitor +class MinVisitor(NodeVisitor): + target = "aten.minimum.default" + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), + ] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + + import serializer.tosa_serializer as ts # type: ignore + from tosa.NanPropagationMode import NanPropagationMode # type: ignore + + if inputs[0].dtype != inputs[1].dtype and inputs[0].dtype != output.dtype: + raise TypeError( + f"Data type of inputs and output must be the same. Got input 0 dtype: " + f"{inputs[0].dtype}, input 1 dtype: {inputs[1].dtype} and output " + f"dtype: {output.dtype}" + ) + + scale_back = 1.0 + min_output = output + if inputs[0].dtype == ts.DType.INT8: + input_qparams = get_input_qparams(node) + if len(input_qparams) != 2: + raise ValueError( + f"Both inputs need to have quantization information for {node}" + ) + if input_qparams[0] != input_qparams[1]: + raise ValueError( + "Both inputs must have the same quantization parameters for MIN" + ) + + operand_inputs, scale_back = tqutils.insert_rescale_ops_to_int32( + tosa_graph, inputs, node, self.tosa_specs + ) + + output.shape = tosa_shape(output.shape, output.dim_order) + min_output = tosa_graph.addIntermediate(output.shape, ts.DType.INT32) + else: + operand_inputs = inputs + + attr_minimum = ts.TosaSerializerAttribute() + + # Set to PROPOGATE as default + attr_minimum.MinimumAttribute(nan_mode=NanPropagationMode.PROPAGATE) + + tosa_graph.addOperator( + ts.TosaOp.Op().MINIMUM, + [ + operand_inputs[0].name, + operand_inputs[1].name, + ], + [min_output.name], + attr_minimum, + ) + + if output.dtype == ts.DType.INT8: + # insert RESCALE from int32 back to int8 + tqutils.insert_rescale_op_to_int8( + tosa_graph, min_output, scale_back, node, self.tosa_specs + ) diff --git a/backends/arm/operators/op_mul.py b/backends/arm/operators/op_mul.py index 45dcb974ea4..6c5b94f1a2b 100644 --- a/backends/arm/operators/op_mul.py +++ b/backends/arm/operators/op_mul.py @@ -5,14 +5,12 @@ # pyre-unsafe -from typing import List +from typing import Any, List import executorch.backends.arm.tosa_quant_utils as tqutils import executorch.backends.arm.tosa_utils as tutils import torch -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( get_input_qparams, ) @@ -37,10 +35,13 @@ class MulVisitor_080_BI(NodeVisitor): def define_node( self, node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + if ( inputs[0].dtype != ts.DType.INT8 or inputs[1].dtype != ts.DType.INT8 @@ -114,10 +115,13 @@ class MulVisitor_080_MI(MulVisitor_080_BI): def define_node( self, node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + if inputs[0].dtype == ts.DType.INT8: return super().define_node(node, tosa_graph, inputs, output) @@ -128,3 +132,100 @@ def define_node( tosa_graph.addOperator( ts.TosaOp.Op().MUL, [input1.name, input2.name], [output.name], attr ) + + +@register_node_visitor +class MulVisitor_INT(NodeVisitor): + target = "aten.mul.Tensor" + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-1.0+INT"), + ] + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + + import serializer.tosa_serializer as ts # type: ignore + + if ( + inputs[0].dtype != ts.DType.INT8 + or inputs[1].dtype != ts.DType.INT8 + or output.dtype != ts.DType.INT8 + ): + raise ValueError( + f"Inputs and output for {self.target} need to be INT8, got " + f"{inputs[0].dtype=}, {inputs[1].dtype=} and {output.dtype=}" + ) + + input_A = inputs[0] + input_B = inputs[1] + input_qparams = get_input_qparams(node) + input_A_qargs = input_qparams[0] + input_B_qargs = input_qparams[1] + input_A.shape = tutils.tosa_shape(input_A.shape, input_A.dim_order) + input_B.shape = tutils.tosa_shape(input_B.shape, input_B.dim_order) + + # Rescale inputs to INT32 with zp=0 + input_A_rescaled = tqutils.build_rescale_to_int32( + tosa_graph, + input_A, + input_A_qargs.zp, + [1.0], + tosa_spec=self.tosa_specs, + ) + input_B_rescaled = tqutils.build_rescale_to_int32( + tosa_graph, + input_B, + input_B_qargs.zp, + [1.0], + tosa_spec=self.tosa_specs, + ) + + output_shape = tutils.tosa_shape(output.shape, output.dim_order) + mul_output = tosa_graph.addIntermediate(output_shape, ts.DType.INT32) + + # Do the INT32 Mul + tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{node.name}_shift") + tosa_graph.addOperator( + ts.TosaOp.Op().MUL, + [input_A_rescaled.name, input_B_rescaled.name, f"{node.name}_shift"], + [mul_output.name], + ) + output_scale = input_A_qargs.scale * input_B_qargs.scale + tqutils.insert_rescale_op_to_int8( + tosa_graph, mul_output, output_scale, node, self.tosa_specs + ) + + +@register_node_visitor +class MulVisitor_FP(MulVisitor_INT): + # inheriting 'target' from INT class + + tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")] + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + + import serializer.tosa_serializer as ts # type: ignore + + if inputs[0].dtype == ts.DType.INT8: + return super().define_node(node, tosa_graph, inputs, output) + + input1, input2 = inputs + + tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{node.name}_shift") + tosa_graph.addOperator( + ts.TosaOp.Op().MUL, + [input1.name, input2.name, f"{node.name}_shift"], + [output.name], + ) diff --git a/backends/arm/operators/op_rescale.py b/backends/arm/operators/op_rescale.py index c59015dcc14..33bf6b8fb69 100644 --- a/backends/arm/operators/op_rescale.py +++ b/backends/arm/operators/op_rescale.py @@ -5,32 +5,35 @@ # pyre-unsafe -from typing import cast, List +from typing import Any, cast, List import executorch.backends.arm.tosa_quant_utils as tosa_quant_utils import torch -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - -import tosa_tools.v0_80.tosa.Op as TosaOp # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg +from executorch.backends.arm.tosa_quant_utils import create_const_ops_for_rescale + +from executorch.backends.arm.tosa_specification import TosaSpecification from torch.fx import Node @register_node_visitor -class RescaleVisitor(NodeVisitor): +class RescaleVisitor_0_80(NodeVisitor): target = "_rescale.default" + tosa_specs = NodeVisitor.tosa_specs_0_80 + def define_node( self, node: Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore input_dtype = inputs[0].dtype output_dtype = cast(torch.dtype, node.args[1]) @@ -68,5 +71,73 @@ def define_node( ) tosa_graph.addOperator( - TosaOp.Op().RESCALE, [inputs[0].name], [output.name], attr_rescale + ts.TosaOp.Op().RESCALE, [inputs[0].name], [output.name], attr_rescale + ) + + +@register_node_visitor +class RescaleVisitor_INT(NodeVisitor): + target = "_rescale.default" + + tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+INT")] + + def define_node( + self, + node: Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + import serializer.tosa_serializer as ts # type: ignore + from tosa.RoundingMode import RoundingMode # type: ignore + + input_dtype = inputs[0].dtype + output_dtype = cast(torch.dtype, node.args[1]) + scale = cast(float, node.args[2]) + input_zp = cast(int, node.args[3]) + output_zp = cast(int, node.args[4]) + + if input_dtype != map_dtype(torch.int8) and input_zp != 0: + raise ValueError( + f"If input dtype is not int8, input_zp must be 0. Got input_dtype{ts.DTypeNames[input_dtype]}, {input_zp=}" + ) + if output_dtype != torch.int8 and output_zp != 0: + raise ValueError( + f"If output dtype is not int8, output_zp must be 0. Got {output_dtype=}, {output_zp=}" + ) + + # scale32 gives higher accuracy but for a higher HW cost. + # For now, always go for scale32. + scale_32 = True + scale_width = 32 if scale_32 else 16 + multipliers, shifts = tosa_quant_utils.compute_multiplier_and_shift( + [scale], scale_width + ) + + rescale_inputs = create_const_ops_for_rescale( + tosa_graph, + input_dtype, + inputs[0].name, + multipliers, + shifts, + input_zp, + output_zp, + ts, + ) + + attr_rescale = ts.TosaSerializerAttribute() + + attr_rescale.RescaleAttribute( + scale32=scale_32, + rounding_mode=RoundingMode.SINGLE_ROUND, + per_channel=False, + input_unsigned=False, + output_unsigned=False, + ) + + tosa_graph.addOperator( + ts.TosaOp.Op().RESCALE, + [inputs[0].name, *rescale_inputs], + [output.name], + attr_rescale, ) diff --git a/backends/arm/operators/op_sub.py b/backends/arm/operators/op_sub.py index ef9ed31c88d..65126f4d4dc 100644 --- a/backends/arm/operators/op_sub.py +++ b/backends/arm/operators/op_sub.py @@ -5,12 +5,11 @@ # pyre-unsafe -from typing import List +from typing import Any, List import executorch.backends.arm.tosa_quant_utils as tqutils import executorch.backends.arm.tosa_utils as tutils -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -34,10 +33,13 @@ def __init__(self, *args): def define_node( self, node: Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + # Specification (0.80) states that input and output types # should all be the same if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype: @@ -54,6 +56,7 @@ def define_node( f'IO data type needs to be {supported_dtypes}, got "{inputs[0].dtype}"' ) + scale_back = 1.0 if inputs[0].dtype == ts.DType.INT8: rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32( tosa_graph, inputs, node @@ -84,7 +87,9 @@ def define_node( if output.dtype == ts.DType.INT8: # Scale output back to 8 bit # pyre-ignore - tqutils.insert_rescale_op_to_int8(tosa_graph, sub_output, scale_back, node) # type: ignore[possibly-undefined] + tqutils.insert_rescale_op_to_int8( + tosa_graph, sub_output, scale_back, node + ) # type: ignore[possibly-undefined] @register_node_visitor @@ -101,10 +106,13 @@ def __init__(self, *args): def define_node( self, node: Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + # Specification (0.80) states that input and output types # should all be the same if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype: @@ -136,3 +144,106 @@ def define_node( [output.name], None, ) + + +@register_node_visitor +class SubVisitor_INT(NodeVisitor): + target = "aten.sub.Tensor" + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-1.0+INT"), + ] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + + import serializer.tosa_serializer as ts # type: ignore + + # Specification (1.0) states that input and output types + # should all be the same + assert inputs[0].dtype == inputs[1].dtype == output.dtype + # Handle int8 (quantized) and int32 + assert inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32] + + scale_back = 1.0 + if inputs[0].dtype == ts.DType.INT8: + rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32( + tosa_graph, inputs, node, self.tosa_specs + ) + else: + # input[0].dtype == ts.DType.INT32 + # Non quantized input, natively support by TOSA.SUB + rescaled_inputs = inputs + + if output.dtype == ts.DType.INT8: + broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order) + sub_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32) + else: + # output.dtype == ts.DType.INT32 + sub_output = output + + # Do the INT32 Sub + tosa_graph.addOperator( + ts.TosaOp.Op().SUB, + [ + rescaled_inputs[0].name, + rescaled_inputs[1].name, + ], + [sub_output.name], + None, + ) + + if output.dtype == ts.DType.INT8: + # Scale output back to 8 bit + # pyre-ignore + tqutils.insert_rescale_op_to_int8( + tosa_graph, sub_output, scale_back, node, self.tosa_specs + ) # type: ignore[possibly-undefined] + + +@register_node_visitor +class SubVisitor_FP(SubVisitor_INT): + # inheriting 'target' from INT class + + tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + + import serializer.tosa_serializer as ts # type: ignore + + # Specification (1.0) states that input and output types + # should all be the same + assert inputs[0].dtype == inputs[1].dtype == output.dtype + + if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]: + # Call the inherited define_node for handling integers + super().define_node(node, tosa_graph, inputs, output) + else: + # FP32 Sub lowering + assert inputs[0].dtype == ts.DType.FP32 + assert output.dtype == ts.DType.FP32 + + # MI lowering + tosa_graph.addOperator( + ts.TosaOp.Op().SUB, + [inputs[0].name, inputs[1].name], + [output.name], + None, + ) diff --git a/backends/arm/operators/op_sum.py b/backends/arm/operators/op_sum.py index 135566e48ac..b898eb6cb67 100644 --- a/backends/arm/operators/op_sum.py +++ b/backends/arm/operators/op_sum.py @@ -5,12 +5,11 @@ # pyre-unsafe -from typing import cast, List +from typing import Any, cast, List import executorch.backends.arm.tosa_quant_utils as tqutils import executorch.backends.arm.tosa_utils as tutils -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -34,10 +33,13 @@ def __init__(self, *args): def define_node( self, node: Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + input_shape = list(inputs[0].shape) dim_list = cast(list[int], inputs[1].special) dim_list = [dim % len(input_shape) for dim in dim_list] @@ -89,10 +91,13 @@ def __init__(self, *args): def define_node( self, node: Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + if inputs[0].dtype == ts.DType.INT8: return super().define_node(node, tosa_graph, inputs, output) input_name = inputs[0].name @@ -123,3 +128,115 @@ def define_node( ) input_name = output_name + + +@register_node_visitor +class SumVisitor_INT(NodeVisitor): + target = "aten.sum.dim_IntList" + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-1.0+INT"), + ] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + + import serializer.tosa_serializer as ts # type: ignore + + input_shape = list(inputs[0].shape) + dim_list = cast(list[int], inputs[1].special) + dim_list = [dim % len(input_shape) for dim in dim_list] + keep_dim = cast(bool, inputs[2].number if len(inputs) > 2 else False) + assert keep_dim, "This case should be handled by InsertSqueezeAfterSumPass" + + # Rescale input to 32 bit + rescaled_inputs, scale = tqutils.insert_rescale_ops_to_int32( + tosa_graph, + [inputs[0]], + node, + self.tosa_specs, + ) + + prev_node = rescaled_inputs[0] + reduced_shape = input_shape + + # Reduce all dims in dim_list one-by-one. + for dim in dim_list: + # When reduced, the size of the dim becomes 1. + reduced_shape[dim] = 1 + + attr = ts.TosaSerializerAttribute() + attr.ReduceSumAttribute(inputs[0].dim_order.index(dim)) + + next_node = tosa_graph.addIntermediate( + tutils.tosa_shape(reduced_shape, inputs[0].dim_order), + dtype=ts.DType.INT32, + ) + + tosa_graph.addOperator( + ts.TosaOp.Op().REDUCE_SUM, [prev_node.name], [next_node.name], attr + ) + + prev_node = next_node + tqutils.insert_rescale_op_to_int8( + tosa_graph, prev_node, scale, node, self.tosa_specs + ) + + +@register_node_visitor +class SumVisitor_FP(SumVisitor_INT): + # inheriting 'target' from INT class + + tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + + import serializer.tosa_serializer as ts # type: ignore + + if inputs[0].dtype == ts.DType.INT8: + return super().define_node(node, tosa_graph, inputs, output) + input_name = inputs[0].name + reduced_shape = list(inputs[0].shape) + dim_list = cast(list[int], inputs[1].special) + dim_list = [dim % len(reduced_shape) for dim in dim_list] + keep_dim = cast(bool, inputs[2].number if len(inputs) > 2 else False) + assert keep_dim, "This case should be handled by InsertSqueezeAfterSumPass" + + # Reduce all dims in dim_list one-by-one. + for dim in dim_list: + # When reduced, the size of the dim becomes 1 + reduced_shape[dim] = 1 + + attr = ts.TosaSerializerAttribute() + attr.ReduceSumAttribute(inputs[0].dim_order.index(dim)) + + if dim == dim_list[-1]: + output_name = output.name + else: + output_name = tosa_graph.addIntermediate( + tutils.tosa_shape(reduced_shape, inputs[0].dim_order), + dtype=ts.DType.FP32, + ).name + + tosa_graph.addOperator( + ts.TosaOp.Op().REDUCE_SUM, [input_name], [output_name], attr + ) + + input_name = output_name diff --git a/backends/arm/operators/op_upsample_bilinear2d.py b/backends/arm/operators/op_upsample_bilinear2d.py index 3a238709223..1c0fbc11d24 100644 --- a/backends/arm/operators/op_upsample_bilinear2d.py +++ b/backends/arm/operators/op_upsample_bilinear2d.py @@ -4,25 +4,23 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe -from typing import List +from typing import Any, List import torch -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import build_rescale +from executorch.backends.arm.tosa_quant_utils import build_rescale, build_rescale_v0_80 from executorch.backends.arm.tosa_utils import get_resize_parameters, tosa_shape -from tosa_tools.v0_80.tosa.ResizeMode import ResizeMode # type: ignore @register_node_visitor class UpsampleBilinear2dVisitor_0_80(NodeVisitor): target = "aten.upsample_bilinear2d.vec" + tosa_specs = NodeVisitor.tosa_specs_0_80 def __init__(self, *args): super().__init__(*args) @@ -30,10 +28,14 @@ def __init__(self, *args): def define_node( self, node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + from tosa_tools.v0_80.tosa.ResizeMode import ResizeMode # type: ignore + if inputs[0].shape is None or output.shape is None: raise ValueError("Only static shapes are supported") @@ -85,13 +87,12 @@ def in_int16_range(x): final_output_scale = float(1 / (scale_n_yx[0] * scale_n_yx[1])) - build_rescale( + build_rescale_v0_80( tosa_fb=tosa_graph, scale=[final_output_scale], input_node=intermediate, output_name=output.name, output_type=ts.DType.INT8, - output_shape=output.shape, input_zp=0, output_zp=0, is_double_round=False, @@ -100,3 +101,110 @@ def in_int16_range(x): raise ValueError( "Input/output dtype not in {float32, int8}: {input_dtype=} {output.dtype=}" ) + + +@register_node_visitor +class UpsampleBilinear2dVisitor(NodeVisitor): + + target = "aten.upsample_bilinear2d.vec" + tosa_specs = NodeVisitor.tosa_specs_1_00 + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + import serializer.tosa_serializer as ts + from tosa.ResizeMode import ResizeMode # type: ignore + from tosa.RoundingMode import RoundingMode # type: ignore + + if inputs[0].shape is None or output.shape is None: + raise ValueError("Only static shapes are supported") + + input_dtype = inputs[0].dtype + + # tosa_shape output is NHWC, take HW + input_size_yx = torch.tensor( + tosa_shape(inputs[0].shape, inputs[0].dim_order)[1:3] + ) + # Ignore scale and size parameters, directly use the output size as + # we only support static shapes currently + output_size_yx = torch.tensor(tosa_shape(output.shape, output.dim_order)[1:3]) + + scale_n_yx, scale_d_yx, offset_yx, border_yx = get_resize_parameters( + input_size_yx, output_size_yx, ResizeMode.NEAREST, align_corners=True + ) + + def in_int16_range(x): + return torch.all(x >= -(2**15)) and torch.all(x <= 2**15 - 1) + + assert in_int16_range(scale_n_yx) + assert in_int16_range(scale_d_yx) + assert in_int16_range(border_yx) + + scales = [scale_n_yx[0], scale_d_yx[0], scale_n_yx[1], scale_d_yx[1]] + + attr = ts.TosaSerializerAttribute() + attr.ResizeAttribute(mode=ResizeMode.BILINEAR) + + scales_tensor = tosa_graph.addConst( + [len(scales)], ts.DType.SHAPE, scales, node.name + "_scales" + ) + offset = offset_yx.tolist() + offset_tensor = tosa_graph.addConst( + [len(offset)], ts.DType.SHAPE, offset, node.name + "_offset" + ) + border = border_yx.tolist() + border_tensor = tosa_graph.addConst( + [len(border)], ts.DType.SHAPE, border, node.name + "_border" + ) + if input_dtype == output.dtype == ts.DType.FP32: + tosa_graph.addOperator( + ts.TosaOp.Op().RESIZE, + [ + inputs[0].name, + scales_tensor.name, + offset_tensor.name, + border_tensor.name, + ], + [output.name], + attr, + ) + return + elif input_dtype == output.dtype == ts.DType.INT8: + intermediate = tosa_graph.addIntermediate( + tosa_shape(output.shape, output.dim_order), ts.DType.INT32 + ) + tosa_graph.addOperator( + ts.TosaOp.Op().RESIZE, + [ + inputs[0].name, + scales_tensor.name, + offset_tensor.name, + border_tensor.name, + ], + [intermediate.name], + attr, + ) + + final_output_scale = float(1 / (scale_n_yx[0] * scale_n_yx[1])) + + build_rescale( + tosa_fb=tosa_graph, + scale=[final_output_scale], + input_node=intermediate, + output_name=output.name, + output_type=ts.DType.INT8, + input_zp=0, + output_zp=0, + rounding_mode=RoundingMode.SINGLE_ROUND, + ) + else: + raise ValueError( + "Input/output dtype not in {float32, int8}: {input_dtype=} {output.dtype=}" + ) diff --git a/backends/arm/tosa_quant_utils.py b/backends/arm/tosa_quant_utils.py index afdbf78422e..de73c194a39 100644 --- a/backends/arm/tosa_quant_utils.py +++ b/backends/arm/tosa_quant_utils.py @@ -8,17 +8,19 @@ # Utiliy functions for TOSA quantized lowerings import math -from typing import cast, List, NamedTuple, Tuple + +from typing import Any, cast, NamedTuple, Tuple + +import executorch.backends.arm.tosa_specification as tosa_specification import torch.fx import torch.fx.node -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore -import tosa_tools.v0_80.tosa.Op as TosaOp # type: ignore from executorch.backends.arm.tosa_mapping import TosaArg from executorch.exir.dialects._ops import ops as exir_ops from torch import Tensor from torch.fx import Node +from tosa.RoundingMode import RoundingMode # type: ignore q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default @@ -27,8 +29,11 @@ def insert_rescale_ops_to_int32( - tosa_graph: ts.TosaSerializer, inputs: list[TosaArg], node: Node -) -> tuple[list[ts.TosaSerializerTensor], float]: + tosa_graph: Any, + inputs: list[TosaArg], + node: Node, + tosa_spec=tosa_specification.Tosa_0_80, +) -> tuple[list[Any], float]: """Rescales all 'nodes' to int32, adding suitable RESCALE ops to 'tosa_graph'. The scales are adjusted using the smallest scale of all 'nodes'. @@ -59,24 +64,22 @@ def insert_rescale_ops_to_int32( min_scale = min([qarg.scale for qarg in qargs]) scales = [qarg.scale / min_scale for qarg in qargs] - rescaled_nodes: list[ts.TosaSerializerTensor] = [] + rescaled_nodes: list[Any] = [] for tensor, qarg, scale in zip(tensors, qargs, scales): rescaled_nodes.append( build_rescale_to_int32( - tosa_graph, - tensor, - qarg.zp, - [scale], + tosa_graph, tensor, qarg.zp, [scale], tosa_spec=tosa_spec ) ) return rescaled_nodes, min_scale def insert_rescale_op_to_int8( - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, last_tensor: TosaArg, scale: float, node: Node, + tosa_spec=tosa_specification.Tosa_0_80, ) -> None: """Rescales the node back to int8, adding a suitable RESCALE op to 'tosa_graph'. Parameters: @@ -102,10 +105,11 @@ def insert_rescale_op_to_int8( # Rescale Back to INT8 build_rescale_from_int32( tosa_graph, - last_tensor.name, + last_tensor, node.name, qargs_out.zp, [output_rescale_scale], + tosa_spec=tosa_spec, ) @@ -143,11 +147,6 @@ def from_operator(cls, op, args): raise NotImplementedError -# Check if scale32 mode is used for given output element type -def is_scale32(type: int) -> ts.DType: - return type == ts.DType.INT8 - - # TOSA uses the RESCALE operation to scale between values with differing precision. # The RESCALE operator is defined using an integer multiply, add, and shift. # This utility function is for calculating the multier and shift given a scale. @@ -195,19 +194,23 @@ def compute_multiplier_and_shift( return multipliers, shifts -def build_rescale( - tosa_fb: ts.TosaSerializer, +def build_rescale_v0_80( + tosa_fb: Any, scale: list[float], - input_node: ts.TosaSerializerTensor, + input_node: Any, output_name: str, - output_type: ts.DType, - output_shape: List[int], + output_type: Any, input_zp: int, output_zp: int, is_double_round: bool = False, per_channel=False, ): - scale_width = 32 if is_scale32(output_type) else 16 + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + import tosa_tools.v0_80.tosa.Op as TosaOp # type: ignore + + # Check if scale32 mode is used for given output element type + is_scale32 = output_type == ts.DType.INT8 + scale_width = 32 if is_scale32 else 16 multipliers, shifts = compute_multiplier_and_shift(scale, scale_width) attr_rescale = ts.TosaSerializerAttribute() @@ -216,7 +219,7 @@ def build_rescale( output_zp=output_zp, multiplier=multipliers, shift=shifts, - scale32=is_scale32(output_type), + scale32=is_scale32, double_round=is_double_round, per_channel=per_channel, input_unsigned=False, @@ -230,67 +233,168 @@ def build_rescale( return -def build_rescale_to_int32( - tosa_fb: ts.TosaSerializer, - input_arg: TosaArg, +# For TOSA spec v1.0 RESCALE operator requires multipler, shifts, input_zp and output_zp to be +# const inputs. Create constant operators from the data already initialized. +def create_const_ops_for_rescale( + tosa_fb, input_dtype, input_name, multipliers, shifts, input_zp, output_zp, ts +): + output_dtype = ts.DType.INT32 if input_dtype == ts.DType.INT8 else ts.DType.INT8 + + multipliers = tosa_fb.addConst( + (len(multipliers),), + ts.DType.INT32, + multipliers, + name=input_name + "_multipliers", + ) + shifts = tosa_fb.addConst( + (len(shifts),), ts.DType.INT8, shifts, name=input_name + "_shifts" + ) + input_zp = tosa_fb.addConst( + [1], input_dtype, [input_zp], name=input_name + "_input_zp" + ) + output_zp = tosa_fb.addConst( + [1], output_dtype, [output_zp], name=input_name + "_output_zp" + ) + + return [multipliers.name, shifts.name, input_zp.name, output_zp.name] + + +def build_rescale( + tosa_fb: Any, + scale: list[float], + input_node: Any, + output_name: str, + output_type: Any, input_zp: int, - rescale_scale: list[float], - is_scale32: bool = True, - is_double_round: bool = False, - per_channel: bool = False, -) -> ts.TosaSerializerTensor: - multipliers, shifts = compute_multiplier_and_shift(rescale_scale) + output_zp: int, + rounding_mode: RoundingMode, + per_channel=False, +): + import serializer.tosa_serializer as ts # type: ignore + import tosa.Op as TosaOp # type: ignore + + input_name = input_node.name + + multipliers, shifts = compute_multiplier_and_shift(scale, 32) + rescale_inputs = create_const_ops_for_rescale( + tosa_fb, + input_node.dtype, + input_name, + multipliers, + shifts, + input_zp, + output_zp, + ts, + ) attr_rescale = ts.TosaSerializerAttribute() attr_rescale.RescaleAttribute( - input_zp=input_zp, - output_zp=0, - multiplier=multipliers, - shift=shifts, - scale32=is_scale32, - double_round=is_double_round, + scale32=True, + rounding_mode=rounding_mode, per_channel=per_channel, input_unsigned=False, output_unsigned=False, ) - input_A_rescaled_to_int32 = tosa_fb.addIntermediate(input_arg.shape, ts.DType.INT32) + tosa_fb.addOperator( TosaOp.Op().RESCALE, - [input_arg.name], - [input_A_rescaled_to_int32.name], + [input_node.name, *rescale_inputs], + [output_name], attr_rescale, ) + return + + +def build_rescale_to_int32( + tosa_fb: Any, + input_arg: TosaArg, + input_zp: int, + rescale_scale: list[float], + is_scale32: bool = True, + is_double_round: bool = False, + per_channel: bool = False, + tosa_spec=tosa_specification.Tosa_0_80, +) -> Any: + input_A_rescaled_to_int32 = None + if tosa_spec == tosa_specification.Tosa_0_80: + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + + input_A_rescaled_to_int32 = tosa_fb.addIntermediate( + input_arg.shape, ts.DType.INT32 + ) + + build_rescale_v0_80( + tosa_fb=tosa_fb, + scale=rescale_scale, + input_node=input_arg, + output_name=input_A_rescaled_to_int32.name, + output_type=ts.DType.INT32, + input_zp=input_zp, + output_zp=0, + ) # type: ignore[call-arg] + + elif isinstance(tosa_spec[0], tosa_specification.Tosa_1_00): + # For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs + # to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale + import serializer.tosa_serializer as ts # type: ignore + + input_A_rescaled_to_int32 = tosa_fb.addIntermediate( + input_arg.shape, ts.DType.INT32 + ) + + build_rescale( + tosa_fb, + rescale_scale, + input_arg, + input_A_rescaled_to_int32.name, + ts.DType.INT32, + input_zp, + 0, + rounding_mode=RoundingMode.SINGLE_ROUND, + ) # type: ignore[call-arg] + return input_A_rescaled_to_int32 def build_rescale_from_int32( - tosa_fb: ts.TosaSerializer, - input_name: str, + tosa_fb: Any, + input_node: TosaArg, output_name: str, output_zp: int, rescale_scale: list[float], is_scale32: bool = True, is_double_round: bool = False, per_channel: bool = False, + tosa_spec=tosa_specification.Tosa_0_80, ) -> None: - multipliers, shifts = compute_multiplier_and_shift(rescale_scale) - attr_rescale_output = ts.TosaSerializerAttribute() - attr_rescale_output.RescaleAttribute( - input_zp=0, - output_zp=output_zp, - multiplier=multipliers, - shift=shifts, - scale32=is_scale32, - double_round=is_double_round, - per_channel=per_channel, - input_unsigned=False, - output_unsigned=False, - ) - - tosa_fb.addOperator( - TosaOp.Op().RESCALE, [input_name], [output_name], attr_rescale_output - ) - + if tosa_spec == tosa_specification.Tosa_0_80: + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + + build_rescale_v0_80( + tosa_fb=tosa_fb, + scale=rescale_scale, + input_node=input_node, + output_name=output_name, + output_type=ts.DType.INT8, + input_zp=0, + output_zp=output_zp, + ) # type: ignore[call-arg] + + elif isinstance(tosa_spec[0], tosa_specification.Tosa_1_00): + import serializer.tosa_serializer as ts # type: ignore + + # For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs + # to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale + build_rescale( + tosa_fb, + rescale_scale, + input_node, + output_name=output_name, + output_type=ts.DType.INT8, + input_zp=0, + output_zp=output_zp, + rounding_mode=RoundingMode.SINGLE_ROUND, + ) # type: ignore[call-arg] return @@ -298,14 +402,15 @@ def build_rescale_from_int32( def build_rescale_conv_output( - tosa_fb: ts.TosaSerializer, - op: ts.TosaSerializerTensor, + tosa_fb: Any, + op: Any, output_name: str, - output_type: ts.DType, + output_type: Any, input_scale: list[float], weight_scale: list[float], output_scale: list[float], output_zp: int, + tosa_spec=tosa_specification.Tosa_0_80, ): # TODO add check to verify if this is a Per-channel quantization. post_conv2d_scale = [ @@ -313,16 +418,29 @@ def build_rescale_conv_output( ] # Since we assume the input tensor that is being rescaled is int32 date type, zero point must be 0. - build_rescale( - tosa_fb, - post_conv2d_scale, - op, - output_name, - output_type, - op.shape, - 0, - output_zp, - False, - isinstance(weight_scale, torch.Tensor), - ) + if tosa_spec == tosa_specification.Tosa_0_80: + build_rescale_v0_80( + tosa_fb=tosa_fb, + scale=post_conv2d_scale, + input_node=op, + output_name=output_name, + output_type=output_type, + input_zp=0, + output_zp=output_zp, + per_channel=isinstance(weight_scale, torch.Tensor), + ) # type: ignore[call-arg] + elif isinstance(tosa_spec[0], tosa_specification.Tosa_1_00): + # For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs + # to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale + build_rescale( + tosa_fb=tosa_fb, + scale=post_conv2d_scale, + input_node=op, + output_name=output_name, + output_type=output_type, + input_zp=0, + output_zp=output_zp, + rounding_mode=RoundingMode.SINGLE_ROUND, + per_channel=isinstance(weight_scale, torch.Tensor), + ) # type: ignore[call-arg] return