| 
 | 1 | +# Copyright 2024 Arm Limited and/or its affiliates.  | 
 | 2 | +# All rights reserved.  | 
 | 3 | +#  | 
 | 4 | +# This source code is licensed under the BSD-style license found in the  | 
 | 5 | +# LICENSE file in the root directory of this source tree.  | 
 | 6 | + | 
 | 7 | +import copy  | 
 | 8 | + | 
 | 9 | +from typing import Callable, cast, Iterable  | 
 | 10 | + | 
 | 11 | +from executorch.backends.arm.tosa_quant_utils import QuantArgs  | 
 | 12 | + | 
 | 13 | +from executorch.exir.dialects._ops import ops as exir_ops  | 
 | 14 | + | 
 | 15 | +from executorch.exir.pass_base import ExportPass, PassResult  | 
 | 16 | +from torch.fx import GraphModule, Node  | 
 | 17 | + | 
 | 18 | +q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default  | 
 | 19 | +dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default  | 
 | 20 | + | 
 | 21 | + | 
 | 22 | +def get_input_qparams(node: Node) -> dict[int, QuantArgs]:  | 
 | 23 | +    """  | 
 | 24 | +    Get the input quantization parameters from a node, set by the 'FoldAndAnnotateQParamsPass'.  | 
 | 25 | +    Raises a ValueError if the node doesn't have any parameters set.  | 
 | 26 | +    """  | 
 | 27 | +    if "input_qparams" not in node.meta.keys():  | 
 | 28 | +        raise ValueError(f"No input quantization parameter found in node {node}")  | 
 | 29 | +    input_qparams = cast(dict[int, QuantArgs], node.meta["input_qparams"])  | 
 | 30 | +    if len(input_qparams) == 0:  | 
 | 31 | +        raise ValueError(f"No input quantization parameter found in node {node}")  | 
 | 32 | +    return input_qparams  | 
 | 33 | + | 
 | 34 | + | 
 | 35 | +def get_output_qparams(node: Node) -> dict[int, QuantArgs]:  | 
 | 36 | +    """  | 
 | 37 | +    Get the output quantization parameters from a node, set by the 'FoldAndAnnotateQParamsPass'.  | 
 | 38 | +    Raises a ValueError if the node doesn't have any parameters set.  | 
 | 39 | +    """  | 
 | 40 | +    if "output_qparams" not in node.meta.keys():  | 
 | 41 | +        raise ValueError(f"No output quantization parameter found in node {node}")  | 
 | 42 | +    input_qparams = cast(dict[int, QuantArgs], node.meta["output_qparams"])  | 
 | 43 | +    if len(input_qparams) == 0:  | 
 | 44 | +        raise ValueError(f"No output quantization parameter found in node {node}")  | 
 | 45 | +    return input_qparams  | 
 | 46 | + | 
 | 47 | + | 
 | 48 | +class FoldAndAnnotateQParamsPass(ExportPass):  | 
 | 49 | +    """  | 
 | 50 | +    A pass that walks the graph and removes any DQ and Q nodes before and after the target  | 
 | 51 | +     node in the supplied list of operators.  | 
 | 52 | +     The quantization parameters from the DQ/Q nodes are stored as meta values to be  | 
 | 53 | +     accessible for later lowering and serialization passes.  | 
 | 54 | +     The assumption is that the quantization annotatation adds DQ nodes for all tensor  | 
 | 55 | +     inputs to the target one Q node to the output.  | 
 | 56 | +
  | 
 | 57 | +     Example ('executorch_exir_dialects_edge__ops_' prefix removed from operators for readability):  | 
 | 58 | +
  | 
 | 59 | +        x_q: "i8[5]" = quantized_decomposed_quantize_per_tensor_default(x, 0.05487706884741783, -128, -128, 127, torch.int8)  | 
 | 60 | +
  | 
 | 61 | +        x_dq: "f32[5]" = quantized_decomposed_dequantize_per_tensor_default(x_q, 0.05487706884741783, -128, -128, 127, torch.int8)  | 
 | 62 | +        aten_add_tensor: "f32[5]" = ops_aten_add_Tensor(x_dq, x_dq)  | 
 | 63 | +        aten_add_tensor_q: "i8[5]" = quantized_decomposed_quantize_per_tensor_default(aten_add_tensor, 0.05487706884741783, -128, -128, 127, torch.int8)  | 
 | 64 | +
  | 
 | 65 | +        output_dq: "f32[5]" = quantized_decomposed_dequantize_per_tensor_default(aten_add_tensor_q, 0.05487706884741783, -128, -128, 127, torch.int8)  | 
 | 66 | +
  | 
 | 67 | +     Becomes:  | 
 | 68 | +        x_q: "i8[5]" = quantized_decomposed_quantize_per_tensor_default(x, 0.05487706884741783, -128, -128, 127, torch.int8)  | 
 | 69 | +
  | 
 | 70 | +        aten_add_tensor: "i8[5]" = aten_add_Tensor(x_q, x_q)  | 
 | 71 | +
  | 
 | 72 | +        output_dq: "f32[5]" = quantized_decomposed_dequantize_per_tensor_default(aten_add_tensor_q, 0.05487706884741783, -128, -128, 127, torch.int8)  | 
 | 73 | +
  | 
 | 74 | +    The quantization parameters for x_dq and aten_add_tensor_q are store in meta for the aten_add_tensor node.  | 
 | 75 | +
  | 
 | 76 | +    """  | 
 | 77 | + | 
 | 78 | +    def __init__(self, targeted_ops: Iterable[Callable]):  | 
 | 79 | +        super().__init__()  | 
 | 80 | +        self.targeted_ops = targeted_ops  | 
 | 81 | + | 
 | 82 | +    def call(self, graph_module: GraphModule) -> PassResult:  | 
 | 83 | + | 
 | 84 | +        # Loop over the graph nodes and find any node in the 'targeted_ops' list.  | 
 | 85 | +        for n in graph_module.graph.nodes:  | 
 | 86 | +            n = cast(Node, n)  | 
 | 87 | +            if n.op != "call_function" or n.target not in self.targeted_ops:  | 
 | 88 | +                continue  | 
 | 89 | + | 
 | 90 | +            # Make sure we haven't already set qparams meta information on the node  | 
 | 91 | +            assert "input_qparams" not in n.meta.keys()  | 
 | 92 | +            assert "output_qparams" not in n.meta.keys()  | 
 | 93 | + | 
 | 94 | +            # for the inputs and outputs search the graph for quantization info and  | 
 | 95 | +            # store the information in a dict with order of the _tensor_ inputs as key,  | 
 | 96 | +            # ignoring any other arguments to the target node.  | 
 | 97 | +            n.meta["input_qparams"] = {}  | 
 | 98 | +            n.meta["output_qparams"] = {}  | 
 | 99 | +            for i, arg in enumerate(n.args):  | 
 | 100 | +                if not isinstance(arg, Node):  | 
 | 101 | +                    continue  | 
 | 102 | + | 
 | 103 | +                # Make sure arg has requires_grad set to False  | 
 | 104 | +                # For parameters that are not quantized, sometimes (i.e. convolution)  | 
 | 105 | +                # the Parameter(FakeTensor(...)) has requires_grad set to True, which  | 
 | 106 | +                # causes the retracing of the graph to fail with:  | 
 | 107 | +                #  | 
 | 108 | +                # E       RuntimeError: isDifferentiableType(variable.scalar_type()) INTERNAL ASSERT FAILED at "/Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/functions/utils.h":74, please report a bug to PyTorch.  | 
 | 109 | +                # E  | 
 | 110 | +                # E       While executing %aten_convolution_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%quantized_decomposed_quantize_per_tensor_default, %b__frozen_param0, %p__param_constant1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})  | 
 | 111 | +                # E       Original traceback:  | 
 | 112 | +                # E         File "/Users/perast01/src/executorch/backends/arm/test/ops/test_conv2d.py", line 110, in forward  | 
 | 113 | +                # E           x = conv(x)  | 
 | 114 | +                #  | 
 | 115 | +                if arg.op == "placeholder":  | 
 | 116 | +                    arg.meta["val"].requires_grad = False  | 
 | 117 | + | 
 | 118 | +                if arg.target != dq_op:  | 
 | 119 | +                    continue  | 
 | 120 | + | 
 | 121 | +                # arg.target for argument i is a dequant node, extract the information  | 
 | 122 | +                n.meta["input_qparams"][i] = QuantArgs.from_operator(  | 
 | 123 | +                    arg.target, arg.args  | 
 | 124 | +                )  | 
 | 125 | + | 
 | 126 | +                # arg.args[0] is the tensor input, replace the input usage  | 
 | 127 | +                n.replace_input_with(arg, arg.args[0])  | 
 | 128 | +                graph_module.graph.erase_node(arg)  | 
 | 129 | + | 
 | 130 | +            # Copy the users, since we are modifying it.  | 
 | 131 | +            users_copy = copy.copy(n.users)  | 
 | 132 | +            for i, user in enumerate(users_copy):  | 
 | 133 | +                if user.target != q_op:  | 
 | 134 | +                    continue  | 
 | 135 | + | 
 | 136 | +                # quantization node found here, store the quantization parameters in meta value  | 
 | 137 | +                n.meta["output_qparams"][i] = QuantArgs.from_operator(  | 
 | 138 | +                    user.target, user.args  | 
 | 139 | +                )  | 
 | 140 | + | 
 | 141 | +                user.replace_all_uses_with(n)  | 
 | 142 | +                graph_module.graph.erase_node(user)  | 
 | 143 | + | 
 | 144 | +        # retrace the graph to update the fake tensor types  | 
 | 145 | +        graph_module = super().call(graph_module).graph_module  | 
 | 146 | + | 
 | 147 | +        graph_module.recompile()  | 
 | 148 | +        return PassResult(graph_module, True)  | 
 | 149 | + | 
 | 150 | + | 
 | 151 | +class QuantizeFullArgument(ExportPass):  | 
 | 152 | +    """  | 
 | 153 | +    Make sure the fill_value for full.default is quantized. This pass needs to be run before  | 
 | 154 | +    the folding pass above to make sure that the retraced output of the full.default op is  | 
 | 155 | +    the right dtype.  | 
 | 156 | +    """  | 
 | 157 | + | 
 | 158 | +    def call(self, graph_module: GraphModule) -> PassResult:  | 
 | 159 | +        modified = False  | 
 | 160 | +        # Loop over the graph nodes and find any node in the 'targeted_ops' list.  | 
 | 161 | +        for n in graph_module.graph.nodes:  | 
 | 162 | +            n = cast(Node, n)  | 
 | 163 | +            if n.target != exir_ops.edge.aten.full.default:  | 
 | 164 | +                continue  | 
 | 165 | + | 
 | 166 | +            # Make sure we have a quantized operator  | 
 | 167 | +            user = list(n.users)[0]  | 
 | 168 | +            if user.target != q_op:  | 
 | 169 | +                continue  | 
 | 170 | + | 
 | 171 | +            qargs = QuantArgs.from_operator(user.target, user.args)  | 
 | 172 | +            if "dtype" not in n.kwargs.keys() or n.kwargs["dtype"] != qargs.dtype:  | 
 | 173 | +                # replace the node arg with a quantized dito and also set dtype  | 
 | 174 | +                # to get the right output according to the Edge IR specification:  | 
 | 175 | +                # exir/dialects/edge/edge.yaml:3596  | 
 | 176 | +                quantized_full_value = qargs.quantize_value(n.args[1]).item()  | 
 | 177 | +                n.update_arg(1, quantized_full_value)  | 
 | 178 | +                n.update_kwarg("dtype", qargs.dtype)  | 
 | 179 | +                modified = True  | 
 | 180 | + | 
 | 181 | +        return PassResult(graph_module, modified)  | 
0 commit comments