Skip to content

Commit 94967e7

Browse files
committed
Update base for Update on "[ET-VK][ez] Support exporting of custom operator calls via higher_order_auto_functionalized, checkpoint"
As title. This diff adds the ability to partition custom op calls to the Vulkan delegate. Differential Revision: [D63913434](https://our.internmc.facebook.com/intern/diff/D63913434/) [ghstack-poisoned]
1 parent 20a157f commit 94967e7

File tree

6 files changed

+115
-59
lines changed

6 files changed

+115
-59
lines changed

backends/vulkan/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ runtime.python_library(
2828
"//executorch/backends/transforms:fuse_view_copy",
2929
"//executorch/backends/transforms:mean_to_sum_div",
3030
"//executorch/backends/transforms:remove_clone_ops",
31+
"//executorch/backends/vulkan/passes:remove_local_scalar_dense",
3132
"//executorch/exir:graph_module",
3233
"//executorch/exir/_serialize:_bindings",
3334
"//executorch/exir/_serialize:lib",

backends/vulkan/partitioner/supported_ops.py

Lines changed: 25 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -47,17 +47,16 @@ def __contains__(self, op):
4747
operator.getitem,
4848
]
4949

50-
BINARY_OPS = [
50+
SUPPORTS_DYNAMIC_SHAPE = [
51+
# Binary broadcasting operators
5152
exir_ops.edge.aten.add.Tensor,
5253
exir_ops.edge.aten.sub.Tensor,
5354
exir_ops.edge.aten.minimum.default,
5455
exir_ops.edge.aten.mul.Tensor,
5556
exir_ops.edge.aten.div.Tensor,
5657
exir_ops.edge.aten.div.Tensor_mode,
5758
exir_ops.edge.aten.pow.Tensor_Tensor,
58-
]
59-
60-
UNARY_OPS = [
59+
# Unary elementwise operators
6160
exir_ops.edge.aten.abs.default,
6261
exir_ops.edge.aten.clamp.default,
6362
exir_ops.edge.aten.cos.default,
@@ -71,60 +70,46 @@ def __contains__(self, op):
7170
exir_ops.edge.aten.sin.default,
7271
exir_ops.edge.aten.sqrt.default,
7372
exir_ops.edge.aten.tanh.default,
74-
]
75-
76-
MATMUL_OPS = [
73+
# Matrix Multiplication Operators
7774
exir_ops.edge.aten.bmm.default,
7875
exir_ops.edge.aten.mm.default,
7976
exir_ops.edge.aten.addmm.default,
8077
exir_ops.edge.aten.linear.default,
81-
]
82-
83-
POOLING_OPS = [
78+
# Reduction operators
79+
exir_ops.edge.aten._log_softmax.default,
80+
exir_ops.edge.aten._softmax.default,
81+
# 2D Pooling ops
8482
exir_ops.edge.aten.avg_pool2d.default,
8583
exir_ops.edge.aten.max_pool2d_with_indices.default,
86-
]
87-
88-
CONVOLUTION_OPS = [
84+
# Convolution ops
8985
exir_ops.edge.aten.convolution.default,
9086
exir_ops.edge.et_vk.conv_with_clamp.default,
9187
]
9288

93-
REDUCTION_OPS = [
89+
NO_DYNAMIC_SHAPE = [
90+
# Reduction operators
9491
exir_ops.edge.aten.mean.dim,
9592
exir_ops.edge.aten.sum.dim_IntList,
96-
exir_ops.edge.aten._log_softmax.default,
97-
exir_ops.edge.aten._softmax.default,
98-
]
99-
100-
NORMALIZATION_OPS = [
93+
# Normalization operators
10194
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
10295
exir_ops.edge.aten.native_layer_norm.default,
103-
]
104-
105-
SHAPE_MANIPULATION_OPS = [
96+
# Shape Manipulation operators
10697
exir_ops.edge.aten.squeeze_copy.dims,
10798
exir_ops.edge.aten.unsqueeze_copy.default,
10899
exir_ops.edge.aten.view_copy.default,
109100
exir_ops.edge.aten.permute_copy.default,
110101
exir_ops.edge.aten.t_copy.default,
111-
]
112-
113-
INDEXING_OPS = [
102+
# Indexing and lookup operators
114103
exir_ops.edge.aten.embedding.default,
115104
exir_ops.edge.aten.index_select.default,
116105
exir_ops.edge.aten.select_copy.int,
117106
exir_ops.edge.aten.slice_copy.Tensor,
118-
]
119-
120-
ORCHESTRATION_OPS = [
107+
# Tensor combination operators
121108
exir_ops.edge.aten.cat.default,
122109
exir_ops.edge.aten.split_with_sizes_copy.default,
123110
exir_ops.edge.aten.split.Tensor,
124111
exir_ops.edge.aten.repeat.default,
125-
]
126-
127-
CREATION_OPS = [
112+
# Tensor creation operators
128113
exir_ops.edge.aten.arange.start_step,
129114
exir_ops.edge.aten.clone.default,
130115
exir_ops.edge.aten.constant_pad_nd.default,
@@ -139,39 +124,20 @@ def __contains__(self, op):
139124
]
140125

141126

142-
def register_prim_ops(ops: OpList):
143-
for op in PRIM_OPS:
144-
ops[op].supports_texture = True
145-
ops[op].supports_buffer = True
146-
ops[op].supports_dynamic_shape = True
127+
def enumerate_supported_ops():
128+
ops = OpList()
147129

130+
# Register in order of least to most capabilities
148131

149-
def register_no_dynamic_shape_ops(ops: OpList):
150-
for op in [
151-
*REDUCTION_OPS,
152-
*NORMALIZATION_OPS,
153-
*SHAPE_MANIPULATION_OPS,
154-
*INDEXING_OPS,
155-
*ORCHESTRATION_OPS,
156-
*CREATION_OPS,
157-
]:
132+
for op in NO_DYNAMIC_SHAPE:
158133
ops[op].supports_dynamic_shape = False
159134

160-
161-
def register_dynamic_shape_ops(ops: OpList):
162-
for op in [
163-
*BINARY_OPS,
164-
*UNARY_OPS,
165-
*MATMUL_OPS,
166-
*POOLING_OPS,
167-
*CONVOLUTION_OPS,
168-
]:
135+
for op in SUPPORTS_DYNAMIC_SHAPE:
169136
ops[op].supports_dynamic_shape = True
170137

138+
for op in PRIM_OPS:
139+
ops[op].supports_texture = True
140+
ops[op].supports_buffer = True
141+
ops[op].supports_dynamic_shape = True
171142

172-
def enumerate_supported_ops():
173-
ops = OpList()
174-
register_prim_ops(ops)
175-
register_no_dynamic_shape_ops(ops)
176-
register_dynamic_shape_ops(ops)
177143
return ops

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,30 @@ def is_linear_permute(self, node: torch.fx.Node) -> bool:
108108

109109
return False
110110

111+
def is_in_local_scalar_dense_chain(self, node: torch.fx.Node) -> bool:
112+
"""
113+
Scalar tensors are usually converted to scalar values in the graph via`
114+
scalar_tensor[0].item()` in Python, which translates to a chain of
115+
`local_scalar_dense(torch.select.int(scalar_tensor, 0, 0))` in the graph.
116+
This function marks the entire chain as supported by the Vulkan delegate.
117+
118+
Later, within vulkan_preprocess there will be a graph transform which
119+
replaces the chain with passing in the scalar tensor directly.
120+
"""
121+
if node.target == exir_ops.edge.aten.select_copy.int:
122+
if len(node.users) != 1:
123+
return False
124+
if node.args[0].meta["val"].numel() != 1:
125+
return False
126+
127+
user = list(node.users.keys())[0]
128+
return user.target == torch.ops.aten._local_scalar_dense.default
129+
130+
if node.target == torch.ops.aten._local_scalar_dense.default:
131+
return True
132+
133+
return False
134+
111135
def is_node_supported(
112136
self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node
113137
) -> bool:
@@ -122,6 +146,9 @@ def _is_node_supported(
122146
if self.is_linear_permute(node):
123147
return True
124148

149+
if self.is_in_local_scalar_dense_chain(node):
150+
return True
151+
125152
if node.target not in VulkanSupportedOperators._ops:
126153
return False
127154

backends/vulkan/passes/TARGETS

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,16 @@ python_unittest(
2727
"//caffe2:torch",
2828
],
2929
)
30+
31+
runtime.python_library(
32+
name = "remove_local_scalar_dense",
33+
srcs = ["remove_local_scalar_dense_ops.py"],
34+
visibility = [
35+
"//executorch/backends/...",
36+
],
37+
deps = [
38+
"//caffe2:torch",
39+
"//executorch/exir:pass_base",
40+
"//executorch/exir/dialects:lib",
41+
],
42+
)
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import torch
10+
from executorch.exir.dialects._ops import ops as exir_ops
11+
from executorch.exir.pass_base import ExportPass, PassResult
12+
13+
14+
def remove_local_scalar_dense_ops(graph: torch.fx.Graph) -> torch.fx.Graph:
15+
"""
16+
Remove local_scalar_dense op nodes and replace uses with parent node, or the
17+
original scalar tensor.
18+
"""
19+
target_op = torch.ops.aten._local_scalar_dense.default
20+
for node in graph.nodes:
21+
if node.op == "call_function" and node.target == target_op:
22+
replace_node = node.args[0]
23+
# If the argument to the local_scalar_dense op is a select op with only
24+
# one user, and the argument to the select op is a tensor with only one
25+
# element (i.e. a scalar tensor), then replace the entire pattern with the
26+
# scalar tensor.
27+
if (
28+
replace_node.op == "call_function"
29+
and replace_node.target == exir_ops.edge.aten.select_copy.int
30+
):
31+
if replace_node.args[0].meta["val"].numel() == 1:
32+
replace_node = replace_node.args[0]
33+
34+
with graph.inserting_after(node):
35+
node.replace_all_uses_with(replace_node)
36+
37+
graph.eliminate_dead_code()
38+
return graph
39+
40+
41+
class RemoveLocalScalarDenseOpsTransform(ExportPass):
42+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
43+
graph_module.graph = remove_local_scalar_dense_ops(graph_module.graph)
44+
return PassResult(graph_module, True)

backends/vulkan/vulkan_preprocess.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717
from executorch.backends.transforms.mean_to_sum_div import MeanToSumDiv
1818
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
1919

20+
from executorch.backends.vulkan.passes.remove_local_scalar_dense_ops import (
21+
RemoveLocalScalarDenseOpsTransform,
22+
)
23+
2024
from executorch.backends.vulkan.serialization.vulkan_graph_builder import VkGraphBuilder
2125
from executorch.backends.vulkan.serialization.vulkan_graph_serialize import (
2226
serialize_vulkan_graph,
@@ -57,6 +61,7 @@ def preprocess( # noqa: C901
5761
MeanToSumDiv(),
5862
SpecPropPass(),
5963
ConstraintBasedSymShapeEvalPass(),
64+
RemoveLocalScalarDenseOpsTransform(),
6065
MemoryPlanningPass(),
6166
]
6267

0 commit comments

Comments
 (0)