Skip to content
2 changes: 2 additions & 0 deletions backends/vulkan/partitioner/supported_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ def __contains__(self, op):
# Convolution
exir_ops.edge.aten.convolution.default,
exir_ops.edge.et_vk.conv_with_clamp.default,
# Custom ops
"llama::sdpa_with_kv_cache",
]

NO_DYNAMIC_SHAPE = [
Expand Down
10 changes: 8 additions & 2 deletions backends/vulkan/partitioner/vulkan_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,16 +144,22 @@ def is_node_supported(
def _is_node_supported(
self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node
) -> bool:
target = node.target
if node.target == torch.ops.higher_order.auto_functionalized:
first_arg = node.args[0]
assert isinstance(first_arg, torch._ops.OpOverload)
target = first_arg.name()

if self.is_linear_permute(node):
return True

if self.is_in_local_scalar_dense_chain(node):
return True

if node.target not in VulkanSupportedOperators._ops:
if target not in VulkanSupportedOperators._ops:
return False

features = VulkanSupportedOperators._ops[node.target]
features = VulkanSupportedOperators._ops[target]

if self.require_dynamic_shapes and not features.supports_dynamic_shape:
return False
Expand Down
6 changes: 6 additions & 0 deletions backends/vulkan/vulkan_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@

from executorch.exir.program._program import _copy_module

from torch.export._remove_auto_functionalized_pass import (
unsafe_remove_auto_functionalized_pass,
)

DEFAULT_DEBUG_HANDLE = 65535


Expand All @@ -52,6 +56,8 @@ def preprocess( # noqa: C901
program: ExportedProgram,
module_compile_spec: List[CompileSpec],
) -> PreprocessResult:
program = unsafe_remove_auto_functionalized_pass(program)

passes = [
RemoveCloneOpsTransform(),
AddmmToLinearTransform(),
Expand Down
Loading