diff --git a/backends/vulkan/TARGETS b/backends/vulkan/TARGETS index 25a2d4846f8..dd49512c08c 100644 --- a/backends/vulkan/TARGETS +++ b/backends/vulkan/TARGETS @@ -25,6 +25,7 @@ runtime.python_library( "//executorch/backends/transforms:addmm_mm_to_linear", "//executorch/backends/transforms:fuse_batch_norm_with_conv", "//executorch/backends/transforms:fuse_conv_with_clamp", + "//executorch/backends/transforms:fuse_dequant_linear", "//executorch/backends/transforms:fuse_view_copy", "//executorch/backends/transforms:mean_to_sum_div", "//executorch/backends/transforms:remove_clone_ops", diff --git a/backends/vulkan/partitioner/supported_ops.py b/backends/vulkan/partitioner/supported_ops.py index 587485a5c99..4d0858953bd 100644 --- a/backends/vulkan/partitioner/supported_ops.py +++ b/backends/vulkan/partitioner/supported_ops.py @@ -45,6 +45,13 @@ def __contains__(self, op): PRIM_OPS = [ operator.getitem, + # Quantization related ops will be fused via graph passes + exir_ops.edge.quantized_decomposed.quantize_per_channel.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, + exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, ] SUPPORTS_DYNAMIC_SHAPE = [ diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index d2378274f6d..ed566a30ccc 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -13,6 +13,7 @@ FuseBatchNormWithConvPass, ) from executorch.backends.transforms.fuse_conv_with_clamp import FuseClampPass +from executorch.backends.transforms.fuse_dequant_linear import FuseDequantLinearPass from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform from executorch.backends.transforms.mean_to_sum_div import MeanToSumDiv from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform @@ -59,6 +60,7 @@ def preprocess( # noqa: C901 passes = [ RemoveCloneOpsTransform(), AddmmToLinearTransform(), + FuseDequantLinearPass(), FuseViewCopyTransform(), FuseBatchNormWithConvPass(program), FuseClampPass(),