| 
 | 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates.  | 
 | 2 | +# All rights reserved.  | 
 | 3 | +#  | 
 | 4 | +# This source code is licensed under the BSD-style license found in the  | 
 | 5 | +# LICENSE file in the root directory of this source tree.  | 
 | 6 | + | 
 | 7 | +from typing import Optional  | 
 | 8 | + | 
 | 9 | +import executorch.backends.vulkan.utils as utils  | 
 | 10 | + | 
 | 11 | +import torch  | 
 | 12 | + | 
 | 13 | +from executorch.backends.vulkan.patterns.pattern_registry import (  | 
 | 14 | +    PatternMatch,  | 
 | 15 | +    register_pattern_detector,  | 
 | 16 | +    register_pattern_replacement,  | 
 | 17 | +)  | 
 | 18 | + | 
 | 19 | +from executorch.exir import ExportedProgram  | 
 | 20 | +from executorch.exir.dialects._ops import ops as exir_ops  | 
 | 21 | + | 
 | 22 | + | 
 | 23 | +class QuantizedBinaryMatch(PatternMatch):  | 
 | 24 | +    def __init__(self, binary_node: torch.fx.Node) -> None:  | 
 | 25 | +        self.anchor_node = binary_node  | 
 | 26 | +        self.match_found = False  | 
 | 27 | +        self.all_nodes = [self.anchor_node]  | 
 | 28 | + | 
 | 29 | +        # Extract alpha parameter if it exists (for add operations)  | 
 | 30 | +        self.alpha = 1.0  | 
 | 31 | +        if len(binary_node.args) > 2 and binary_node.args[2] is not None:  | 
 | 32 | +            # Alpha is typically a scalar value  | 
 | 33 | +            if isinstance(binary_node.args[2], (int, float)):  | 
 | 34 | +                self.alpha = binary_node.args[2]  | 
 | 35 | + | 
 | 36 | +        # Identify input nodes - both should be dequantize nodes for static quantization  | 
 | 37 | +        if len(binary_node.args) < 2:  | 
 | 38 | +            return  | 
 | 39 | + | 
 | 40 | +        input_a_node = binary_node.args[0]  | 
 | 41 | +        assert isinstance(input_a_node, torch.fx.Node)  | 
 | 42 | +        input_b_node = binary_node.args[1]  | 
 | 43 | +        assert isinstance(input_b_node, torch.fx.Node)  | 
 | 44 | + | 
 | 45 | +        # Both arguments must be dequant nodes for static quantization  | 
 | 46 | +        if not utils.is_dequant_node(input_a_node) or not utils.is_dequant_node(  | 
 | 47 | +            input_b_node  | 
 | 48 | +        ):  | 
 | 49 | +            return  | 
 | 50 | + | 
 | 51 | +        self.dequantize_input_a_node = input_a_node  | 
 | 52 | +        self.dequantize_input_b_node = input_b_node  | 
 | 53 | + | 
 | 54 | +        # Extract quantization parameters for input A  | 
 | 55 | +        self.quantize_input_a_node = self.dequantize_input_a_node.args[0]  | 
 | 56 | +        self.input_a_scales_node = self.dequantize_input_a_node.args[1]  | 
 | 57 | +        self.input_a_zeros_node = self.dequantize_input_a_node.args[2]  | 
 | 58 | + | 
 | 59 | +        # Extract quantization parameters for input B  | 
 | 60 | +        self.quantize_input_b_node = self.dequantize_input_b_node.args[0]  | 
 | 61 | +        self.input_b_scales_node = self.dequantize_input_b_node.args[1]  | 
 | 62 | +        self.input_b_zeros_node = self.dequantize_input_b_node.args[2]  | 
 | 63 | + | 
 | 64 | +        self.all_nodes.extend(  | 
 | 65 | +            [self.dequantize_input_a_node, self.dequantize_input_b_node]  | 
 | 66 | +        )  | 
 | 67 | + | 
 | 68 | +        # Identify output node  | 
 | 69 | +        self.output_node = self.anchor_node  | 
 | 70 | + | 
 | 71 | +        # The binary operation output must have only one user; it will be either a relu node  | 
 | 72 | +        # or a quantize node.  | 
 | 73 | +        if len(self.output_node.users) != 1:  | 
 | 74 | +            return  | 
 | 75 | + | 
 | 76 | +        cur_node = list(self.output_node.users)[0]  | 
 | 77 | +        self.relu_node = None  | 
 | 78 | +        if cur_node.target == exir_ops.edge.aten.relu.default:  | 
 | 79 | +            self.relu_node = cur_node  | 
 | 80 | +            self.all_nodes.append(self.relu_node)  | 
 | 81 | +            # If there's a relu, get its user (should be the quantize node)  | 
 | 82 | +            if len(cur_node.users) != 1:  | 
 | 83 | +                return  | 
 | 84 | +            cur_node = list(cur_node.users)[0]  | 
 | 85 | + | 
 | 86 | +        if not utils.is_quant_node(cur_node):  | 
 | 87 | +            return  | 
 | 88 | + | 
 | 89 | +        self.quantize_output_node = cur_node  | 
 | 90 | +        self.output_scales_node = self.quantize_output_node.args[1]  | 
 | 91 | +        self.output_zeros_node = self.quantize_output_node.args[2]  | 
 | 92 | + | 
 | 93 | +        self.all_nodes.append(self.quantize_output_node)  | 
 | 94 | + | 
 | 95 | +        self.match_found = True  | 
 | 96 | + | 
 | 97 | + | 
 | 98 | +# Define the binary operation anchor nodes that we support  | 
 | 99 | +binary_anchor_nodes = {  | 
 | 100 | +    exir_ops.edge.aten.add.Tensor,  | 
 | 101 | +    exir_ops.edge.aten.add_.Tensor,  | 
 | 102 | +}  | 
 | 103 | + | 
 | 104 | + | 
 | 105 | +@register_pattern_detector("quantized_binary")  | 
 | 106 | +def find_quantized_binary_patterns(  | 
 | 107 | +    node: torch.fx.Node,  | 
 | 108 | +) -> Optional[QuantizedBinaryMatch]:  | 
 | 109 | +    if node.target not in binary_anchor_nodes:  | 
 | 110 | +        return None  | 
 | 111 | + | 
 | 112 | +    matched_pattern = QuantizedBinaryMatch(node)  | 
 | 113 | +    if matched_pattern.match_found:  | 
 | 114 | +        return matched_pattern  | 
 | 115 | + | 
 | 116 | +    return None  | 
 | 117 | + | 
 | 118 | + | 
 | 119 | +##  | 
 | 120 | +## Pattern Replacement  | 
 | 121 | +##  | 
 | 122 | + | 
 | 123 | + | 
 | 124 | +@register_pattern_replacement("quantized_binary")  | 
 | 125 | +def make_add_q8ta_q8ta_q8to_custom_op(  | 
 | 126 | +    ep: ExportedProgram,  | 
 | 127 | +    graph_module: torch.fx.GraphModule,  | 
 | 128 | +    match: QuantizedBinaryMatch,  | 
 | 129 | +):  | 
 | 130 | +    # Determine the operation type based on the anchor node  | 
 | 131 | +    op_target = None  | 
 | 132 | +    if match.anchor_node.target in {  | 
 | 133 | +        exir_ops.edge.aten.add.Tensor,  | 
 | 134 | +        exir_ops.edge.aten.add_.Tensor,  | 
 | 135 | +    }:  | 
 | 136 | +        op_target = exir_ops.edge.et_vk.add_q8ta_q8ta_q8to.default  | 
 | 137 | +    else:  | 
 | 138 | +        # For future binary operations, add more mappings here  | 
 | 139 | +        raise NotImplementedError(  | 
 | 140 | +            f"Unsupported binary operation: {match.anchor_node.target}"  | 
 | 141 | +        )  | 
 | 142 | + | 
 | 143 | +    with graph_module.graph.inserting_before(match.output_node):  | 
 | 144 | +        qbinary_node = graph_module.graph.create_node(  | 
 | 145 | +            "call_function",  | 
 | 146 | +            op_target,  | 
 | 147 | +            args=(  | 
 | 148 | +                match.quantize_input_a_node,  | 
 | 149 | +                match.quantize_input_b_node,  | 
 | 150 | +                match.input_a_scales_node,  | 
 | 151 | +                match.input_a_zeros_node,  | 
 | 152 | +                match.input_b_scales_node,  | 
 | 153 | +                match.input_b_zeros_node,  | 
 | 154 | +                match.output_scales_node,  | 
 | 155 | +                match.output_zeros_node,  | 
 | 156 | +                match.alpha,  # Alpha parameter for scaling  | 
 | 157 | +            ),  | 
 | 158 | +        )  | 
 | 159 | + | 
 | 160 | +    qbinary_node.meta["val"] = match.output_node.meta["val"]  | 
 | 161 | +    match.quantize_output_node.replace_all_uses_with(qbinary_node)  | 
0 commit comments