diff --git a/backends/vulkan/partitioner/supported_ops.py b/backends/vulkan/partitioner/supported_ops.py index 054b6947517..e9ed1439cd2 100644 --- a/backends/vulkan/partitioner/supported_ops.py +++ b/backends/vulkan/partitioner/supported_ops.py @@ -84,6 +84,8 @@ def __contains__(self, op): # Convolution exir_ops.edge.aten.convolution.default, exir_ops.edge.et_vk.conv_with_clamp.default, + # Custom ops + "llama::sdpa_with_kv_cache", ] NO_DYNAMIC_SHAPE = [ diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index b1046aa01f0..b7e74bcaf42 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -144,16 +144,22 @@ def is_node_supported( def _is_node_supported( self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node ) -> bool: + target = node.target + if node.target == torch.ops.higher_order.auto_functionalized: + first_arg = node.args[0] + assert isinstance(first_arg, torch._ops.OpOverload) + target = first_arg.name() + if self.is_linear_permute(node): return True if self.is_in_local_scalar_dense_chain(node): return True - if node.target not in VulkanSupportedOperators._ops: + if target not in VulkanSupportedOperators._ops: return False - features = VulkanSupportedOperators._ops[node.target] + features = VulkanSupportedOperators._ops[target] if self.require_dynamic_shapes and not features.supports_dynamic_shape: return False diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 562457ac54f..18b5ad872f5 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -40,6 +40,10 @@ from executorch.exir.program._program import _copy_module +from torch.export._remove_auto_functionalized_pass import ( + unsafe_remove_auto_functionalized_pass, +) + DEFAULT_DEBUG_HANDLE = 65535 @@ -52,6 +56,8 @@ def preprocess( # noqa: C901 program: ExportedProgram, module_compile_spec: List[CompileSpec], ) -> PreprocessResult: + program = unsafe_remove_auto_functionalized_pass(program) + passes = [ RemoveCloneOpsTransform(), AddmmToLinearTransform(),