Skip to content

Commit bda9a8d

Browse files
committed
Update on "[ET-VK] Introduce AOT operator registry"
## Changes Move the following files to the root directory of Vulkan backend: * `backends/vulkan/partitioner/supported_ops.py` -> `backends/vulkan/op_registry.py` * `backends/vulkan/_passes/custom_ops_defs.py` -> `backends/vulkan/custom_ops_lib.py` In the new `op_registry.py` file, the way operator features are specified is reworked to provide much more detail about the features of the operator implementation in Vulkan. See the new `OpFeatures` class for more details. An example of registering a new operator to the export flow is ``` update_features( [ exir_ops.edge.aten._log_softmax.default, exir_ops.edge.aten._softmax.default, exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.sum.dim_IntList, exir_ops.edge.aten.amax.default, exir_ops.edge.aten.amin.default, ] ) def register_reduce_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( uses_packed_dim=True, ) features.resize_fn = True def check_reduce_node(node: torch.fx.Node) -> bool: dim_list = node.args[1] assert isinstance(dim_list, list) if len(dim_list) != 1: return False keepdim = node.args[2] assert isinstance(keepdim, bool) if not keepdim: return False return True features.check_node_fn = check_reduce_node return features ``` ## Rationale The purpose of these changes is to centralize operator definitions so that there is a common source of truth about the capabilities of operator implementation in Vulkan. This way, the partitioner does not have to implement ad-hoc functions for specific operators (i.e. `is_valid_to_copy`) and graph transforms do not have to maintain their own operator metadata (`USES_WEIGHTS` in `insert_prepack_nodes`). Differential Revision: [D64915640](https://our.internmc.facebook.com/intern/diff/D64915640/) [ghstack-poisoned]
2 parents 253285a + fdb7392 commit bda9a8d

File tree

7 files changed

+113
-95
lines changed

7 files changed

+113
-95
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,11 @@ def quantized_conv_meta(
132132
out_shift: torch.Tensor,
133133
channel_last: bool = False,
134134
) -> torch.Tensor:
135-
out_channels, _in_channels, *kernel_size = weight.shape
135+
if channel_last:
136+
out_channels, *kernel_size, _ = weight.shape
137+
else:
138+
out_channels, _, *kernel_size = weight.shape
139+
136140
in_size = input.shape
137141
# Assert that the input tensor has at least 3 dimensions, and at most 6
138142
assert len(in_size) > 2
@@ -141,7 +145,13 @@ def quantized_conv_meta(
141145
# Compute the output tensor size
142146
output_size = (
143147
get_conv1d_output_size(
144-
in_size, out_channels, stride[1], padding[1], dilation[1], kernel_size[0]
148+
in_size,
149+
out_channels,
150+
stride[1],
151+
padding[1],
152+
dilation[1],
153+
kernel_size[0],
154+
channel_last,
145155
)
146156
if len(in_size) == 3
147157
else get_conv2d_output_size(

backends/cadence/aot/utils.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,20 @@ def get_conv1d_output_size(
4343
padding: int,
4444
dilation: int,
4545
kernel_size: int,
46+
channel_last: bool,
4647
) -> torch.Size:
4748
assert len(in_size) == 3
48-
N, C, L = in_size
49+
if channel_last:
50+
N, L, C = in_size
51+
else:
52+
N, C, L = in_size
4953

5054
# Reference: https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
5155
lout = (L + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1
5256

53-
return torch.Size((in_size[0], out_channels, lout))
57+
if channel_last:
58+
return torch.Size((N, lout, out_channels))
59+
return torch.Size((N, out_channels, lout))
5460

5561

5662
# Get the output size of a 2D convolution given the input size and parameters
@@ -76,7 +82,8 @@ def get_conv2d_output_size(
7682
wout = (W + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) // stride[
7783
1
7884
] + 1
79-
85+
if channel_last:
86+
return torch.Size((N, hout, wout, out_channels))
8087
return torch.Size((in_size[0], out_channels, hout, wout))
8188

8289

backends/vulkan/op_registry.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,9 @@ def __init__(
9292
def update_features(aten_op):
9393
def features_decorator(fn: Callable):
9494
def update_features_impl(op: OpKey):
95-
if op not in vulkan_supported_ops:
96-
vulkan_supported_ops[op] = OpFeatures()
95+
if op in vulkan_supported_ops:
96+
raise RuntimeError(f"[Vulkan delegate] duplicate registration of {op}!")
97+
vulkan_supported_ops[op] = OpFeatures()
9798
vulkan_supported_ops[op] = fn(vulkan_supported_ops[op])
9899

99100
if isinstance(aten_op, list):
@@ -165,7 +166,6 @@ def register_binary_op(features: OpFeatures):
165166
exir_ops.edge.aten.sqrt.default,
166167
exir_ops.edge.aten.rsqrt.default,
167168
exir_ops.edge.aten.tanh.default,
168-
exir_ops.edge.aten._to_copy.default,
169169
]
170170
)
171171
def register_unary_op(features: OpFeatures):
@@ -216,8 +216,6 @@ def check_to_copy_node(node: torch.fx.Node) -> bool:
216216
exir_ops.edge.aten.mm.default,
217217
exir_ops.edge.aten.addmm.default,
218218
exir_ops.edge.aten.linear.default,
219-
exir_ops.edge.et_vk.linear_weight_int4.default,
220-
exir_ops.edge.aten._weight_int8pack_mm.default,
221219
]
222220
)
223221
def register_mm_op(features: OpFeatures):
@@ -276,8 +274,6 @@ def register_softmax_op(features: OpFeatures):
276274

277275
@update_features(
278276
[
279-
exir_ops.edge.aten._log_softmax.default,
280-
exir_ops.edge.aten._softmax.default,
281277
exir_ops.edge.aten.mean.dim,
282278
exir_ops.edge.aten.sum.dim_IntList,
283279
exir_ops.edge.aten.amax.default,
@@ -366,9 +362,6 @@ def register_view_op(features: OpFeatures):
366362
# packed tensors only and do not have a resize function.
367363
@update_features(
368364
[
369-
# Normalization
370-
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
371-
exir_ops.edge.aten.native_layer_norm.default,
372365
# Shape Manipulation
373366
exir_ops.edge.aten.squeeze_copy.dims,
374367
exir_ops.edge.aten.unsqueeze_copy.default,

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,12 +139,16 @@ def is_in_local_scalar_dense_chain(self, node: torch.fx.Node) -> bool:
139139

140140
return False
141141

142+
def log_skip(self, node: torch.fx.Node, reason: str) -> None:
143+
if node.op == "call_function":
144+
logger.info(
145+
f"[Vulkan Partitioner] Due to [{reason}], skipping {node.format_node()}"
146+
)
147+
142148
def is_node_supported(
143149
self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node
144150
) -> bool:
145151
r = self._is_node_supported(submodules, node)
146-
if not r and node.op == "call_function":
147-
logger.info(f"Skipping node in Vulkan partitioning: {node.format_node()}")
148152
return r
149153

150154
def _is_node_supported(
@@ -163,14 +167,17 @@ def _is_node_supported(
163167
return True
164168

165169
if target not in vulkan_supported_ops:
170+
self.log_skip(node, "not in vulkan_supported_ops")
166171
return False
167172

168173
features = vulkan_supported_ops[target]
169174

170175
if not features.check_node_fn(node):
176+
self.log_skip(node, "op args not supported")
171177
return False
172178

173179
if self.require_dynamic_shapes and not features.resize_fn:
180+
self.log_skip(node, "no dynamic shape support")
174181
return False
175182

176183
return self.all_args_compatible(node)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
load(":targets.bzl", "define_common_targets")
22
oncall("executorch")
33

4-
define_common_targets()
4+
define_common_targets(is_fbcode = True)

backends/vulkan/serialization/targets.bzl

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,6 @@
11
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
22

3-
def define_common_targets():
4-
runtime.python_library(
5-
name = "lib",
6-
srcs = [
7-
"vulkan_graph_builder.py",
8-
"vulkan_graph_schema.py",
9-
"vulkan_graph_serialize.py",
10-
],
11-
resources = [
12-
"schema.fbs",
13-
],
14-
visibility = [
15-
"//executorch/...",
16-
"//executorch/vulkan/...",
17-
"@EXECUTORCH_CLIENTS",
18-
],
19-
deps = [
20-
"//executorch/exir:graph_module",
21-
"//executorch/exir/_serialize:_bindings",
22-
"//executorch/exir/_serialize:lib",
23-
],
24-
)
25-
3+
def define_common_targets(is_fbcode = False):
264
runtime.genrule(
275
name = "gen_vk_delegate_schema",
286
srcs = ["schema.fbs"],
@@ -57,3 +35,26 @@ def define_common_targets():
5735
"flatbuffers-api",
5836
],
5937
)
38+
39+
if is_fbcode:
40+
runtime.python_library(
41+
name = "lib",
42+
srcs = [
43+
"vulkan_graph_builder.py",
44+
"vulkan_graph_schema.py",
45+
"vulkan_graph_serialize.py",
46+
],
47+
resources = [
48+
"schema.fbs",
49+
],
50+
visibility = [
51+
"//executorch/...",
52+
"//executorch/vulkan/...",
53+
"@EXECUTORCH_CLIENTS",
54+
],
55+
deps = [
56+
"//executorch/exir:graph_module",
57+
"//executorch/exir/_serialize:_bindings",
58+
"//executorch/exir/_serialize:lib",
59+
],
60+
)

backends/vulkan/targets.bzl

Lines changed: 54 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -203,59 +203,59 @@ def define_common_targets(is_fbcode = False):
203203
##
204204
## AOT targets
205205
##
206+
if is_fbcode:
207+
runtime.python_library(
208+
name = "custom_ops_lib",
209+
srcs = [
210+
"custom_ops_lib.py"
211+
],
212+
visibility = [
213+
"//executorch/...",
214+
"//executorch/vulkan/...",
215+
"@EXECUTORCH_CLIENTS",
216+
],
217+
deps = [
218+
"//caffe2:torch",
219+
]
220+
)
206221

207-
runtime.python_library(
208-
name = "custom_ops_lib",
209-
srcs = [
210-
"custom_ops_lib.py"
211-
],
212-
visibility = [
213-
"//executorch/...",
214-
"//executorch/vulkan/...",
215-
"@EXECUTORCH_CLIENTS",
216-
],
217-
deps = [
218-
"//caffe2:torch",
219-
]
220-
)
222+
runtime.python_library(
223+
name = "op_registry",
224+
srcs = [
225+
"op_registry.py",
226+
],
227+
visibility = [
228+
"//executorch/...",
229+
"//executorch/vulkan/...",
230+
"@EXECUTORCH_CLIENTS",
231+
],
232+
deps = [
233+
":custom_ops_lib",
234+
"//caffe2:torch",
235+
"//executorch/exir/dialects:lib",
236+
"//executorch/backends/vulkan/serialization:lib",
237+
]
238+
)
221239

222-
runtime.python_library(
223-
name = "op_registry",
224-
srcs = [
225-
"op_registry.py",
226-
],
227-
visibility = [
228-
"//executorch/...",
229-
"//executorch/vulkan/...",
230-
"@EXECUTORCH_CLIENTS",
231-
],
232-
deps = [
233-
":custom_ops_lib",
234-
"//caffe2:torch",
235-
"//executorch/exir/dialects:lib",
236-
"//executorch/backends/vulkan/serialization:lib",
237-
]
238-
)
239-
240-
runtime.python_library(
241-
name = "vulkan_preprocess",
242-
srcs = [
243-
"vulkan_preprocess.py",
244-
],
245-
visibility = [
246-
"//executorch/...",
247-
"//executorch/vulkan/...",
248-
"@EXECUTORCH_CLIENTS",
249-
],
250-
deps = [
251-
"//executorch/backends/transforms:addmm_mm_to_linear",
252-
"//executorch/backends/transforms:fuse_batch_norm_with_conv",
253-
"//executorch/backends/transforms:fuse_conv_with_clamp",
254-
"//executorch/backends/transforms:fuse_dequant_linear",
255-
"//executorch/backends/transforms:fuse_view_copy",
256-
"//executorch/backends/transforms:remove_clone_ops",
257-
"//executorch/backends/vulkan/_passes:vulkan_passes",
258-
"//executorch/backends/vulkan/serialization:lib",
259-
"//executorch/exir/backend:backend_details",
260-
],
261-
)
240+
runtime.python_library(
241+
name = "vulkan_preprocess",
242+
srcs = [
243+
"vulkan_preprocess.py",
244+
],
245+
visibility = [
246+
"//executorch/...",
247+
"//executorch/vulkan/...",
248+
"@EXECUTORCH_CLIENTS",
249+
],
250+
deps = [
251+
"//executorch/backends/transforms:addmm_mm_to_linear",
252+
"//executorch/backends/transforms:fuse_batch_norm_with_conv",
253+
"//executorch/backends/transforms:fuse_conv_with_clamp",
254+
"//executorch/backends/transforms:fuse_dequant_linear",
255+
"//executorch/backends/transforms:fuse_view_copy",
256+
"//executorch/backends/transforms:remove_clone_ops",
257+
"//executorch/backends/vulkan/_passes:vulkan_passes",
258+
"//executorch/backends/vulkan/serialization:lib",
259+
"//executorch/exir/backend:backend_details",
260+
],
261+
)

0 commit comments

Comments
 (0)