From c55ef19bd542d546f7a669eaf4eab9ece7387b66 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Mon, 28 Apr 2025 12:41:16 -0700 Subject: [PATCH] [ET-VK] Introduce generic export pass for fusing Q/DQ nodes ## Context When quantizing models with the PT2E quantization flow, quantize/dequantize nodes will be inserted into the graph. However, these quantize/dequantize nodes must be fused with operators such as `aten.linear.default` to produce nodes corresponding to quantized operators (e.g. `weight_int8pack_mm`) in order for quantized operator implementations to be called at runtime. Currently, the op fusion is done by the `fuse_dequant_linear.py` pass, however, this only handles one specific fusion pattern to generate a `weight_int8pack_mm` operator. As more quantized operators are to be supported in ET-VK via the PT2E quantization flow, a more generic fusion pass is needed that can handle a variety of fusion patterns. ## Changes Introduce the `FuseQuantizedOpsTransform()` pass. I elected to introduce a new pass under the `backends/vulkan/_passes` directory, as opposed to modifying the existing pass because I anticipate the majority of the fusion patterns to be specific to ET-VK. Remove the existing `FuseDequantLinearPass()` Switch to using the `FuseQuantizedOpsTransform` pass instead of the old `FuseDequantLinear` pass. Add `test_vulkan_passes` Python test to test export passes. Some small refactors to `test_vulkan_delegate` Python test to improve code organizations. Differential Revision: [D73794042](https://our.internmc.facebook.com/intern/diff/D73794042/) [ghstack-poisoned] --- backends/transforms/fuse_dequant_linear.py | 77 ------- backends/transforms/targets.bzl | 15 -- backends/vulkan/_passes/TARGETS | 17 ++ backends/vulkan/_passes/__init__.py | 4 + backends/vulkan/_passes/fuse_quantized_ops.py | 110 +++++++++ backends/vulkan/targets.bzl | 2 +- backends/vulkan/test/TARGETS | 13 ++ backends/vulkan/test/test_vulkan_delegate.py | 213 ++++++++++++------ backends/vulkan/test/test_vulkan_passes.py | 131 +++++++++++ backends/vulkan/utils.py | 34 +++ backends/vulkan/vulkan_preprocess.py | 4 +- 11 files changed, 452 insertions(+), 168 deletions(-) delete mode 100644 backends/transforms/fuse_dequant_linear.py create mode 100644 backends/vulkan/_passes/fuse_quantized_ops.py create mode 100644 backends/vulkan/test/test_vulkan_passes.py diff --git a/backends/transforms/fuse_dequant_linear.py b/backends/transforms/fuse_dequant_linear.py deleted file mode 100644 index 235715ac74f..00000000000 --- a/backends/transforms/fuse_dequant_linear.py +++ /dev/null @@ -1,77 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - -import torch - -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass, PassResult - - -class FuseDequantLinearPass(ExportPass): - """ - Fuses weight dequantize_per_channel nodes with linear nodes into - weight_int8pack_mm nodes, for 8-bit weight-only quantization. - - Replaces dq(weight) -> linear(activation, dq) with weight_int8pack_mm - Replaces dq(weight) -> linear(activation, dq, bias) with weight_int8pack_mm -> add - """ - - def fuse_dequant_with_linear( - self, - graph_module: torch.fx.GraphModule, - dequant_node: torch.fx.Node, - linear_node: torch.fx.Node, - ) -> None: - activations = linear_node.args[0] - bias = None - if len(linear_node.args) > 2: - bias = linear_node.args[2] - quant_weight = dequant_node.args[0] - scale = dequant_node.args[1] - - with graph_module.graph.inserting_before(linear_node): - weight_int8pack_mm_node = graph_module.graph.create_node( - "call_function", - exir_ops.edge.aten._weight_int8pack_mm.default, - (activations, quant_weight, scale), - ) - if bias: - add_node = graph_module.graph.create_node( - "call_function", - exir_ops.edge.aten.add.Tensor, - (weight_int8pack_mm_node, bias), - ) - linear_node.replace_all_uses_with(add_node) - else: - linear_node.replace_all_uses_with(weight_int8pack_mm_node) - graph_module.graph.erase_node(linear_node) - graph_module.graph.erase_node(dequant_node) - - def is_node_target( - self, node: torch.fx.Node, target: torch._ops.OperatorBase - ) -> bool: - return node.op == "call_function" and node.target == target - - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - for node in graph_module.graph.nodes: - if self.is_node_target(node, exir_ops.edge.aten.linear.default): - weight_node = node.args[1] - if self.is_node_target( - weight_node, - exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, - ): - # only fuse if weight tensor is int8 packed - quant_weight = weight_node.args[0] - if quant_weight.meta["val"].dtype != torch.int8: - continue - self.fuse_dequant_with_linear(graph_module, weight_node, node) - - graph_module.recompile() - graph_module = super().call(graph_module).graph_module - - return PassResult(graph_module, True) diff --git a/backends/transforms/targets.bzl b/backends/transforms/targets.bzl index 66ff9111f52..71980195962 100644 --- a/backends/transforms/targets.bzl +++ b/backends/transforms/targets.bzl @@ -77,21 +77,6 @@ def define_common_targets(): ], ) - runtime.python_library( - name = "fuse_dequant_linear", - srcs = ["fuse_dequant_linear.py"], - visibility = [ - "//executorch/backends/...", - ], - deps = [ - ":utils", - "//caffe2:torch", - "//executorch/exir:pass_base", - "//executorch/exir:sym_util", - "//executorch/exir/dialects:lib", - ], - ) - runtime.python_library( name = "view_copy_to_squeeze_unsqueeze", srcs = ["view_copy_to_squeeze_unsqueeze.py"], diff --git a/backends/vulkan/_passes/TARGETS b/backends/vulkan/_passes/TARGETS index 5478ad0eab6..449ff84edfb 100644 --- a/backends/vulkan/_passes/TARGETS +++ b/backends/vulkan/_passes/TARGETS @@ -3,6 +3,21 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") oncall("executorch") +runtime.python_library( + name = "fuse_quantized_ops", + srcs = ["fuse_quantized_ops.py"], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/backends/vulkan:utils_lib", + "//executorch/exir:pass_base", + "//executorch/exir:sym_util", + "//executorch/exir/dialects:lib", + ], +) + runtime.python_library( name = "insert_prepack_nodes", srcs = ["insert_prepack_nodes.py"], @@ -13,6 +28,7 @@ runtime.python_library( "//caffe2:torch", "//executorch/exir:pass_base", "//executorch/backends/vulkan:utils_lib", + "//executorch/backends/vulkan:op_registry", ], ) @@ -110,6 +126,7 @@ runtime.python_library( "//executorch/examples/...", ], deps = [ + ":fuse_quantized_ops", ":insert_prepack_nodes", ":int4_weight_only_quantizer", ":remove_asserts", diff --git a/backends/vulkan/_passes/__init__.py b/backends/vulkan/_passes/__init__.py index 220afa6a35c..7ff93a6ee38 100644 --- a/backends/vulkan/_passes/__init__.py +++ b/backends/vulkan/_passes/__init__.py @@ -6,6 +6,9 @@ # pyre-strict +from executorch.backends.vulkan._passes.fuse_quantized_ops import ( + FuseQuantizedOpsTransform, +) from executorch.backends.vulkan._passes.insert_prepack_nodes import insert_prepack_nodes from executorch.backends.vulkan._passes.int4_weight_only_quantizer import ( VkInt4WeightOnlyQuantizer, @@ -26,6 +29,7 @@ from executorch.backends.vulkan._passes.tag_memory_meta_pass import TagMemoryMetaPass __all__ = [ + "FuseQuantizedOpsTransform", "insert_prepack_nodes", "VkInt4WeightOnlyQuantizer", "remove_asserts", diff --git a/backends/vulkan/_passes/fuse_quantized_ops.py b/backends/vulkan/_passes/fuse_quantized_ops.py new file mode 100644 index 00000000000..52b16d70ec9 --- /dev/null +++ b/backends/vulkan/_passes/fuse_quantized_ops.py @@ -0,0 +1,110 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import executorch.backends.vulkan.utils as utils +import torch + +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + +############################# +## aten.weight_int8pack_mm ## +############################# + + +def matches_int8pack_mm_pattern(node: torch.fx.Node) -> bool: + if not utils.is_linear_node(node): + return False + + input_node = node.args[0] + weight_node = node.args[1] + + # Type checking + if not isinstance(weight_node, torch.fx.Node): + return False + if not isinstance(input_node, torch.fx.Node): + return False + + # The weight arg should be a dequant node dequantizing the quantized weight + # Furthermore, the op expects per channel quantization of the weight + if not utils.is_dequant_per_channel_node(weight_node): + return False + + orig_weight = weight_node.args[0] + if not isinstance(orig_weight, torch.fx.Node): + return False + + # The quantized weight data should be a int8 tensor + if orig_weight.meta["val"].dtype != torch.int8: + return False + + # The input arg should not be a dequant node + if utils.is_dequant_node(input_node): + return False + + return True + + +def fuse_into_weight_int8pack_mm_node( + graph_module: torch.fx.GraphModule, + linear_node: torch.fx.Node, +) -> None: + """ + The weight_int8pack_mm operator represents a weight only quantized linear operator. + After the PT2E quantization flow, the expected graph pattern is + + dq_weight = dequantize(weight, scales) + out = linear(activation, dq_weight, bias?) + + The goal of this function is to condense that sequence into + + out = weight_int8pack_mm(activation, dq_weight, scales) + out = out + bias + """ + activation = linear_node.args[0] + dq_weight_node = linear_node.args[1] + assert isinstance(activation, torch.fx.Node) + assert isinstance(dq_weight_node, torch.fx.Node) + + bias = None + if len(linear_node.args) > 2: + bias = linear_node.args[2] + assert isinstance(bias, torch.fx.Node) + + orig_weight = dq_weight_node.args[0] + scale = dq_weight_node.args[1] + + with graph_module.graph.inserting_before(linear_node): + weight_int8pack_mm_node = graph_module.graph.create_node( + "call_function", + exir_ops.edge.aten._weight_int8pack_mm.default, + (activation, orig_weight, scale), + ) + if bias: + add_node = graph_module.graph.create_node( + "call_function", + exir_ops.edge.aten.add.Tensor, + (weight_int8pack_mm_node, bias), + ) + linear_node.replace_all_uses_with(add_node) + else: + linear_node.replace_all_uses_with(weight_int8pack_mm_node) + graph_module.graph.erase_node(linear_node) + graph_module.graph.erase_node(dq_weight_node) + + +class FuseQuantizedOpsTransform(ExportPass): + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + for node in graph_module.graph.nodes: + if matches_int8pack_mm_pattern(node): + fuse_into_weight_int8pack_mm_node(graph_module, node) + + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, True) diff --git a/backends/vulkan/targets.bzl b/backends/vulkan/targets.bzl index aafc87ad2c3..665fde103fc 100644 --- a/backends/vulkan/targets.bzl +++ b/backends/vulkan/targets.bzl @@ -280,6 +280,7 @@ def define_common_targets(is_fbcode = False): deps = [ "//caffe2:torch", "//executorch/exir:tensor", + "//executorch/exir/backend/canonical_partitioners:config_partitioner_lib", "//executorch/backends/vulkan/serialization:lib", ] ) @@ -332,7 +333,6 @@ def define_common_targets(is_fbcode = False): "//executorch/backends/transforms:addmm_mm_to_linear", "//executorch/backends/transforms:fuse_batch_norm_with_conv", "//executorch/backends/transforms:fuse_conv_with_clamp", - "//executorch/backends/transforms:fuse_dequant_linear", "//executorch/backends/transforms:fuse_view_copy", "//executorch/backends/transforms:remove_clone_ops", "//executorch/backends/transforms:view_copy_to_squeeze_unsqueeze", diff --git a/backends/vulkan/test/TARGETS b/backends/vulkan/test/TARGETS index 5ac87892762..8f07040d586 100644 --- a/backends/vulkan/test/TARGETS +++ b/backends/vulkan/test/TARGETS @@ -24,6 +24,19 @@ python_unittest( ], ) +python_unittest( + name = "test_vulkan_passes", + srcs = [ + "test_vulkan_passes.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/backends/vulkan/_passes:vulkan_passes", + "//executorch/backends/vulkan/quantizer:vulkan_quantizer", + "//executorch/backends/vulkan:vulkan_preprocess", + ] +) + python_unittest( name = "test_vulkan_delegate_header", srcs = [ diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index 5fba5ed54cf..b57710974e8 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -15,10 +15,19 @@ from executorch.backends.transforms.convert_dtype_pass import I64toI32 from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner + from executorch.backends.vulkan.vulkan_preprocess import VulkanBackend -from executorch.exir import EdgeCompileConfig -from torch.export import Dim, export, ExportedProgram +from executorch.exir import ( + EdgeCompileConfig, + EdgeProgramManager, + ExecutorchProgramManager, +) + +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e + +from torch.ao.quantization.quantizer import Quantizer +from torch.export import Dim, export, export_for_training, ExportedProgram ctypes.CDLL("libvulkan.so.1") @@ -30,11 +39,66 @@ from executorch.extension.pytree import tree_flatten -class TestBackends(unittest.TestCase): - _edge_compile_config: EdgeCompileConfig = EdgeCompileConfig( +def lower_module( + model: torch.nn.Module, sample_inputs: Tuple[torch.Tensor], dynamic_shapes=None +) -> EdgeProgramManager: + compile_options = {} + edge_compile_config = EdgeCompileConfig( + _skip_dim_order=False, # TODO(T182928844): Delegate dim order op to backend. + ) + + program: ExportedProgram = export( + model, sample_inputs, dynamic_shapes=dynamic_shapes, strict=True + ) + + edge_program = to_edge_transform_and_lower( + program, + compile_config=edge_compile_config, + transform_passes=[ + I64toI32(edge_compile_config._skip_dim_order), + ], + partitioner=[VulkanPartitioner(compile_options)], + ) + + return edge_program + + +def quantize_and_lower_module( + model: torch.nn.Module, + sample_inputs: Tuple[torch.Tensor], + quantizer: Quantizer, + dynamic_shapes=None, +) -> EdgeProgramManager: + compile_options = {} + edge_compile_config = EdgeCompileConfig( _skip_dim_order=False, # TODO(T182928844): Delegate dim order op to backend. ) + program = export_for_training( + model, sample_inputs, dynamic_shapes=dynamic_shapes, strict=True + ).module() + + program = prepare_pt2e(program, quantizer) # pyre-ignore + # Calibrate + program(*sample_inputs) + + program = convert_pt2e(program) + + program = export(program, sample_inputs, dynamic_shapes=dynamic_shapes) + + edge_program = to_edge_transform_and_lower( + program, + compile_config=edge_compile_config, + transform_passes=[ + I64toI32(edge_compile_config._skip_dim_order), + ], + partitioner=[VulkanPartitioner(compile_options)], + ) + + return edge_program + + +class TestVulkanBackend(unittest.TestCase): def assert_outputs_equal( self, model_output, @@ -88,6 +152,59 @@ def assert_outputs_equal( ) ) + def check_no_delegation(self, et_program: ExecutorchProgramManager): + self.assertEqual( + len(et_program.executorch_program.execution_plan[0].delegates), + 0, + ) + return + + def check_vk_delegation(self, et_program: ExecutorchProgramManager): + self.assertEqual( + et_program.executorch_program.execution_plan[0].delegates[0].id, + VulkanBackend.__name__, + ) + + def run_delegated_model_and_check_output( + self, + et_program: ExecutorchProgramManager, + model: torch.nn.Module, + sample_inputs: Tuple[torch.Tensor], + atol=1e-03, + rtol=1e-01, + test_inputs=None, + first_output_only=False, + ): + executorch_module = _load_for_executorch_from_buffer(et_program.buffer) + inputs_flattened, _ = tree_flatten(sample_inputs) + + model_output = executorch_module.run_method("forward", tuple(inputs_flattened)) + ref_output = model(*sample_inputs) + + self.assert_outputs_equal( + model_output, + ref_output, + atol=atol, + rtol=rtol, + first_output_only=first_output_only, + ) + + if test_inputs is not None: + for test_input in test_inputs: + test_inputs_flattened, _ = tree_flatten(test_input) + model_output = executorch_module.run_method( + "forward", tuple(test_inputs_flattened) + ) + ref_output = model(*test_input) + + self.assert_outputs_equal( + model_output, + ref_output, + atol=atol, + rtol=rtol, + first_output_only=first_output_only, + ) + def lower_module_and_test_output( self, model: torch.nn.Module, @@ -105,80 +222,29 @@ def lower_module_and_test_output( outputs with the outputs of the eager module. """ - def run_test(): - compile_options = {} + # Validate that the model can execute in eager mode + model.eval() + model(*sample_inputs) - # At least model should run in eager mode. - model.eval() - model(*sample_inputs) + edge_program = lower_module(model, sample_inputs, dynamic_shapes=dynamic_shapes) - program: ExportedProgram = export( - model, sample_inputs, dynamic_shapes=dynamic_shapes, strict=True - ) + et_program = edge_program.to_executorch() - edge_program = to_edge_transform_and_lower( - program, - compile_config=self._edge_compile_config, - transform_passes=[ - I64toI32(self._edge_compile_config._skip_dim_order), - ], - partitioner=[VulkanPartitioner(compile_options)], - ) - executorch_program = edge_program.to_executorch() - - if expect_no_delegates: - self.assertEqual( - len( - executorch_program.executorch_program.execution_plan[ - 0 - ].delegates - ), - 0, - ) - return - else: - self.assertEqual( - executorch_program.executorch_program.execution_plan[0] - .delegates[0] - .id, - VulkanBackend.__name__, - ) - - executorch_module = _load_for_executorch_from_buffer( - executorch_program.buffer - ) - inputs_flattened, _ = tree_flatten(sample_inputs) + if expect_no_delegates: + self.check_no_delegation(et_program) + return - model_output = executorch_module.run_method( - "forward", tuple(inputs_flattened) - ) - ref_output = model(*sample_inputs) - - self.assert_outputs_equal( - model_output, - ref_output, - atol=atol, - rtol=rtol, - first_output_only=first_output_only, - ) - - if test_inputs is not None: - for test_input in test_inputs: - test_inputs_flattened, _ = tree_flatten(test_input) - model_output = executorch_module.run_method( - "forward", tuple(test_inputs_flattened) - ) - ref_output = model(*test_input) + self.check_vk_delegation(et_program) - self.assert_outputs_equal( - model_output, - ref_output, - atol=atol, - rtol=rtol, - first_output_only=first_output_only, - ) - - run_test() + self.run_delegated_model_and_check_output( + et_program, + model, + sample_inputs, + atol, + rtol, + test_inputs=test_inputs, + first_output_only=first_output_only, + ) def test_vulkan_backend_add(self): # This test is the simplest test by manually lowering some submodules, we can use paritioner @@ -942,6 +1008,7 @@ def forward(self, x): sample_inputs, ) + @unittest.skip("layer norm compute shader not working with swiftshader") def test_vulkan_backend_native_layer_norm(self): class NativeLayerNormModule(torch.nn.Module): def __init__(self): diff --git a/backends/vulkan/test/test_vulkan_passes.py b/backends/vulkan/test/test_vulkan_passes.py new file mode 100644 index 00000000000..478fef6f83f --- /dev/null +++ b/backends/vulkan/test/test_vulkan_passes.py @@ -0,0 +1,131 @@ +import unittest +from typing import Optional, Tuple + +import torch + +from executorch.backends.transforms.addmm_mm_to_linear import AddmmToLinearTransform +from executorch.backends.vulkan._passes import FuseQuantizedOpsTransform + +from executorch.backends.vulkan.quantizer.vulkan_quantizer import ( + get_weight_quantization_config, + VulkanQuantizer, +) + +from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge + +from executorch.exir.backend.canonical_partitioners.config_partitioner import ( + format_target_name, +) + +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torch.ao.quantization.quantizer import Quantizer + +################### +## Common Models ## +################### + + +class SingleLinearModule(torch.nn.Module): + def __init__(self, K=256, N=128): + super().__init__() + self.K = K + self.N = N + self.linear = torch.nn.Linear(K, N, bias=False) + + def forward(self, x): + return self.linear(x) + + def get_sample_inputs(self): + sample_inputs = (torch.rand(size=(32, self.K), dtype=torch.float32),) + return sample_inputs + + +########### +## Tests ## +########### + + +def quantize_and_lower_module( + model: torch.nn.Module, + sample_inputs: Tuple[torch.Tensor], + quantizer: Quantizer, + dynamic_shapes=None, +) -> EdgeProgramManager: + edge_compile_config = EdgeCompileConfig( + _skip_dim_order=False, # TODO(T182928844): Delegate dim order op to backend. + _check_ir_validity=False, + ) + + program = torch.export.export_for_training( + model, sample_inputs, dynamic_shapes=dynamic_shapes, strict=True + ).module() + + program = prepare_pt2e(program, quantizer) # pyre-ignore + # Calibrate + program(*sample_inputs) + + program = convert_pt2e(program) + + program = torch.export.export(program, sample_inputs, dynamic_shapes=dynamic_shapes) + + edge_program = to_edge( + program, + compile_config=edge_compile_config, + ) + + return edge_program + + +def get_weight_int8_symmetric_vk_qconfig(): + qconfig = get_weight_quantization_config( + is_per_channel=True, + weight_qmin=-128, + weight_qmax=127, + ) + return qconfig + + +def get_target_canonical_name(node: torch.fx.Node) -> Optional[str]: + if node.op != "call_function": + return None + node_name = format_target_name(node.target.__name__) # pyre-ignore + return node_name + + +def op_node_count(graph_module: torch.fx.GraphModule, canonical_op_name: str) -> int: + count = 0 + for node in graph_module.graph.nodes: + canonical_name = get_target_canonical_name(node) + if canonical_name is not None and canonical_name == canonical_op_name: + count += 1 + return count + + +class TestVulkanPasses(unittest.TestCase): + + def test_fuse_int8pack_mm(self): + K = 256 + N = 256 + model = SingleLinearModule(K, N) + sample_inputs = model.get_sample_inputs() + + quantizer = VulkanQuantizer() + quantizer.set_global(get_weight_int8_symmetric_vk_qconfig()) + + edge_program = quantize_and_lower_module( + model, + sample_inputs, + quantizer, + ) + + edge_program.transform( + [ + AddmmToLinearTransform(), + FuseQuantizedOpsTransform(), + ] + ) + + gm = edge_program._edge_programs["forward"].graph_module + + self.assertEqual(op_node_count(gm, "_weight_int8pack_mm.default"), 1) + self.assertEqual(op_node_count(gm, "dequantize_per_channel.default"), 0) diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index fa032cd7b4f..c1b1e6d9db1 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -14,6 +14,10 @@ VkStorageType, ) +from executorch.exir.backend.canonical_partitioners.config_partitioner import ( + format_target_name, +) + from executorch.exir.tensor import TensorSpec from torch._export.utils import is_buffer, is_param @@ -22,11 +26,41 @@ from torch.export import ExportedProgram +_DQ_OPS = { + "dequantize_per_tensor.tensor", + "dequantize_per_tensor.default", + "dequantize_per_channel.default", + "dequantize_per_channel_group.default", + "dequantize_per_token.default", + "dequantize_affine.default", +} + ## ## Node type determination ## +def is_dequant_node(node: torch.fx.Node) -> bool: + if node.op != "call_function": + return False + node_name = format_target_name(node.target.__name__) # pyre-ignore + return node_name in _DQ_OPS + + +def is_dequant_per_channel_node(node: torch.fx.Node) -> bool: + if node.op != "call_function": + return False + node_name = format_target_name(node.target.__name__) # pyre-ignore + return node_name == "dequantize_per_channel.default" + + +def is_linear_node(node: torch.fx.Node) -> bool: + if node.op != "call_function": + return False + node_name = format_target_name(node.target.__name__) # pyre-ignore + return node_name == "linear.default" + + def is_get_attr_node(node: torch.fx.Node) -> bool: return isinstance(node, torch.fx.Node) and node.op == "get_attr" diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 188311e5f2c..c9b67230ca5 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -17,12 +17,12 @@ FuseBatchNormWithConvPass, ) from executorch.backends.transforms.fuse_conv_with_clamp import FuseClampPass -from executorch.backends.transforms.fuse_dequant_linear import FuseDequantLinearPass from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform from executorch.backends.transforms.view_copy_to_squeeze_unsqueeze import ( ViewCopyToSqueezeUnsqueezePass, ) from executorch.backends.vulkan._passes import ( + FuseQuantizedOpsTransform, insert_prepack_nodes, RemoveLocalScalarDenseOpsTransform, RemoveRedundantOpsTransform, @@ -152,7 +152,7 @@ def preprocess( # noqa: C901 [ RemoveRedundantOpsTransform(), AddmmToLinearTransform(), - FuseDequantLinearPass(), + FuseQuantizedOpsTransform(), SqueezeUnsqueezeInputs(), FuseViewCopyTransform(), ViewCopyToSqueezeUnsqueezePass(),