From 33782582d7499f02a2f5848c4d4dd5f10dac9d53 Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Mon, 4 Nov 2024 17:59:52 -0600 Subject: [PATCH] Revert "Search graph for quantization nodes (#6452)" This reverts commit 63017e4035950c1f9c59388c0c89ee6e543a53fa. --- .../annotate_channels_last_dim_order_pass.py | 5 +- .../_passes/insert_squeeze_after_sum_pass.py | 14 +- .../arm/_passes/size_adjust_conv2d_pass.py | 4 +- backends/arm/operators/op_addmm.py | 38 ++-- backends/arm/operators/op_bmm.py | 16 +- backends/arm/operators/op_conv2d.py | 22 +- backends/arm/operators/op_exp.py | 7 +- backends/arm/operators/op_full.py | 11 +- backends/arm/operators/op_hardtanh.py | 13 +- backends/arm/operators/op_log.py | 7 +- backends/arm/operators/op_mm.py | 16 +- backends/arm/operators/op_mul.py | 4 +- backends/arm/operators/op_placeholder.py | 17 +- backends/arm/operators/op_reciprocal.py | 7 +- backends/arm/operators/op_relu.py | 2 +- backends/arm/operators/op_rsqrt.py | 7 +- backends/arm/operators/op_sigmoid.py | 7 +- backends/arm/operators/op_tanh.py | 7 +- .../generic_annotator.py | 3 - .../quantization_annotation/mm_annotator.py | 4 +- backends/arm/test/ops/test_bmm.py | 20 +- backends/arm/test/ops/test_linear.py | 2 +- backends/arm/tosa_quant_utils.py | 214 ++++++------------ backends/arm/tosa_utils.py | 20 +- 24 files changed, 175 insertions(+), 292 deletions(-) diff --git a/backends/arm/_passes/annotate_channels_last_dim_order_pass.py b/backends/arm/_passes/annotate_channels_last_dim_order_pass.py index 786117e6457..77def9e7cd3 100644 --- a/backends/arm/_passes/annotate_channels_last_dim_order_pass.py +++ b/backends/arm/_passes/annotate_channels_last_dim_order_pass.py @@ -14,7 +14,7 @@ get_first_fake_tensor, insert_q_dq_pair, ) -from executorch.backends.arm.tosa_quant_utils import dq_op, q_op, register_passable_op +from executorch.backends.arm.tosa_quant_utils import dq_op, q_op from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -42,9 +42,6 @@ def _transpose_impl(*args, **kwargs): return args[0] -register_passable_op(torch.ops.passthrough_to_tosa._transpose) - - class AnnotateChannelsLastDimOrder(ExportPass): """ Annotates each node with a tosa_dim_order. tosa_dim_order can be seen as a channels-last dim-order diff --git a/backends/arm/_passes/insert_squeeze_after_sum_pass.py b/backends/arm/_passes/insert_squeeze_after_sum_pass.py index adf2b4f491c..152d5c95f6d 100644 --- a/backends/arm/_passes/insert_squeeze_after_sum_pass.py +++ b/backends/arm/_passes/insert_squeeze_after_sum_pass.py @@ -8,7 +8,9 @@ import torch import torch.fx -from executorch.backends.arm._passes.arm_pass_utils import create_node +from executorch.backends.arm._passes.arm_pass_utils import create_node, insert_q_dq_pair + +from executorch.backends.arm.tosa_quant_utils import get_quant_node_args, is_quant_node from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -26,6 +28,8 @@ class InsertSqueezeAfterSumPass(ExportPass): sum(dims, keep_dim = False) After pass: sum(dims, keep_dim = True) + (q) + (dq) squeeze(dim = dims) """ @@ -41,6 +45,12 @@ def call(self, graph_module: torch.fx.GraphModule): continue dim_list = cast(list[int], sum_node.args[1]) + quantized = is_quant_node(sum_node) + if quantized: + qparams = get_quant_node_args(sum_node.all_input_nodes[0]) + qparams = qparams + (torch.int8,) + else: + qparams = None # Add keep_dim = True arg to sum node. sum_node.args = sum_node.args[0:2] + (True,) @@ -51,6 +61,8 @@ def call(self, graph_module: torch.fx.GraphModule): ) sum_node.replace_all_uses_with(squeeze_node) squeeze_node.args = (sum_node, dim_list) + if quantized: + sum_node = insert_q_dq_pair(graph_module.graph, sum_node, qparams) graph_module.graph.eliminate_dead_code() graph_module.recompile() graph_module = super().call(graph_module).graph_module diff --git a/backends/arm/_passes/size_adjust_conv2d_pass.py b/backends/arm/_passes/size_adjust_conv2d_pass.py index c7bd27dcce0..980ab09e597 100644 --- a/backends/arm/_passes/size_adjust_conv2d_pass.py +++ b/backends/arm/_passes/size_adjust_conv2d_pass.py @@ -9,7 +9,7 @@ from typing import cast, Optional import torch.fx -from executorch.backends.arm.tosa_quant_utils import is_node_quantized +from executorch.backends.arm.tosa_quant_utils import is_quant_node from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult from torch._ops import OpOverload @@ -113,7 +113,7 @@ def call(self, graph_module: torch.fx.GraphModule): slice_node = graph.create_node( "call_function", self.slice_op, (last_node,) + args ) - if is_node_quantized(last_node): + if is_quant_node(last_node): q_params = last_node.args[1:] dq_node = insert_q_dq_pair( graph_module.graph, slice_node, q_params diff --git a/backends/arm/operators/op_addmm.py b/backends/arm/operators/op_addmm.py index 64de62767e2..b4f782db4a3 100644 --- a/backends/arm/operators/op_addmm.py +++ b/backends/arm/operators/op_addmm.py @@ -14,13 +14,10 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import ( - build_rescale, - search_quant_arg_downstream, - search_quant_arg_upstream, -) +from executorch.backends.arm.tosa_quant_utils import build_rescale, get_quant_node_args from executorch.backends.arm.tosa_utils import build_reshape +from executorch.exir.dialects._ops import ops as exir_ops from serializer.tosa_serializer import TosaOp @@ -70,7 +67,12 @@ def define_node( input_zp = 0 if is_quant_node: input_node = node.all_input_nodes[1] - input_zp = search_quant_arg_upstream(input_node).zp + # rank > 2 linear layer + if input_node.target == exir_ops.edge.aten.view_copy.default: + quant_node = input_node.all_input_nodes[0] + else: + quant_node = input_node + input_zp = get_quant_node_args(quant_node).zp attr.ConvAttribute( pad=pad_attr, stride=stride_attr, @@ -105,16 +107,24 @@ def define_node( # Read inputs' parent nodes _, input_node, weight_node = node.all_input_nodes - qargs = search_quant_arg_upstream(input_node) - input_scale = qargs.scale - consumer_node = list(node.users)[0] - quant_args = search_quant_arg_downstream(consumer_node) - - consumer_node_scale = quant_args.scale - consumer_node_node_zp = quant_args.zp + # rank > 2 linear layer + if input_node.target == exir_ops.edge.aten.view_copy.default: + quant_node = input_node.all_input_nodes[0] + input_scale = get_quant_node_args(quant_node).scale + consumer_node = list(node.users)[0] + consumer_consumer_node = list(consumer_node.users)[0] + quant_args = get_quant_node_args(consumer_consumer_node) + consumer_node_scale = quant_args.scale + consumer_node_node_zp = quant_args.zp + else: + input_scale = get_quant_node_args(input_node).scale + consumer_node = list(node.users)[0] + quant_args = get_quant_node_args(consumer_node) + consumer_node_scale = quant_args.scale + consumer_node_node_zp = quant_args.zp weight_node_q_node = weight_node.all_input_nodes[0] - weight_scale = search_quant_arg_upstream(weight_node_q_node).scale + weight_scale = get_quant_node_args(weight_node_q_node).scale output_rescale_scale = (input_scale * weight_scale) / consumer_node_scale diff --git a/backends/arm/operators/op_bmm.py b/backends/arm/operators/op_bmm.py index c4067e5a7c7..161b5d22396 100644 --- a/backends/arm/operators/op_bmm.py +++ b/backends/arm/operators/op_bmm.py @@ -14,11 +14,7 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import ( - build_rescale, - search_quant_arg_downstream, - search_quant_arg_upstream, -) +from executorch.backends.arm.tosa_quant_utils import build_rescale, get_quant_node_args from executorch.backends.arm.tosa_utils import get_two_inputs from serializer.tosa_serializer import TosaOp @@ -46,10 +42,8 @@ def define_node( # For INT8, we need to get the zero points and add an intermediate tensor # for a later rescale. if is_quant_node: - input0_q_params = search_quant_arg_upstream(input0) - input1_q_params = search_quant_arg_upstream(input1) - input0_zp = input0_q_params.zp - input1_zp = input1_q_params.zp + input0_zp = get_quant_node_args(input0).zp + input1_zp = get_quant_node_args(input1).zp bmm_result = tosa_graph.addIntermediate(output.shape, ts.DType.INT32) bmm_output_name = bmm_result.name else: @@ -69,7 +63,9 @@ def define_node( # As INT8 accumulates into INT32, we need to rescale it back to INT8 if is_quant_node: - output_q_params = search_quant_arg_downstream(list(node.users)[0]) + input0_q_params = get_quant_node_args(input0) + input1_q_params = get_quant_node_args(input1) + output_q_params = get_quant_node_args(list(node.users)[0]) final_output_scale = ( input0_q_params.scale * input1_q_params.scale diff --git a/backends/arm/operators/op_conv2d.py b/backends/arm/operators/op_conv2d.py index 8b2627ceda0..64cde0724f5 100644 --- a/backends/arm/operators/op_conv2d.py +++ b/backends/arm/operators/op_conv2d.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe -from typing import List +from typing import cast, List import serializer.tosa_serializer as ts import torch @@ -15,10 +15,9 @@ from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_quant_utils import ( build_rescale_conv_output, - search_quant_arg_downstream, - search_quant_arg_upstream, + get_quant_node_args, ) -from executorch.backends.arm.tosa_utils import build_reshape, tosa_shape +from executorch.backends.arm.tosa_utils import build_reshape, getNodeArgs, tosa_shape from serializer.tosa_serializer import TosaOp @@ -83,9 +82,7 @@ def define_node( ) input_zp = ( - search_quant_arg_upstream(node.all_input_nodes[0]).zp - if is_quant_node - else 0 + get_quant_node_args(node.all_input_nodes[0]).zp if is_quant_node else 0 ) attr.ConvAttribute( @@ -161,10 +158,9 @@ def define_node( # integer value domain of the next op. Otherwise return float32 output. if is_quant_node: # Get scale_factor from input, weight, and output. - input_scale = search_quant_arg_upstream(node.all_input_nodes[0]).scale - weight_scale = search_quant_arg_upstream(node.all_input_nodes[1]).scale - output_qargs = search_quant_arg_downstream(list(node.users)[0]) - + _, input_scale, _, _, _, _ = getNodeArgs(cast(torch.fx.Node, node.args[0])) + _, weight_scale, _, _, _, _ = getNodeArgs(cast(torch.fx.Node, node.args[1])) + _, output_scale, output_zp, _, _, _ = getNodeArgs(list(node.users)[0]) build_rescale_conv_output( tosa_graph, # pyre-fixme[61]: Uninitialized local [61]: Local variable `conv2d_res` is undefined, or not always defined. @@ -173,6 +169,6 @@ def define_node( actual_out_type, input_scale, weight_scale, - output_qargs.scale, - output_qargs.zp, + output_scale, + output_zp, ) diff --git a/backends/arm/operators/op_exp.py b/backends/arm/operators/op_exp.py index 115ee4606c1..0e0a75dcc47 100644 --- a/backends/arm/operators/op_exp.py +++ b/backends/arm/operators/op_exp.py @@ -17,10 +17,9 @@ from executorch.backends.arm.tosa_quant_utils import ( dequantize_value, + get_quant_node_args, QuantArgs, quantize_value, - search_quant_arg_downstream, - search_quant_arg_upstream, ) from serializer.tosa_serializer import TosaOp from torch.fx import Node @@ -49,9 +48,9 @@ def define_node( # Create attribute for 8 bit table lookup. input_node = node.all_input_nodes[0] - in_quantargs = search_quant_arg_upstream(input_node) + in_quantargs = get_quant_node_args(input_node) output_node = list(node.users)[0] - out_quantargs = search_quant_arg_downstream(output_node) + out_quantargs = get_quant_node_args(output_node) table = exp_table_8bit(in_quantargs, out_quantargs) table_attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_full.py b/backends/arm/operators/op_full.py index b2c14e4d465..cf67975e0d9 100644 --- a/backends/arm/operators/op_full.py +++ b/backends/arm/operators/op_full.py @@ -14,10 +14,7 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import ( - quantize_value, - search_quant_arg_downstream, -) +from executorch.backends.arm.tosa_quant_utils import get_quant_node_args from executorch.backends.arm.tosa_utils import tosa_shape from torch.fx import Node @@ -42,8 +39,10 @@ def define_node( value = inputs[1].number if is_quant_node: - qargs = search_quant_arg_downstream(list(node.users)[0]) - qvalue = quantize_value(value, qargs) + qargs = get_quant_node_args(list(node.users)[0]) + qvalue = np.clip( + np.round(value / qargs.scale) + qargs.zp, qargs.qmin, qargs.qmax + ) dtype = ts.DType.INT8 data = np.full(shape, qvalue, dtype=np.int8) else: diff --git a/backends/arm/operators/op_hardtanh.py b/backends/arm/operators/op_hardtanh.py index 184bb8173df..62c0a27f05f 100644 --- a/backends/arm/operators/op_hardtanh.py +++ b/backends/arm/operators/op_hardtanh.py @@ -14,10 +14,7 @@ ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import ( - quantize_value, - search_quant_arg_upstream, -) +from executorch.backends.arm.tosa_quant_utils import get_quant_node_args from serializer.tosa_serializer import TosaOp @@ -40,10 +37,12 @@ def define_node( if is_quant_node: # Get quant parameters - qargs = search_quant_arg_upstream(node.all_input_nodes[0]) + scale, zp, qmin, qmax = get_quant_node_args(node.all_input_nodes[0]) # Convert to quantized representation - clamp_min_qs = quantize_value(inputs[1].number, qargs) - clamp_max_qs = quantize_value(inputs[2].number, qargs) + clamp_min_qs = round((inputs[1].number / scale) + zp) + clamp_min_qs = max(clamp_min_qs, qmin) + clamp_max_qs = round((inputs[2].number / scale) + zp) + clamp_max_qs = min(clamp_max_qs, qmax) # Set fp values to 0.0 since they are not used clamp_min_fp = 0.0 clamp_max_fp = 0.0 diff --git a/backends/arm/operators/op_log.py b/backends/arm/operators/op_log.py index 8512e3eb300..5276173efa3 100644 --- a/backends/arm/operators/op_log.py +++ b/backends/arm/operators/op_log.py @@ -17,10 +17,9 @@ from executorch.backends.arm.tosa_quant_utils import ( dequantize_value, + get_quant_node_args, QuantArgs, quantize_value, - search_quant_arg_downstream, - search_quant_arg_upstream, ) from serializer.tosa_serializer import TosaOp from torch.fx import Node @@ -50,9 +49,9 @@ def define_node( # Create attribute for 8 bit table lookup. input_node = node.all_input_nodes[0] - in_quantargs = search_quant_arg_upstream(input_node) + in_quantargs = get_quant_node_args(input_node) output_node = list(node.users)[0] - out_quantargs = search_quant_arg_downstream(output_node) + out_quantargs = get_quant_node_args(output_node) table = log_table_8bit(in_quantargs, out_quantargs) table_attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_mm.py b/backends/arm/operators/op_mm.py index b59baed69a8..ebddb3a40e2 100644 --- a/backends/arm/operators/op_mm.py +++ b/backends/arm/operators/op_mm.py @@ -14,11 +14,7 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import ( - build_rescale, - search_quant_arg_downstream, - search_quant_arg_upstream, -) +from executorch.backends.arm.tosa_quant_utils import build_rescale, get_quant_node_args from executorch.backends.arm.tosa_utils import ( build_reshape, expand_dims, @@ -58,8 +54,8 @@ def define_node( # For INT8, we need to get the zero point, otherwise it is 0 input0_zp, input1_zp = 0, 0 if is_quant_node: - input0_zp = search_quant_arg_upstream(input0).zp - input1_zp = search_quant_arg_upstream(input1).zp + input0_zp = get_quant_node_args(input0).zp + input1_zp = get_quant_node_args(input1).zp mat_mul_result = tosa_graph.addIntermediate( output_new_shape, ts.DType.INT32 if is_quant_node else output.dtype @@ -90,9 +86,9 @@ def define_node( # As INT8 accumulates into INT32, we need to rescale it back to INT8 if is_quant_node: - input0_q_params = search_quant_arg_upstream(input0) - input1_q_params = search_quant_arg_upstream(input1) - output_q_params = search_quant_arg_downstream(list(node.users)[0]) + input0_q_params = get_quant_node_args(input0) + input1_q_params = get_quant_node_args(input1) + output_q_params = get_quant_node_args(list(node.users)[0]) final_output_scale = ( input0_q_params.scale * input1_q_params.scale diff --git a/backends/arm/operators/op_mul.py b/backends/arm/operators/op_mul.py index 8d507567114..c152e8759ef 100644 --- a/backends/arm/operators/op_mul.py +++ b/backends/arm/operators/op_mul.py @@ -37,10 +37,10 @@ def define_node( if is_quant_node: input_A = inputs[0] input_B = inputs[1] - input_A_qargs = tqutils.search_quant_arg_upstream( + input_A_qargs = tqutils.get_quant_node_args( cast(torch.fx.Node, node.args[0]) ) - input_B_qargs = tqutils.search_quant_arg_upstream( + input_B_qargs = tqutils.get_quant_node_args( cast(torch.fx.Node, node.args[1]) ) diff --git a/backends/arm/operators/op_placeholder.py b/backends/arm/operators/op_placeholder.py index 00bebba09d7..2618c9e71d3 100644 --- a/backends/arm/operators/op_placeholder.py +++ b/backends/arm/operators/op_placeholder.py @@ -10,14 +10,13 @@ import torch.fx from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_quant_utils import ( - get_quantized_node_output_dtype, - is_node_quantized, - search_quant_arg_upstream, + get_quant_arg_dtype, + get_quant_node_args, + is_quant_arg, ) from executorch.backends.arm.tosa_utils import ( is_bias_node_for_quantized_addmm, is_bias_node_for_quantized_conv, - map_dtype, tosa_shape, ) from executorch.exir.dialects._ops import ops as exir_ops @@ -42,11 +41,7 @@ def process_inputs( tensor = ts.TosaSerializerTensor( inputs[0].name, tosa_shape(input_shape, input_dim_order), - ( - map_dtype(get_quantized_node_output_dtype(node)) - if is_node_quantized(node) - else inputs[0].dtype - ), + get_quant_arg_dtype(node) if is_quant_arg(node) else inputs[0].dtype, data=None, placeholderFilename=inputs[0].name + ".npy", ) @@ -80,8 +75,8 @@ def process_quantized_bias( _, ) = consumer_node.all_input_nodes - input_node_scale = search_quant_arg_upstream(input_node).scale - weight_node_scale = search_quant_arg_upstream(weight_node).scale + input_node_scale = get_quant_node_args(input_node).scale + weight_node_scale = get_quant_node_args(weight_node).scale bias_values_quantized = ( (parameter_values / (input_node_scale * weight_node_scale)) .round() diff --git a/backends/arm/operators/op_reciprocal.py b/backends/arm/operators/op_reciprocal.py index 051d8bf4d7a..3d43fd8f7da 100644 --- a/backends/arm/operators/op_reciprocal.py +++ b/backends/arm/operators/op_reciprocal.py @@ -15,10 +15,9 @@ from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_quant_utils import ( dequantize_value, + get_quant_node_args, QuantArgs, quantize_value, - search_quant_arg_downstream, - search_quant_arg_upstream, ) from serializer.tosa_serializer import TosaOp @@ -42,8 +41,8 @@ def define_node( if is_quant_node: input = inputs[0] - input_qargs = search_quant_arg_upstream(node.all_input_nodes[0]) - output_qargs = search_quant_arg_downstream(list(node.users)[0]) + input_qargs = get_quant_node_args(node.all_input_nodes[0]) + output_qargs = get_quant_node_args(list(node.users)[0]) div_table = div_table_8bit(input_qargs, output_qargs) diff --git a/backends/arm/operators/op_relu.py b/backends/arm/operators/op_relu.py index afc2fd88d6c..20bba3f6545 100644 --- a/backends/arm/operators/op_relu.py +++ b/backends/arm/operators/op_relu.py @@ -38,7 +38,7 @@ def define_node( clamp_min_qs = 0 clamp_max_qs = 0 if is_quant_node: - out_qargs = tqutils.search_quant_arg_downstream(list(node.users)[0]) + out_qargs = tqutils.get_quant_node_args(list(node.users)[0]) clamp_min_qs = tqutils.quantize_value(0, out_qargs) clamp_max_qs = tqutils.quantize_value(float("inf"), out_qargs) diff --git a/backends/arm/operators/op_rsqrt.py b/backends/arm/operators/op_rsqrt.py index d256a1c633a..9225c7d938f 100644 --- a/backends/arm/operators/op_rsqrt.py +++ b/backends/arm/operators/op_rsqrt.py @@ -16,10 +16,9 @@ from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_quant_utils import ( dequantize_value, + get_quant_node_args, QuantArgs, quantize_value, - search_quant_arg_downstream, - search_quant_arg_upstream, ) from serializer.tosa_serializer import TosaOp @@ -40,9 +39,9 @@ def define_node( # Assume quantized input is 8 bit. # Create attribute for 8 bit table lookup. input_node = node.all_input_nodes[0] - in_quantargs = search_quant_arg_upstream(input_node) + in_quantargs = get_quant_node_args(input_node) output_node = list(node.users)[0] - out_quantargs = search_quant_arg_downstream(output_node) + out_quantargs = get_quant_node_args(output_node) table = rsqrt_table_8bit(in_quantargs, out_quantargs) table_attr = ts.TosaSerializerAttribute() table_attr.TableAttribute(table) diff --git a/backends/arm/operators/op_sigmoid.py b/backends/arm/operators/op_sigmoid.py index d0e321f6fd9..0087b1f7a81 100644 --- a/backends/arm/operators/op_sigmoid.py +++ b/backends/arm/operators/op_sigmoid.py @@ -17,10 +17,9 @@ from executorch.backends.arm.tosa_quant_utils import ( dequantize_value, + get_quant_node_args, QuantArgs, quantize_value, - search_quant_arg_downstream, - search_quant_arg_upstream, ) from serializer.tosa_serializer import TosaOp from torch.fx import Node @@ -50,9 +49,9 @@ def define_node( # Create attribute for 8 bit table lookup. input_node = node.all_input_nodes[0] - in_quantargs = search_quant_arg_upstream(input_node) + in_quantargs = get_quant_node_args(input_node) output_node = list(node.users)[0] - out_quantargs = search_quant_arg_downstream(output_node) + out_quantargs = get_quant_node_args(output_node) table = sigmoid_table_8bit(in_quantargs, out_quantargs) table_attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_tanh.py b/backends/arm/operators/op_tanh.py index 7a556a53799..20f343a7f1b 100644 --- a/backends/arm/operators/op_tanh.py +++ b/backends/arm/operators/op_tanh.py @@ -17,10 +17,9 @@ from executorch.backends.arm.tosa_quant_utils import ( dequantize_value, + get_quant_node_args, QuantArgs, quantize_value, - search_quant_arg_downstream, - search_quant_arg_upstream, ) from serializer.tosa_serializer import TosaOp from torch.fx import Node @@ -50,9 +49,9 @@ def define_node( # Create attribute for 8 bit table lookup. input_node = node.all_input_nodes[0] - in_quantargs = search_quant_arg_upstream(input_node) + in_quantargs = get_quant_node_args(input_node) output_node = list(node.users)[0] - out_quantargs = search_quant_arg_downstream(output_node) + out_quantargs = get_quant_node_args(output_node) table = tanh_table_8bit(in_quantargs, out_quantargs) table_attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/quantizer/quantization_annotation/generic_annotator.py b/backends/arm/quantizer/quantization_annotation/generic_annotator.py index a35f5c0fdae..f91df1398e8 100644 --- a/backends/arm/quantizer/quantization_annotation/generic_annotator.py +++ b/backends/arm/quantizer/quantization_annotation/generic_annotator.py @@ -27,9 +27,6 @@ torch.ops.aten.unsqueeze.default, torch.ops.aten.unsqueeze_copy.default, torch.ops.aten.reshape.default, - torch.ops.aten.repeat.default, - torch.ops.aten.expand_copy.default, - torch.ops.aten.expand.default, # Disabling these as there seems to be an issue with support for complex # datatypes in torch: # torch.ops.aten.view_as_complex.default, diff --git a/backends/arm/quantizer/quantization_annotation/mm_annotator.py b/backends/arm/quantizer/quantization_annotation/mm_annotator.py index 60d9adb1c3c..b48c6d59905 100644 --- a/backends/arm/quantizer/quantization_annotation/mm_annotator.py +++ b/backends/arm/quantizer/quantization_annotation/mm_annotator.py @@ -24,9 +24,7 @@ def _annotate_mm( quantization_config: QuantizationConfig, filter_fn: Optional[Callable[[Node], bool]] = None, ) -> Optional[List[List[Node]]]: - mm_partitions = get_source_partitions( - gm.graph, [torch.mm, torch.bmm, torch.matmul], filter_fn - ) + mm_partitions = get_source_partitions(gm.graph, [torch.mm, torch.bmm], filter_fn) mm_partitions = list(itertools.chain.from_iterable(mm_partitions.values())) annotated_partitions = [] for mm_partition in mm_partitions: diff --git a/backends/arm/test/ops/test_bmm.py b/backends/arm/test/ops/test_bmm.py index a61cc7f1c8d..e4e6abb7bb3 100644 --- a/backends/arm/test/ops/test_bmm.py +++ b/backends/arm/test/ops/test_bmm.py @@ -32,12 +32,6 @@ class BMM(torch.nn.Module): def forward(self, x, y): return torch.bmm(x, y) - class MatMul(torch.nn.Module): - test_parameters = [(torch.rand(2, 3, 5), torch.rand(2, 5, 2))] - - def forward(self, x, y): - return torch.matmul(x, y) - class BMMSingleInput(torch.nn.Module): test_parameters = [ (torch.rand(20, 3, 3),), @@ -59,9 +53,9 @@ def _test_bmm_tosa_MI_pipeline( compile_spec=common.get_tosa_compile_spec(), ) .export() + .check_count({"torch.ops.aten.bmm.default": 1}) .check_not(["torch.ops.quantized_decomposed"]) .to_edge() - .check_count({"executorch_exir_dialects_edge__ops_aten_bmm_default": 1}) .partition() .check_not(["executorch_exir_dialects_edge__ops_aten_bmm_default"]) .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) @@ -80,9 +74,9 @@ def _test_bmm_tosa_BI_pipeline( ) .quantize() .export() + .check_count({"torch.ops.aten.bmm.default": 1}) .check(["torch.ops.quantized_decomposed"]) .to_edge() - .check_count({"executorch_exir_dialects_edge__ops_aten_bmm_default": 1}) .partition() .check_not(["executorch_exir_dialects_edge__ops_aten_bmm_default"]) .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) @@ -122,16 +116,6 @@ def test_bmm_single_input_tosa_MI(self, operand1: torch.Tensor): test_data = (operand1,) self._test_bmm_tosa_MI_pipeline(self.BMMSingleInput(), test_data) - @parameterized.expand(MatMul.test_parameters) - def test_matmul_tosa_MI(self, operand1: torch.Tensor, operand2: torch.Tensor): - test_data = (operand1, operand2) - self._test_bmm_tosa_MI_pipeline(self.MatMul(), test_data) - - @parameterized.expand(MatMul.test_parameters) - def test_matmul_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): - test_data = (operand1, operand2) - self._test_bmm_tosa_BI_pipeline(self.MatMul(), test_data) - @parameterized.expand(BMM.test_parameters) def test_bmm_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): test_data = (operand1, operand2) diff --git a/backends/arm/test/ops/test_linear.py b/backends/arm/test/ops/test_linear.py index 7d463545887..3f68ab0251a 100644 --- a/backends/arm/test/ops/test_linear.py +++ b/backends/arm/test/ops/test_linear.py @@ -151,7 +151,7 @@ def _test_linear_tosa_BI_pipeline( .partition() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() - .run_method_and_compare_outputs(inputs=test_data, qtol=1) + .run_method_and_compare_outputs(inputs=test_data, qtol=True) ) def _test_linear_tosa_ethosu_BI_pipeline( diff --git a/backends/arm/tosa_quant_utils.py b/backends/arm/tosa_quant_utils.py index d195c7f4464..fe408e41b3a 100644 --- a/backends/arm/tosa_quant_utils.py +++ b/backends/arm/tosa_quant_utils.py @@ -15,31 +15,14 @@ import serializer.tosa_serializer as ts import torch.fx import tosa.Op as TosaOp -from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg from executorch.exir.dialects._ops import ops as exir_ops from serializer.tosa_serializer import TosaSerializerTensor from torch.fx import Node - q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default -dq_q_ops = (q_op, dq_op) -passable_ops = [ - exir_ops.edge.aten.view_copy.default, - exir_ops.edge.aten.permute_copy.default, - exir_ops.edge.aten.squeeze_copy.dims, - exir_ops.edge.aten.unsqueeze_copy.default, - exir_ops.edge.aten.split_with_sizes_copy.default, - exir_ops.edge.aten.repeat.default, - exir_ops.edge.aten.clone.default, - exir_ops.edge.aten.slice_copy.Tensor, - exir_ops.edge.aten.cat.default, -] - - -def register_passable_op(op): - """We need to be able to add custom ops such as tosa_transpose to the passable_op list after they have been created""" - passable_ops.append(op) +dq_q_ops = [q_op, dq_op] class QuantArgs(NamedTuple): @@ -47,19 +30,6 @@ class QuantArgs(NamedTuple): zp: int qmin: int qmax: int - dtype: torch.dtype - - def quantize_value(self, x): - if not isinstance(x, torch.Tensor): - x = torch.Tensor([x]) - return torch.clip( - torch.round(x / self.scale) + self.zp, - self.qmin, - self.qmax, - ).to(self.dtype) - - def dequantize_value(self, qx): - return (qx - self.zp) * self.scale def quantize_value(x, qargs: QuantArgs, dtype=np.int8): @@ -74,135 +44,81 @@ def dequantize_value(qx, qargs: QuantArgs): return (qx - qargs.zp) * qargs.scale -def qargs_from_qnode(node: torch.fx.Node): - assert node.target in dq_q_ops, f"Op {node} is not a quant node." +def is_quant_node(node: torch.fx.Node): - return QuantArgs(*node.args[1:]) + consumer_node_condition = False + if len(list(node.users)) > 0: + consumer_node = list(node.users)[0] + # For Rank > 2 Linear layers, the quant node is after the view_copy + if ( + node.target == exir_ops.edge.aten.addmm.default + and consumer_node.target == exir_ops.edge.aten.view_copy.default + ): + consumer_consumer_node = list(consumer_node.users)[0] + return True if consumer_consumer_node.target == q_op else False + consumer_node_condition = consumer_node.target == q_op -def get_neighbour_quant_args( - node: torch.fx.Node, -) -> tuple[list[QuantArgs], list[QuantArgs]]: - user_q_args = [] + input_node_condition = False + if len(node.all_input_nodes) > 0: + input = node.all_input_nodes[0] + input_node_condition = input.target in dq_q_ops - for user in node.users: - q_args = search_quant_arg_downstream(user) - if q_args: - user_q_args.append(q_args) + return node.target in dq_q_ops or consumer_node_condition or input_node_condition - input_q_nodes = [] - for input_node in node.all_input_nodes: - q_args = search_quant_arg_upstream(input_node) - if q_args: - input_q_nodes.append(q_args) - return user_q_args, input_q_nodes +def get_quant_node_dtype(node: torch.fx.Node): + # pyre-ignore[16]: Undefined attribute. + if "tosa" in node.target.__name__: + return node.meta["val"].dtype -def all_q_args_equal(q_arg_list: list[QuantArgs]) -> bool: - first_q_arg = q_arg_list[0] - for q_arg in q_arg_list: - if q_arg != first_q_arg: - return False - return True + if node.target in dq_q_ops: + return node.args[5] + # if not a tosa node, nor a q/dq op, walk the graph until we find a q op + consumer_node = list(node.users)[0] + while True: + if consumer_node.target in dq_q_ops: + return consumer_node.args[5] -def is_node_quantized(node: torch.fx.Node) -> bool: - if node.target in dq_q_ops: - return True + # Try to move on to the next node + if len(consumer_node.users) == 0: + raise RuntimeError(f"No quantized node found in graph for node {node}") + consumer_node = list(consumer_node.users)[0] - user_q_args, input_q_args = get_neighbour_quant_args(node) - # If we did not find any neighbouring quant nodes, we are not quantized. - if len(input_q_args) == 0 and len(user_q_args) == 0: - return False +def is_quant_arg(arg): + consumer_node = list(arg.users)[0] + return consumer_node.target == q_op - if node.target in passable_ops: - assert all_q_args_equal( - user_q_args + input_q_args - ), f"Node {node} needs same quantization parameters on all inputs and outputs." - return True +def get_quant_arg_dtype(node: torch.fx.Node): + consumer_node = list(node.users)[0] + # Get type of quant node, args differ from per_tensor and per_channel. + if consumer_node.target == q_op: + if is_quant_arg(node): + return map_dtype(consumer_node.args[5]) + else: + raise RuntimeError("Quantization argument not found") -def search_quant_arg_downstream(node: torch.fx.Node) -> QuantArgs | None: - """ - Iterates downward in the graph passing through 'passable_ops' to find and return a quantization node, - starting with 'node'. - If a passable node with multiple consumers is encountered, - find QuantArgs for all consumers and assert that they are equal. - If a node not in passable_ops is encountered, return None. - If a node without consumers is encountered, return None. - """ - if node.target in dq_q_ops: - return qargs_from_qnode(node) - if node.target not in passable_ops: - return None - consumer_nodes = list(node.users) - if len(consumer_nodes) == 0: - return None - elif len(consumer_nodes) == 1: - return search_quant_arg_downstream(consumer_nodes[0]) - else: - consumer_qargs: list[QuantArgs] = [] - for input in consumer_nodes: - quant_args = search_quant_arg_downstream(input) - if quant_args: - consumer_qargs.append(quant_args) - if len(quant_args) == 0: - return None - assert all_q_args_equal( - consumer_qargs - ), f"Encountered a op, {node}, in passable_ops with different QuantArgs for different consumers." - return consumer_qargs[0] - - -def search_quant_arg_upstream(node: torch.fx.Node) -> QuantArgs | None: - """ - Iterates upward in the graph passing through 'passable_ops' to find and return a quantization node, - starting with 'node'. - If a passable node with multiple inputs is encountered, - find QuantArgs for all inputs and assert that they are equal. - If a node not in passable_ops is encountered, return None. - If a node without inputs is encountered, return None. - """ - if node.target in dq_q_ops: - return qargs_from_qnode(node) - if node.target not in passable_ops: - return None - input_nodes = list(node.all_input_nodes) - if len(input_nodes) == 0: - return None - elif len(input_nodes) == 1: - return search_quant_arg_upstream(input_nodes[0]) - else: - input_qargs: list[QuantArgs] = [] - for input in input_nodes: - quant_args = search_quant_arg_upstream(input) - if quant_args: - input_qargs.append(quant_args) - if len(quant_args) == 0: - return None - assert all_q_args_equal( - input_qargs - ), f"Encountered a op, {node}, in passable_ops with different QuantArgs for different inputs." - return input_qargs[0] - - -def get_quantized_node_output_dtype(node: torch.fx.Node): - if hasattr(node.target, "__name__") and "tosa" in node.target.__name__: - return node.meta["val"].dtype - if node.target in dq_q_ops: - return node.args[5] +def get_quant_node_args(node: torch.fx.Node): + """ + Get the quantization parameters from a quant node. - # if not a tosa node, nor a q/dq op, walk the graph until we find a q op - user_q_args, input_q_args = get_neighbour_quant_args(node) - if len(user_q_args) > 0: - return user_q_args[0].dtype - elif node.target in passable_ops and len(input_q_args): - return input_q_args[0].dtype - else: - raise RuntimeError("No quantized node found in graph") + Args: + node: The quant node. + Returns: + QuantArgs: scale, zp, qmin, qmax + """ + quant_args = [TosaArg(arg) for arg in node.args] + return QuantArgs( + quant_args[1].number, + quant_args[2].number, + quant_args[3].number, + quant_args[4].number, + ) # Check if scale32 mode is used for given output element type @@ -351,14 +267,14 @@ def rescale_nodes_to_int32( needed by rescale_node_back_to_int8. """ - tensors = [TosaArg(node) for node in nodes] + tensors = [TosaArg(node.args[0]) for node in nodes] # Reshape tensor according to tosa dim order for tensor in tensors: dim_order = tensor.dim_order tensor.shape = [tensor.shape[i] for i in dim_order] - qargs = [search_quant_arg_upstream(node) for node in nodes] + qargs = [get_quant_node_args(node) for node in nodes] # Scale the int8 quantized input to a common scale in the integer # domain @@ -391,7 +307,7 @@ def rescale_node_back_to_int8( scale: the scaling factor used to rescale to int32, from the function 'rescale_nodes_to_int32' tosa_graph: the tosa_graph to manipulate. """ - qargs_out = search_quant_arg_downstream(list(node.users)[0]) + qargs_out = get_quant_node_args(list(node.users)[0]) output_rescale_scale = scale / qargs_out.scale # Rescale Back to INT8 @@ -418,7 +334,7 @@ def build_rescale_conv_output( output_zp, ): # TODO add check to verify if this is a Per-channel quantization. - post_conv2d_scale = (input_scale * weight_scale) / output_scale + post_conv2d_scale = (input_scale.number * weight_scale.number) / output_scale.number # Since we assume the input tensor that is being rescaled is int32 date type, zero point must be 0. build_rescale( @@ -429,6 +345,6 @@ def build_rescale_conv_output( output_type, op.shape, 0, - output_zp, + output_zp.number, ) return diff --git a/backends/arm/tosa_utils.py b/backends/arm/tosa_utils.py index b3e9f4e1c3c..cfafac16760 100644 --- a/backends/arm/tosa_utils.py +++ b/backends/arm/tosa_utils.py @@ -16,11 +16,10 @@ from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg from executorch.backends.arm.tosa_quant_utils import ( - get_quantized_node_output_dtype, - is_node_quantized, + get_quant_node_args, + get_quant_node_dtype, + is_quant_node, q_op, - search_quant_arg_downstream, - search_quant_arg_upstream, ) from executorch.exir.dialects._ops import ops as exir_ops from serializer.tosa_serializer import TosaOp @@ -238,8 +237,8 @@ def build_avg_pool_2d_common( output_zp = 0 if is_quant_node: - input_zp = search_quant_arg_upstream(cast(torch.fx.Node, node.args[0])).zp - output_zp = search_quant_arg_downstream(list(node.users)[0]).zp + input_zp = get_quant_node_args(cast(torch.fx.Node, node.args[0])).zp + output_zp = get_quant_node_args(list(node.users)[0]).zp attr = ts.TosaSerializerAttribute() attr.PoolAttribute( @@ -298,11 +297,6 @@ def process_call_function( # Convert output (this node itself) output = TosaArg(node) - is_quant_node = is_node_quantized(node) - if is_quant_node: - output_dtype = map_dtype(get_quantized_node_output_dtype(node)) - else: - output_dtype = output.dtype tosa_graph.currRegion.currBasicBlock.addTensor( output.name, ( @@ -310,7 +304,7 @@ def process_call_function( if is_permute_node_before_addmm(node) else tosa_shape(output.shape, output.dim_order) ), - output_dtype, + map_dtype(get_quant_node_dtype(node)) if is_quant_node(node) else output.dtype, ) # Visiting each Node @@ -322,7 +316,7 @@ def process_call_function( tosa_graph, inputs, output, - is_quant_node, + is_quant_node(node), ) else: raise RuntimeError(f"Unknown operator {node.target}")