Skip to content

Commit 58554e2

Browse files
committed
Update base for Update on "[ET-VK] Add pass to remove local_scalar_dense"
## Context Scalar tensors (i.e. tensors with only 1 element) are often passed in to functions as scalars via ``` scalar_tensor[0].item() ``` This translates to the following chain in the graph ``` index_select = index_select(scalar_tensor, ...) scalar = local_scalar_dense(index_select) ``` This diff introduces a pass to remove the `local_scalar_dense` "chain" in favor of passing in the input tensor directly. Note that this replacement only occurs if the original tensor is a scalar tensor. In the Vulkan backend, these scalar tensors will be represented as symbolic integers instead of actual tensors, which is why this replacement is valid. However, it may not a valid replacement for other backends. Differential Revision: [D63913432](https://our.internmc.facebook.com/intern/diff/D63913432/) [ghstack-poisoned]
1 parent ae48f99 commit 58554e2

File tree

3 files changed

+2
-16
lines changed

3 files changed

+2
-16
lines changed

backends/vulkan/partitioner/supported_ops.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,6 @@ def __contains__(self, op):
8484
# Convolution ops
8585
exir_ops.edge.aten.convolution.default,
8686
exir_ops.edge.et_vk.conv_with_clamp.default,
87-
# Custom ops
88-
"llama::sdpa_with_kv_cache",
8987
]
9088

9189
NO_DYNAMIC_SHAPE = [

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -119,19 +119,13 @@ def is_node_supported(
119119
def _is_node_supported(
120120
self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node
121121
) -> bool:
122-
target = node.target
123-
if node.target == torch.ops.higher_order.auto_functionalized:
124-
first_arg = node.args[0]
125-
assert isinstance(first_arg, torch._ops.OpOverload)
126-
target = first_arg.name()
127-
128122
if self.is_linear_permute(node):
129123
return True
130124

131-
if target not in VulkanSupportedOperators._ops:
125+
if node.target not in VulkanSupportedOperators._ops:
132126
return False
133127

134-
features = VulkanSupportedOperators._ops[target]
128+
features = VulkanSupportedOperators._ops[node.target]
135129

136130
if self.require_dynamic_shapes and not features.supports_dynamic_shape:
137131
return False

backends/vulkan/vulkan_preprocess.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,6 @@
3636

3737
from executorch.exir.program._program import _copy_module
3838

39-
from torch.export._remove_auto_functionalized_pass import (
40-
unsafe_remove_auto_functionalized_pass,
41-
)
42-
4339
DEFAULT_DEBUG_HANDLE = 65535
4440

4541

@@ -52,8 +48,6 @@ def preprocess( # noqa: C901
5248
program: ExportedProgram,
5349
module_compile_spec: List[CompileSpec],
5450
) -> PreprocessResult:
55-
program = unsafe_remove_auto_functionalized_pass(program)
56-
5751
passes = [
5852
RemoveCloneOpsTransform(),
5953
AddmmToLinearTransform(),

0 commit comments

Comments
 (0)