Skip to content

Commit 7104ff3

Browse files
committed
Update on "[ET-VK] Add pass to remove local_scalar_dense"
## Context Scalar tensors (i.e. tensors with only 1 element) are often passed in to functions as scalars via ``` scalar_tensor[0].item() ``` This translates to the following chain in the graph ``` index_select = index_select(scalar_tensor, ...) scalar = local_scalar_dense(index_select) ``` This diff introduces a pass to remove the `local_scalar_dense` "chain" in favor of passing in the input tensor directly. Note that this replacement only occurs if the original tensor is a scalar tensor. In the Vulkan backend, these scalar tensors will be represented as symbolic integers instead of actual tensors, which is why this replacement is valid. However, it may not a valid replacement for other backends. Differential Revision: [D63913432](https://our.internmc.facebook.com/intern/diff/D63913432/) [ghstack-poisoned]
2 parents eca11ff + 58554e2 commit 7104ff3

File tree

3 files changed

+2
-16
lines changed

3 files changed

+2
-16
lines changed

backends/vulkan/partitioner/supported_ops.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,6 @@ def __contains__(self, op):
8484
# Convolution ops
8585
exir_ops.edge.aten.convolution.default,
8686
exir_ops.edge.et_vk.conv_with_clamp.default,
87-
# Custom ops
88-
"llama::sdpa_with_kv_cache",
8987
]
9088

9189
NO_DYNAMIC_SHAPE = [

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -143,22 +143,16 @@ def is_node_supported(
143143
def _is_node_supported(
144144
self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node
145145
) -> bool:
146-
target = node.target
147-
if node.target == torch.ops.higher_order.auto_functionalized:
148-
first_arg = node.args[0]
149-
assert isinstance(first_arg, torch._ops.OpOverload)
150-
target = first_arg.name()
151-
152146
if self.is_linear_permute(node):
153147
return True
154148

155149
if self.is_in_local_scalar_dense_chain(node):
156150
return True
157151

158-
if target not in VulkanSupportedOperators._ops:
152+
if node.target not in VulkanSupportedOperators._ops:
159153
return False
160154

161-
features = VulkanSupportedOperators._ops[target]
155+
features = VulkanSupportedOperators._ops[node.target]
162156

163157
if self.require_dynamic_shapes and not features.supports_dynamic_shape:
164158
return False

backends/vulkan/vulkan_preprocess.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,6 @@
4040

4141
from executorch.exir.program._program import _copy_module
4242

43-
from torch.export._remove_auto_functionalized_pass import (
44-
unsafe_remove_auto_functionalized_pass,
45-
)
46-
4743
DEFAULT_DEBUG_HANDLE = 65535
4844

4945

@@ -56,8 +52,6 @@ def preprocess( # noqa: C901
5652
program: ExportedProgram,
5753
module_compile_spec: List[CompileSpec],
5854
) -> PreprocessResult:
59-
program = unsafe_remove_auto_functionalized_pass(program)
60-
6155
passes = [
6256
RemoveCloneOpsTransform(),
6357
AddmmToLinearTransform(),

0 commit comments

Comments
 (0)