Skip to content

Commit fdb7392

Browse files
committed
Update base for Update on "[ET-VK] Introduce AOT operator registry"
## Changes Move the following files to the root directory of Vulkan backend: * `backends/vulkan/partitioner/supported_ops.py` -> `backends/vulkan/op_registry.py` * `backends/vulkan/_passes/custom_ops_defs.py` -> `backends/vulkan/custom_ops_lib.py` In the new `op_registry.py` file, the way operator features are specified is reworked to provide much more detail about the features of the operator implementation in Vulkan. See the new `OpFeatures` class for more details. An example of registering a new operator to the export flow is ``` update_features( [ exir_ops.edge.aten._log_softmax.default, exir_ops.edge.aten._softmax.default, exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.sum.dim_IntList, exir_ops.edge.aten.amax.default, exir_ops.edge.aten.amin.default, ] ) def register_reduce_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( uses_packed_dim=True, ) features.resize_fn = True def check_reduce_node(node: torch.fx.Node) -> bool: dim_list = node.args[1] assert isinstance(dim_list, list) if len(dim_list) != 1: return False keepdim = node.args[2] assert isinstance(keepdim, bool) if not keepdim: return False return True features.check_node_fn = check_reduce_node return features ``` ## Rationale The purpose of these changes is to centralize operator definitions so that there is a common source of truth about the capabilities of operator implementation in Vulkan. This way, the partitioner does not have to implement ad-hoc functions for specific operators (i.e. `is_valid_to_copy`) and graph transforms do not have to maintain their own operator metadata (`USES_WEIGHTS` in `insert_prepack_nodes`). Differential Revision: [D64915640](https://our.internmc.facebook.com/intern/diff/D64915640/) [ghstack-poisoned]
2 parents 86764d1 + e93ad5f commit fdb7392

File tree

2 files changed

+22
-5
lines changed

2 files changed

+22
-5
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,11 @@ def quantized_conv_meta(
132132
out_shift: torch.Tensor,
133133
channel_last: bool = False,
134134
) -> torch.Tensor:
135-
out_channels, _in_channels, *kernel_size = weight.shape
135+
if channel_last:
136+
out_channels, *kernel_size, _ = weight.shape
137+
else:
138+
out_channels, _, *kernel_size = weight.shape
139+
136140
in_size = input.shape
137141
# Assert that the input tensor has at least 3 dimensions, and at most 6
138142
assert len(in_size) > 2
@@ -141,7 +145,13 @@ def quantized_conv_meta(
141145
# Compute the output tensor size
142146
output_size = (
143147
get_conv1d_output_size(
144-
in_size, out_channels, stride[1], padding[1], dilation[1], kernel_size[0]
148+
in_size,
149+
out_channels,
150+
stride[1],
151+
padding[1],
152+
dilation[1],
153+
kernel_size[0],
154+
channel_last,
145155
)
146156
if len(in_size) == 3
147157
else get_conv2d_output_size(

backends/cadence/aot/utils.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,20 @@ def get_conv1d_output_size(
4343
padding: int,
4444
dilation: int,
4545
kernel_size: int,
46+
channel_last: bool,
4647
) -> torch.Size:
4748
assert len(in_size) == 3
48-
N, C, L = in_size
49+
if channel_last:
50+
N, L, C = in_size
51+
else:
52+
N, C, L = in_size
4953

5054
# Reference: https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
5155
lout = (L + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1
5256

53-
return torch.Size((in_size[0], out_channels, lout))
57+
if channel_last:
58+
return torch.Size((N, lout, out_channels))
59+
return torch.Size((N, out_channels, lout))
5460

5561

5662
# Get the output size of a 2D convolution given the input size and parameters
@@ -76,7 +82,8 @@ def get_conv2d_output_size(
7682
wout = (W + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) // stride[
7783
1
7884
] + 1
79-
85+
if channel_last:
86+
return torch.Size((N, hout, wout, out_channels))
8087
return torch.Size((in_size[0], out_channels, hout, wout))
8188

8289

0 commit comments

Comments
 (0)