| 
 | 1 | +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.  | 
 | 2 | +import logging  | 
 | 3 | +from typing import Any, Dict, List, Optional, Union  | 
 | 4 | + | 
 | 5 | +import numpy as np  | 
 | 6 | + | 
 | 7 | +import torch  | 
 | 8 | + | 
 | 9 | +from executorch.exir import EdgeProgramManager  | 
 | 10 | +from executorch.exir.dialects._ops import ops as exir_ops  | 
 | 11 | + | 
 | 12 | +from executorch.exir.pass_base import ExportPass  | 
 | 13 | +from executorch.exir.tensor import scalar_type_enum  | 
 | 14 | +from torch.fx.passes.infra.pass_base import PassResult  | 
 | 15 | + | 
 | 16 | +logger = logging.getLogger(__name__)  | 
 | 17 | + | 
 | 18 | + | 
 | 19 | +def quantize_input(  | 
 | 20 | +    exported_program, input_index, qparams: Optional[Dict[str, Any]] = None  | 
 | 21 | +):  | 
 | 22 | +    """  | 
 | 23 | +    Modify the program to expect quantized input at given index. The input is expected  | 
 | 24 | +    to be quantizing this input as the first step. Must be called before  | 
 | 25 | +    permute_input_layout. Returns the scale, zero point, qmin, qmax, and dtype of the  | 
 | 26 | +    expected quantization.  | 
 | 27 | +    """  | 
 | 28 | +    graph = exported_program.graph_module.graph  | 
 | 29 | +    name = exported_program.graph_signature.user_inputs[input_index]  | 
 | 30 | +    placeholders = [n for n in graph.nodes if n.op == "placeholder" and n.name == name]  | 
 | 31 | +    assert placeholders  | 
 | 32 | +    target_placeholder = placeholders[0]  | 
 | 33 | + | 
 | 34 | +    if len(target_placeholder.users) != 1:  | 
 | 35 | +        raise ValueError(f"Input {input_index} has more than one users")  | 
 | 36 | +    quantize = next(iter(target_placeholder.users))  | 
 | 37 | +    if (  | 
 | 38 | +        quantize.target  | 
 | 39 | +        != exir_ops.edge.quantized_decomposed.quantize_per_tensor.default  | 
 | 40 | +    ):  | 
 | 41 | +        raise ValueError(f"Input {input_index} is not used by a quantize op")  | 
 | 42 | + | 
 | 43 | +    # If user specified qparams are different from args of quantize op, we do requantization instead of eliminating quantize op  | 
 | 44 | +    need_requant = False  | 
 | 45 | +    if qparams is not None:  | 
 | 46 | +        assert all(  | 
 | 47 | +            qparam in qparams for qparam in ["scale", "zp", "dtype"]  | 
 | 48 | +        ), "dtype/scale/zp must be specified in qparam for input requantization"  | 
 | 49 | +        if qparams["dtype"] != quantize.args[5]:  | 
 | 50 | +            if any(  | 
 | 51 | +                dtype  | 
 | 52 | +                not in [torch.int8, torch.uint8, torch.bool, torch.int16, torch.uint16]  | 
 | 53 | +                for dtype in [qparams["dtype"], quantize.args[5]]  | 
 | 54 | +            ):  | 
 | 55 | +                raise ValueError(  | 
 | 56 | +                    f"Only limited data types are supported for requantization, but got {qparams['dtype']} -> {quantize.args[5]}"  | 
 | 57 | +                )  | 
 | 58 | + | 
 | 59 | +            need_requant = True  | 
 | 60 | +        elif (  | 
 | 61 | +            not np.isclose(qparams["scale"], quantize.args[1])  | 
 | 62 | +            or qparams["zp"] != quantize.args[2]  | 
 | 63 | +        ):  | 
 | 64 | +            need_requant = True  | 
 | 65 | + | 
 | 66 | +    if need_requant:  | 
 | 67 | +        assert qparams is not None  | 
 | 68 | +        dtype = qparams["dtype"]  | 
 | 69 | +        qmin = torch.iinfo(dtype).min  | 
 | 70 | +        qmax = torch.iinfo(dtype).max  | 
 | 71 | +        scale = qparams["scale"]  | 
 | 72 | +        zero_point = qparams["zp"]  | 
 | 73 | +        quant_args = (scale, zero_point, qmin, qmax, dtype)  | 
 | 74 | +        logger.info(  | 
 | 75 | +            f"Modifying program to requantize quantized input at index {input_index}"  | 
 | 76 | +        )  | 
 | 77 | +        logger.info(f"Quantization parameters: {quant_args}")  | 
 | 78 | + | 
 | 79 | +        with exported_program.graph_module.graph.inserting_before(quantize):  | 
 | 80 | +            input_dequant = exported_program.graph_module.graph.call_function(  | 
 | 81 | +                exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,  | 
 | 82 | +                args=(  | 
 | 83 | +                    target_placeholder,  | 
 | 84 | +                    *quant_args,  | 
 | 85 | +                ),  | 
 | 86 | +            )  | 
 | 87 | +            input_dequant.meta["input_qparams"] = [  | 
 | 88 | +                {  | 
 | 89 | +                    "scale": scale,  | 
 | 90 | +                    "zero_point": zero_point,  | 
 | 91 | +                    "qmin": qmin,  | 
 | 92 | +                    "qmax": qmax,  | 
 | 93 | +                    "dtype": dtype,  | 
 | 94 | +                }  | 
 | 95 | +            ]  | 
 | 96 | +            input_dequant.meta["val"] = quantize.meta["val"].to(torch.float32)  | 
 | 97 | +            target_placeholder.meta["val"] = target_placeholder.meta["val"].to(dtype)  | 
 | 98 | +            quantize.replace_input_with(target_placeholder, input_dequant)  | 
 | 99 | +    else:  | 
 | 100 | +        quant_args = quantize.args[1:]  | 
 | 101 | +        logger.info(f"Modifying program to take quantized input at index {input_index}")  | 
 | 102 | +        logger.info(f"Quantization parameters: {quant_args}")  | 
 | 103 | + | 
 | 104 | +        target_placeholder.meta["val"] = (  | 
 | 105 | +            exir_ops.edge.quantized_decomposed.quantize_per_tensor.default(  | 
 | 106 | +                target_placeholder.meta["val"], *quant_args  | 
 | 107 | +            )  | 
 | 108 | +        )  | 
 | 109 | +        quantize.replace_all_uses_with(quantize.args[0])  | 
 | 110 | + | 
 | 111 | +    exported_program.graph_module.graph.eliminate_dead_code()  | 
 | 112 | +    return quant_args  | 
 | 113 | + | 
 | 114 | + | 
 | 115 | +def quantize_output(exported_program, output_index):  | 
 | 116 | +    """  | 
 | 117 | +    Modify the program to produce quantized output at given index. The model is expected  | 
 | 118 | +    to be dequantizing this output as the last step. Must be called before  | 
 | 119 | +    permute_output_layout. Returns the scale, zero point, qmin, qmax, and dtype of the  | 
 | 120 | +    output quantization.  | 
 | 121 | +    """  | 
 | 122 | +    graph = exported_program.graph_module.graph  | 
 | 123 | +    outputs = [n for n in graph.nodes if n.op == "output"]  | 
 | 124 | +    if len(outputs) != 1:  | 
 | 125 | +        raise NotImplementedError("Only 1 output node is supported")  | 
 | 126 | + | 
 | 127 | +    output_node = outputs[0]  | 
 | 128 | +    output_list = list(output_node.args[0])  | 
 | 129 | +    if output_index >= len(output_list):  | 
 | 130 | +        raise ValueError(  | 
 | 131 | +            f"{len(output_list)} outputs available, "  | 
 | 132 | +            + f"output index out of bounds: {output_index}"  | 
 | 133 | +        )  | 
 | 134 | + | 
 | 135 | +    target_output = output_list[output_index]  | 
 | 136 | +    if (  | 
 | 137 | +        target_output.target  | 
 | 138 | +        != exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default  | 
 | 139 | +    ):  | 
 | 140 | +        raise ValueError("Output {output_index} is not a dequantize op")  | 
 | 141 | + | 
 | 142 | +    dequant = target_output  | 
 | 143 | +    output_list[output_index] = dequant.args[0]  | 
 | 144 | +    output_node.args = (output_list,)  | 
 | 145 | +    dequant_args = dequant.args[1:]  | 
 | 146 | +    graph.eliminate_dead_code()  | 
 | 147 | + | 
 | 148 | +    logger.info(  | 
 | 149 | +        f"Modifying program to produce quantized output at index {output_index}"  | 
 | 150 | +    )  | 
 | 151 | +    logger.info(f"Dequantization parameters: {dequant_args}")  | 
 | 152 | +    return dequant_args  | 
 | 153 | + | 
 | 154 | + | 
 | 155 | +def get_config_method_name(  | 
 | 156 | +    prefix: Optional[str] = "forward",  | 
 | 157 | +    arg_type: str = "input",  | 
 | 158 | +    index: int = 0,  | 
 | 159 | +    key: str = "scale",  | 
 | 160 | +):  | 
 | 161 | +    if prefix is None:  | 
 | 162 | +        prefix = ""  | 
 | 163 | +    else:  | 
 | 164 | +        prefix = prefix + "_"  | 
 | 165 | +    assert arg_type in ["input", "output"], "arg_type must be either input or output"  | 
 | 166 | +    assert index >= 0, "index must be non-negative"  | 
 | 167 | +    assert key in [  | 
 | 168 | +        "scale",  | 
 | 169 | +        "zp",  | 
 | 170 | +        "quant_min",  | 
 | 171 | +        "quant_max",  | 
 | 172 | +        "dtype",  | 
 | 173 | +    ], "key must be one of scale, zp, quant_min, quant_max, dtype"  | 
 | 174 | +    return f"{prefix}{arg_type}{index}_{key}"  | 
 | 175 | + | 
 | 176 | + | 
 | 177 | +class QuantizeInputs(ExportPass):  | 
 | 178 | +    def __init__(  | 
 | 179 | +        self,  | 
 | 180 | +        edge_program_manager: EdgeProgramManager,  | 
 | 181 | +        quantized_inputs_idx: Union[Dict[int, Dict[str, Any]], List[int]],  | 
 | 182 | +        method_name: Optional[str] = None,  | 
 | 183 | +    ):  | 
 | 184 | +        super().__init__()  | 
 | 185 | +        self.edge_program_manager = edge_program_manager  | 
 | 186 | + | 
 | 187 | +        self.quantized_inputs_idx_dict = {}  | 
 | 188 | +        if isinstance(quantized_inputs_idx, dict):  | 
 | 189 | +            self.quantized_inputs_idx_dict = quantized_inputs_idx  | 
 | 190 | +        else:  | 
 | 191 | +            for idx in quantized_inputs_idx:  | 
 | 192 | +                self.quantized_inputs_idx_dict[idx] = None  | 
 | 193 | +        self.param_prefix_name = method_name  | 
 | 194 | + | 
 | 195 | +    def call(self, graph_module: torch.fx.GraphModule):  | 
 | 196 | +        for i, qparams in self.quantized_inputs_idx_dict.items():  | 
 | 197 | +            quant_args = quantize_input(  | 
 | 198 | +                self.edge_program_manager.exported_program(), i, qparams  | 
 | 199 | +            )  | 
 | 200 | + | 
 | 201 | +            if not self.edge_program_manager._config_methods:  | 
 | 202 | +                self.edge_program_manager._config_methods = {}  | 
 | 203 | + | 
 | 204 | +            self.edge_program_manager._config_methods[  | 
 | 205 | +                get_config_method_name(self.param_prefix_name, "input", i, "scale")  | 
 | 206 | +            ] = quant_args[0]  | 
 | 207 | +            self.edge_program_manager._config_methods[  # pyre-ignore  | 
 | 208 | +                get_config_method_name(self.param_prefix_name, "input", i, "zp")  | 
 | 209 | +            ] = quant_args[1]  | 
 | 210 | +            self.edge_program_manager._config_methods[  | 
 | 211 | +                get_config_method_name(self.param_prefix_name, "input", i, "quant_min")  | 
 | 212 | +            ] = quant_args[2]  | 
 | 213 | +            self.edge_program_manager._config_methods[  | 
 | 214 | +                get_config_method_name(self.param_prefix_name, "input", i, "quant_max")  | 
 | 215 | +            ] = quant_args[3]  | 
 | 216 | +            self.edge_program_manager._config_methods[  | 
 | 217 | +                get_config_method_name(self.param_prefix_name, "input", i, "dtype")  | 
 | 218 | +            ] = scalar_type_enum(quant_args[4])  | 
 | 219 | +        return PassResult(graph_module, True)  | 
 | 220 | + | 
 | 221 | + | 
 | 222 | +class QuantizeOutputs(ExportPass):  | 
 | 223 | +    def __init__(  | 
 | 224 | +        self,  | 
 | 225 | +        edge_program_manager: EdgeProgramManager,  | 
 | 226 | +        quantized_outputs_idx_list: List[int],  | 
 | 227 | +        method_name: Optional[str] = None,  | 
 | 228 | +    ):  | 
 | 229 | +        super().__init__()  | 
 | 230 | +        self.edge_program_manager = edge_program_manager  | 
 | 231 | +        self.quantized_outputs_idx_list = quantized_outputs_idx_list  | 
 | 232 | +        self.param_prefix_name = method_name  | 
 | 233 | + | 
 | 234 | +    def call(self, graph_module: torch.fx.GraphModule):  | 
 | 235 | +        for i in self.quantized_outputs_idx_list:  | 
 | 236 | +            dequant_args = quantize_output(  | 
 | 237 | +                self.edge_program_manager.exported_program(), i  | 
 | 238 | +            )  # noqa F841  | 
 | 239 | + | 
 | 240 | +            if not self.edge_program_manager._config_methods:  | 
 | 241 | +                self.edge_program_manager._config_methods = {}  | 
 | 242 | + | 
 | 243 | +            self.edge_program_manager._config_methods[  | 
 | 244 | +                get_config_method_name(self.param_prefix_name, "output", i, "scale")  | 
 | 245 | +            ] = dequant_args[0]  | 
 | 246 | +            self.edge_program_manager._config_methods[  # pyre-ignore  | 
 | 247 | +                get_config_method_name(self.param_prefix_name, "output", i, "zp")  | 
 | 248 | +            ] = dequant_args[1]  | 
 | 249 | +            self.edge_program_manager._config_methods[  | 
 | 250 | +                get_config_method_name(self.param_prefix_name, "output", i, "quant_min")  | 
 | 251 | +            ] = dequant_args[2]  | 
 | 252 | +            self.edge_program_manager._config_methods[  | 
 | 253 | +                get_config_method_name(self.param_prefix_name, "output", i, "quant_max")  | 
 | 254 | +            ] = dequant_args[3]  | 
 | 255 | +            self.edge_program_manager._config_methods[  | 
 | 256 | +                get_config_method_name(self.param_prefix_name, "output", i, "dtype")  | 
 | 257 | +            ] = scalar_type_enum(dequant_args[4])  | 
 | 258 | + | 
 | 259 | +        return PassResult(graph_module, True)  | 
0 commit comments