From 5777ad3a531cc6fb916924041b2d95a7b8e286f0 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Fri, 4 Oct 2024 12:53:46 -0700 Subject: [PATCH] [ET-VK][ez] Support exporting of custom operator calls via `higher_order_auto_functionalized`, checkpoint As title. This diff adds the ability to partition custom op calls to the Vulkan delegate. Differential Revision: [D63913434](https://our.internmc.facebook.com/intern/diff/D63913434/) [ghstack-poisoned] --- backends/vulkan/partitioner/supported_ops.py | 7 +++++++ backends/vulkan/partitioner/vulkan_partitioner.py | 10 ++++++++-- backends/vulkan/vulkan_preprocess.py | 6 ++++++ 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/backends/vulkan/partitioner/supported_ops.py b/backends/vulkan/partitioner/supported_ops.py index ca7ce72caed..1fcc4725dc9 100644 --- a/backends/vulkan/partitioner/supported_ops.py +++ b/backends/vulkan/partitioner/supported_ops.py @@ -169,9 +169,16 @@ def register_dynamic_shape_ops(ops: OpList): ops[op].supports_dynamic_shape = True +def register_custom_ops(ops: OpList): + for op in CUSTOM_OPS: + ops[op].supports_dynamic_shape = True + ops[op].supports_texture = True + + def enumerate_supported_ops(): ops = OpList() register_prim_ops(ops) register_no_dynamic_shape_ops(ops) register_dynamic_shape_ops(ops) + register_custom_ops(ops) return ops diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index 103297bc758..bcdb28475d5 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -119,13 +119,19 @@ def is_node_supported( def _is_node_supported( self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node ) -> bool: + target = node.target + if node.target == torch.ops.higher_order.auto_functionalized: + first_arg = node.args[0] + assert isinstance(first_arg, torch._ops.OpOverload) + target = first_arg.name() + if self.is_linear_permute(node): return True - if node.target not in VulkanSupportedOperators._ops: + if target not in VulkanSupportedOperators._ops: return False - features = VulkanSupportedOperators._ops[node.target] + features = VulkanSupportedOperators._ops[target] if self.require_dynamic_shapes and not features.supports_dynamic_shape: return False diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 7e85c25faee..d2e4acc7d51 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -36,6 +36,10 @@ from executorch.exir.program._program import _copy_module +from torch.export._remove_auto_functionalized_pass import ( + unsafe_remove_auto_functionalized_pass, +) + DEFAULT_DEBUG_HANDLE = 65535 @@ -48,6 +52,8 @@ def preprocess( # noqa: C901 program: ExportedProgram, module_compile_spec: List[CompileSpec], ) -> PreprocessResult: + program = unsafe_remove_auto_functionalized_pass(program) + passes = [ RemoveCloneOpsTransform(), AddmmToLinearTransform(),