From 2bac617cee9f36fb481cdddfc6a9ead115d3494a Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Fri, 11 Oct 2024 09:53:12 -0700 Subject: [PATCH] [ET-VK] Include `FuseDequantLinearPass()` in `vulkan_preprocess` ## Context Include `FuseDequantLinearPass` as a part of `vulkan_preprocess`, so that fusing the quant/dequant nodes added by `VulkanQuantizer` can be done as part of the lowering process. Differential Revision: [D64249613](https://our.internmc.facebook.com/intern/diff/D64249613/) [ghstack-poisoned] --- backends/vulkan/TARGETS | 1 + backends/vulkan/partitioner/supported_ops.py | 1 + backends/vulkan/vulkan_preprocess.py | 2 ++ 3 files changed, 4 insertions(+) diff --git a/backends/vulkan/TARGETS b/backends/vulkan/TARGETS index 25a2d4846f8..dd49512c08c 100644 --- a/backends/vulkan/TARGETS +++ b/backends/vulkan/TARGETS @@ -25,6 +25,7 @@ runtime.python_library( "//executorch/backends/transforms:addmm_mm_to_linear", "//executorch/backends/transforms:fuse_batch_norm_with_conv", "//executorch/backends/transforms:fuse_conv_with_clamp", + "//executorch/backends/transforms:fuse_dequant_linear", "//executorch/backends/transforms:fuse_view_copy", "//executorch/backends/transforms:mean_to_sum_div", "//executorch/backends/transforms:remove_clone_ops", diff --git a/backends/vulkan/partitioner/supported_ops.py b/backends/vulkan/partitioner/supported_ops.py index 587485a5c99..7013a068805 100644 --- a/backends/vulkan/partitioner/supported_ops.py +++ b/backends/vulkan/partitioner/supported_ops.py @@ -45,6 +45,7 @@ def __contains__(self, op): PRIM_OPS = [ operator.getitem, + exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, ] SUPPORTS_DYNAMIC_SHAPE = [ diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index d2378274f6d..ed566a30ccc 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -13,6 +13,7 @@ FuseBatchNormWithConvPass, ) from executorch.backends.transforms.fuse_conv_with_clamp import FuseClampPass +from executorch.backends.transforms.fuse_dequant_linear import FuseDequantLinearPass from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform from executorch.backends.transforms.mean_to_sum_div import MeanToSumDiv from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform @@ -59,6 +60,7 @@ def preprocess( # noqa: C901 passes = [ RemoveCloneOpsTransform(), AddmmToLinearTransform(), + FuseDequantLinearPass(), FuseViewCopyTransform(), FuseBatchNormWithConvPass(program), FuseClampPass(),