|  | 
|  | 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. | 
|  | 2 | +# All rights reserved. | 
|  | 3 | +# | 
|  | 4 | +# This source code is licensed under the BSD-style license found in the | 
|  | 5 | +# LICENSE file in the root directory of this source tree. | 
|  | 6 | + | 
|  | 7 | +# pyre-strict | 
|  | 8 | + | 
|  | 9 | +import executorch.backends.vulkan.utils as utils | 
|  | 10 | +import torch | 
|  | 11 | + | 
|  | 12 | +from executorch.exir.dialects._ops import ops as exir_ops | 
|  | 13 | +from executorch.exir.pass_base import ExportPass, PassResult | 
|  | 14 | + | 
|  | 15 | +############################# | 
|  | 16 | +## aten.weight_int8pack_mm ## | 
|  | 17 | +############################# | 
|  | 18 | + | 
|  | 19 | + | 
|  | 20 | +def matches_int8pack_mm_pattern(node: torch.fx.Node) -> bool: | 
|  | 21 | +    if not utils.is_linear_node(node): | 
|  | 22 | +        return False | 
|  | 23 | + | 
|  | 24 | +    input_node = node.args[0] | 
|  | 25 | +    weight_node = node.args[1] | 
|  | 26 | + | 
|  | 27 | +    # Type checking | 
|  | 28 | +    if not isinstance(weight_node, torch.fx.Node): | 
|  | 29 | +        return False | 
|  | 30 | +    if not isinstance(input_node, torch.fx.Node): | 
|  | 31 | +        return False | 
|  | 32 | + | 
|  | 33 | +    # The weight arg should be a dequant node dequantizing the quantized weight | 
|  | 34 | +    # Furthermore, the op expects per channel quantization of the weight | 
|  | 35 | +    if not utils.is_dequant_per_channel_node(weight_node): | 
|  | 36 | +        return False | 
|  | 37 | + | 
|  | 38 | +    orig_weight = weight_node.args[0] | 
|  | 39 | +    if not isinstance(orig_weight, torch.fx.Node): | 
|  | 40 | +        return False | 
|  | 41 | + | 
|  | 42 | +    # The quantized weight data should be a int8 tensor | 
|  | 43 | +    if orig_weight.meta["val"].dtype != torch.int8: | 
|  | 44 | +        return False | 
|  | 45 | + | 
|  | 46 | +    # The input arg should not be a dequant node | 
|  | 47 | +    if utils.is_dequant_node(input_node): | 
|  | 48 | +        return False | 
|  | 49 | + | 
|  | 50 | +    return True | 
|  | 51 | + | 
|  | 52 | + | 
|  | 53 | +def fuse_into_weight_int8pack_mm_node( | 
|  | 54 | +    graph_module: torch.fx.GraphModule, | 
|  | 55 | +    linear_node: torch.fx.Node, | 
|  | 56 | +) -> None: | 
|  | 57 | +    """ | 
|  | 58 | +    The weight_int8pack_mm operator represents a weight only quantized linear operator. | 
|  | 59 | +    After the PT2E quantization flow, the expected graph pattern is | 
|  | 60 | +
 | 
|  | 61 | +        dq_weight = dequantize(weight, scales) | 
|  | 62 | +        out = linear(activation, dq_weight, bias?) | 
|  | 63 | +
 | 
|  | 64 | +    The goal of this function is to condense that sequence into | 
|  | 65 | +
 | 
|  | 66 | +        out = weight_int8pack_mm(activation, dq_weight, scales) | 
|  | 67 | +        out = out + bias | 
|  | 68 | +    """ | 
|  | 69 | +    activation = linear_node.args[0] | 
|  | 70 | +    dq_weight_node = linear_node.args[1] | 
|  | 71 | +    assert isinstance(activation, torch.fx.Node) | 
|  | 72 | +    assert isinstance(dq_weight_node, torch.fx.Node) | 
|  | 73 | + | 
|  | 74 | +    bias = None | 
|  | 75 | +    if len(linear_node.args) > 2: | 
|  | 76 | +        bias = linear_node.args[2] | 
|  | 77 | +        assert isinstance(bias, torch.fx.Node) | 
|  | 78 | + | 
|  | 79 | +    orig_weight = dq_weight_node.args[0] | 
|  | 80 | +    scale = dq_weight_node.args[1] | 
|  | 81 | + | 
|  | 82 | +    with graph_module.graph.inserting_before(linear_node): | 
|  | 83 | +        weight_int8pack_mm_node = graph_module.graph.create_node( | 
|  | 84 | +            "call_function", | 
|  | 85 | +            exir_ops.edge.aten._weight_int8pack_mm.default, | 
|  | 86 | +            (activation, orig_weight, scale), | 
|  | 87 | +        ) | 
|  | 88 | +        if bias: | 
|  | 89 | +            add_node = graph_module.graph.create_node( | 
|  | 90 | +                "call_function", | 
|  | 91 | +                exir_ops.edge.aten.add.Tensor, | 
|  | 92 | +                (weight_int8pack_mm_node, bias), | 
|  | 93 | +            ) | 
|  | 94 | +            linear_node.replace_all_uses_with(add_node) | 
|  | 95 | +        else: | 
|  | 96 | +            linear_node.replace_all_uses_with(weight_int8pack_mm_node) | 
|  | 97 | +        graph_module.graph.erase_node(linear_node) | 
|  | 98 | +        graph_module.graph.erase_node(dq_weight_node) | 
|  | 99 | + | 
|  | 100 | + | 
|  | 101 | +class FuseQuantizedOpsTransform(ExportPass): | 
|  | 102 | +    def call(self, graph_module: torch.fx.GraphModule) -> PassResult: | 
|  | 103 | +        for node in graph_module.graph.nodes: | 
|  | 104 | +            if matches_int8pack_mm_pattern(node): | 
|  | 105 | +                fuse_into_weight_int8pack_mm_node(graph_module, node) | 
|  | 106 | + | 
|  | 107 | +        graph_module.recompile() | 
|  | 108 | +        graph_module = super().call(graph_module).graph_module | 
|  | 109 | + | 
|  | 110 | +        return PassResult(graph_module, True) | 
0 commit comments