|  | 
| 14 | 14 |     QuantizationConfig, | 
| 15 | 15 | ) | 
| 16 | 16 | from executorch.exir.dialects._ops import ops as exir_ops | 
|  | 17 | +from torch.ao.quantization.observer import MinMaxObserver | 
| 17 | 18 | from torch.ao.quantization.quantizer import ( | 
| 18 | 19 |     QuantizationAnnotation, | 
| 19 | 20 |     SharedQuantizationSpec, | 
| 20 | 21 | ) | 
| 21 | 22 | from torch.fx import Node | 
| 22 | 23 | 
 | 
| 23 | 24 | 
 | 
|  | 25 | +def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None: | 
|  | 26 | +    """ | 
|  | 27 | +    This function is specific for matmul op 16a8w. | 
|  | 28 | +    """ | 
|  | 29 | + | 
|  | 30 | +    def annotate_matmul(node: Node, quantization_config: QuantizationConfig): | 
|  | 31 | +        input_qspec_map = {} | 
|  | 32 | +        input_act = node.args[0] | 
|  | 33 | +        input_spec = quantization_config.input_activation | 
|  | 34 | +        input_qspec_map[input_act] = input_spec | 
|  | 35 | + | 
|  | 36 | +        input_act1 = node.args[1] | 
|  | 37 | +        input_spec1 = quantization_config.weight | 
|  | 38 | +        input_qspec_map[input_act1] = input_spec1 | 
|  | 39 | + | 
|  | 40 | +        node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( | 
|  | 41 | +            input_qspec_map=input_qspec_map, | 
|  | 42 | +            output_qspec=quantization_config.output_activation, | 
|  | 43 | +            _annotated=True, | 
|  | 44 | +        ) | 
|  | 45 | + | 
|  | 46 | +    def annotate_cat(node: Node, quantization_config: QuantizationConfig): | 
|  | 47 | +        input_nodes = node.args[0] | 
|  | 48 | + | 
|  | 49 | +        first_input_node = input_nodes[0] | 
|  | 50 | +        input_qspec_map = {} | 
|  | 51 | +        input_qspec_map[first_input_node] = quantization_config.input_activation | 
|  | 52 | +        share_qparams_with_input_act0_qspec = SharedQuantizationSpec( | 
|  | 53 | +            (first_input_node, node) | 
|  | 54 | +        ) | 
|  | 55 | + | 
|  | 56 | +        for input_node in input_nodes[1:]: | 
|  | 57 | +            if input_node not in input_qspec_map: | 
|  | 58 | +                input_qspec_map[input_node] = share_qparams_with_input_act0_qspec | 
|  | 59 | + | 
|  | 60 | +        node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( | 
|  | 61 | +            input_qspec_map=input_qspec_map, | 
|  | 62 | +            output_qspec=share_qparams_with_input_act0_qspec, | 
|  | 63 | +            _annotated=True, | 
|  | 64 | +        ) | 
|  | 65 | + | 
|  | 66 | +    def annotate_single_in_single_out( | 
|  | 67 | +        node: Node, quantization_config: QuantizationConfig | 
|  | 68 | +    ) -> None: | 
|  | 69 | + | 
|  | 70 | +        input_qspec_map = {} | 
|  | 71 | +        input_act = node.args[0] | 
|  | 72 | +        input_qspec_map[input_act] = quantization_config.input_activation | 
|  | 73 | + | 
|  | 74 | +        node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( | 
|  | 75 | +            input_qspec_map=input_qspec_map, | 
|  | 76 | +            output_qspec=quantization_config.output_activation, | 
|  | 77 | +            _annotated=True, | 
|  | 78 | +        ) | 
|  | 79 | + | 
|  | 80 | +    def annotate_matmul_input1(node: Node): | 
|  | 81 | +        quantization_config_8a8w = get_default_8bit_qnn_ptq_config( | 
|  | 82 | +            act_symmetric=True, act_observer=MinMaxObserver | 
|  | 83 | +        ) | 
|  | 84 | +        while isinstance(node, Node) and node.op == "call_function": | 
|  | 85 | +            if node.target in [ | 
|  | 86 | +                torch.ops.aten.permute.default, | 
|  | 87 | +                torch.ops.aten.transpose.int, | 
|  | 88 | +            ]: | 
|  | 89 | +                annotate_single_in_single_out(node, quantization_config_8a8w) | 
|  | 90 | +                node = node.args[0] | 
|  | 91 | +            elif node.target == torch.ops.aten.cat.default: | 
|  | 92 | +                annotate_cat(node, quantization_config_8a8w) | 
|  | 93 | +                node = node.args[0][0] | 
|  | 94 | +            else: | 
|  | 95 | +                node = node.args[0] | 
|  | 96 | + | 
|  | 97 | +    quantization_config_16a8w = get_16a8w_qnn_ptq_config(act_observer=MinMaxObserver) | 
|  | 98 | + | 
|  | 99 | +    for node in gm.graph.nodes: | 
|  | 100 | +        if node.op == "call_function" and node.target == torch.ops.aten.matmul.default: | 
|  | 101 | +            annotate_matmul(node, quantization_config_16a8w) | 
|  | 102 | +            annotate_matmul_input1(node.args[1]) | 
|  | 103 | + | 
|  | 104 | + | 
| 24 | 105 | def custom_annotate_llama_matmul_16a8w(gm: torch.fx.GraphModule) -> None:  # noqa: C901 | 
| 25 | 106 |     """ | 
| 26 | 107 |     This function is specific for llama matmul op 16a8w. | 
|  | 
0 commit comments