From 28a84ae51355acbc12698f0a5a5f28d6c7598ce8 Mon Sep 17 00:00:00 2001 From: ssjia Date: Wed, 5 Nov 2025 13:16:51 -0800 Subject: [PATCH] [ET-VK] Implementation of to_dim_order_copy Title says it all! Previously, to_dim_order_copy was handled by removing the op. However, this is not possible if the op is modifying the dtype of the original tensor, so these instances of the op would be skipped by the partitioner. This diff adds an implementation dtype conversion, which allows to_dim_order_copy to be lowered. Differential Revision: [D86340341](https://our.internmc.facebook.com/intern/diff/D86340341/) [ghstack-poisoned] --- .../vulkan/_passes/remove_redundant_ops.py | 36 ++++----- backends/vulkan/op_registry.py | 21 +----- .../runtime/graph/ops/glsl/view_buffer.glsl | 21 ++++-- .../graph/ops/glsl/view_convert_buffer.glsl | 54 ++++++++++++++ .../graph/ops/glsl/view_convert_buffer.yaml | 22 ++++++ .../runtime/graph/ops/impl/Unsqueeze.cpp | 12 +++ .../vulkan/runtime/graph/ops/impl/View.cpp | 74 +++++++++++++++++++ backends/vulkan/runtime/graph/ops/impl/View.h | 13 ++++ 8 files changed, 209 insertions(+), 44 deletions(-) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/view_convert_buffer.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/view_convert_buffer.yaml diff --git a/backends/vulkan/_passes/remove_redundant_ops.py b/backends/vulkan/_passes/remove_redundant_ops.py index 8e602dd17b4..25bdd34de70 100644 --- a/backends/vulkan/_passes/remove_redundant_ops.py +++ b/backends/vulkan/_passes/remove_redundant_ops.py @@ -31,35 +31,37 @@ class RemoveRedundantOpsTransform(ExportPass): exir_ops.edge.aten.lift_fresh_copy.default, exir_ops.edge.dim_order_ops._to_dim_order_copy.default, exir_ops.edge.dim_order_ops._clone_dim_order.default, + exir_ops.edge.aten.expand_copy.default, } def __init__(self) -> None: super(RemoveRedundantOpsTransform, self).__init__() def _should_remove(self, node: torch.fx.Node) -> bool: - if node.target in self.redundant_ops: - return True - - # Only remove to_copy if dtype does not change. Otherwise, memory format changes - # will be handled internally by the backend. - if ( - node.target == exir_ops.edge.aten._to_copy.default - or node.target == torch.ops.aten._to_copy.default - ): - src_dtype = node.meta["val"].dtype - # pyre-ignore - dst_dtype = node.args[0].meta["val"].dtype - return src_dtype == dst_dtype - - return False + if node.target not in self.redundant_ops: + return False + + orig_node = node.args[0] + assert isinstance(orig_node, torch.fx.Node) + + src_dtype = orig_node.meta["val"].dtype + dst_dtype = node.meta["val"].dtype + + # Do not remove if the op is converting the dtype. + if src_dtype != dst_dtype: + return False + + src_shape = orig_node.meta["val"].shape + dst_shape = node.meta["val"].shape + + return src_shape == dst_shape def _remove(self, graph_module: torch.fx.GraphModule) -> None: for node in graph_module.graph.nodes: if not self._should_remove(node): continue - with graph_module.graph.inserting_after(node): - node.replace_all_uses_with(node.args[0]) + node.replace_all_uses_with(node.args[0]) graph_module.graph.eliminate_dead_code() diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index b47a8f383a0..cf032727857 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -129,6 +129,7 @@ def update_features_impl(op: OpKey): # Symbolic integer ops torch.ops.aten.sym_size.int, operator.add, + operator.sub, operator.lt, operator.gt, operator.ge, @@ -297,27 +298,9 @@ def check_to_copy_node(node: torch.fx.Node) -> bool: @update_features(exir_ops.edge.dim_order_ops._to_dim_order_copy.default) def register_to_copy_dim_order_op(): - # Currently there is no "real" implementation for to_dim_order_copy, but it can be - # removed as long as the operator is not changing the dtype, i.e. the operator call - # is modifying the dim order only. Therefore, check that the input and output dtypes - # are the same, if so the operator is safe to remove. - def check_dim_order_copy_node(node: torch.fx.Node) -> bool: - in_arg = node.args[0] - if not isinstance(in_arg, torch.fx.Node): - return False - - in_tensor = in_arg.meta.get("val", None) - out_tensor = node.meta.get("val", None) - - if in_tensor.dtype != out_tensor.dtype: - return False - - return True - return OpFeatures( - inputs_storage=utils.ANY_STORAGE, + inputs_storage=utils.ANY_BUFFER, supports_resize=True, - are_node_inputs_supported_fn=check_dim_order_copy_node, ) diff --git a/backends/vulkan/runtime/graph/ops/glsl/view_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/view_buffer.glsl index 2c02803a9b1..96b9aa85a1f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/view_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/view_buffer.glsl @@ -18,6 +18,8 @@ ${layout_declare_ubo(B, "BufferMetadata", "inp")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +${layout_declare_spec_const(C, "int", "all_contiguous", "0")} + /* * The insight behind the view operation is that the contiguous index of each * tensor element in the input and output tensors are the same. @@ -28,17 +30,20 @@ void main() { return; } - TensorIndex outp_tidx; - linear_idx_to_tensor_idx(outp, outp_bufi, outp_tidx); + uint inp_bufi = outp_bufi; + if (all_contiguous == 0) { + TensorIndex outp_tidx; + linear_idx_to_tensor_idx(outp, outp_bufi, outp_tidx); - // To map the output to the input, find the input element that has the same - // contiguous index as the output element. - const uint contig_idx = tensor_idx_to_contiguous_idx(outp, outp_tidx); + // To map the output to the input, find the input element that has the same + // contiguous index as the output element. + const uint contig_idx = tensor_idx_to_contiguous_idx(outp, outp_tidx); - TensorIndex inp_tidx; - contiguous_idx_to_tensor_idx(inp, contig_idx, inp_tidx); + TensorIndex inp_tidx; + contiguous_idx_to_tensor_idx(inp, contig_idx, inp_tidx); - const uint inp_bufi = tensor_idx_to_linear_idx(inp, inp_tidx); + inp_bufi = tensor_idx_to_linear_idx(inp, inp_tidx); + } t_outp[outp_bufi] = t_inp[inp_bufi]; } diff --git a/backends/vulkan/runtime/graph/ops/glsl/view_convert_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/view_convert_buffer.glsl new file mode 100644 index 00000000000..a926c9fea11 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/view_convert_buffer.glsl @@ -0,0 +1,54 @@ +#version 450 core + +#define PRECISION ${PRECISION} + +#define IN_T ${buffer_scalar_type(IN_DTYPE)} +#define OUT_T ${buffer_scalar_type(OUT_DTYPE)} + +${define_required_extensions(IN_DTYPE)} +${define_required_extensions(OUT_DTYPE)} + +layout(std430) buffer; + +#include "indexing.glslh" + +${layout_declare_buffer(B, "w", "t_outp", OUT_DTYPE)} +${layout_declare_buffer(B, "r", "t_inp", IN_DTYPE)} + +${layout_declare_ubo(B, "BufferMetadata", "outp")} +${layout_declare_ubo(B, "BufferMetadata", "inp")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "all_contiguous", "0")} + +/* + * The insight behind the view_convert operation is that the contiguous index of each + * tensor element in the input and output tensors are the same, but the data types + * may be different and need conversion. + */ +void main() { + const uint outp_bufi = gl_GlobalInvocationID.x; + if (outp_bufi >= numel(outp)) { + return; + } + + uint inp_bufi = outp_bufi; + + if (all_contiguous == 0) { + TensorIndex outp_tidx; + linear_idx_to_tensor_idx(outp, outp_bufi, outp_tidx); + + // To map the output to the input, find the input element that has the same + // contiguous index as the output element. + const uint contig_idx = tensor_idx_to_contiguous_idx(outp, outp_tidx); + + TensorIndex inp_tidx; + contiguous_idx_to_tensor_idx(inp, contig_idx, inp_tidx); + + inp_bufi = tensor_idx_to_linear_idx(inp, inp_tidx); + } + + // Convert data type from input to output + t_outp[outp_bufi] = OUT_T(t_inp[inp_bufi]); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/view_convert_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/view_convert_buffer.yaml new file mode 100644 index 00000000000..11d56cad4a9 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/view_convert_buffer.yaml @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +view_convert_buffer: + parameter_names_with_default_values: + IN_DTYPE: float + OUT_DTYPE: float + STORAGE: buffer + generate_variant_forall: + combination: + parameter_names: [IN_DTYPE, OUT_DTYPE] + combos: + - parameter_values: [int32, float] + - parameter_values: [int32, half] + - parameter_values: [uint8, float] + - parameter_values: [uint8, half] + - parameter_values: [uint8, int32] + shader_variants: + - NAME: view_convert_buffer diff --git a/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp b/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp index 36a8ee4c3b1..602fe1ef129 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp @@ -67,6 +67,18 @@ void resize_unsqueeze_node( std::vector out_sizes = graph->sizes_of(in); + std::vector unsqueezed_dims; + + if (graph->val_is_int_list(dims_ref)) { + const IntListPtr dims = graph->get_int_list(dims_ref); + for (int64_t d : *dims) { + unsqueezed_dims.push_back(d); + } + } else { + const int64_t dim = graph->extract_scalar(dims_ref); + unsqueezed_dims.push_back(dim); + } + // Insert singleton dimensions at the specified positions for (auto dim : dims_vec) { int64_t d = dim; diff --git a/backends/vulkan/runtime/graph/ops/impl/View.cpp b/backends/vulkan/runtime/graph/ops/impl/View.cpp index 8701a6246b0..4efb2e94e95 100644 --- a/backends/vulkan/runtime/graph/ops/impl/View.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/View.cpp @@ -60,6 +60,16 @@ void resize_view_node( } } +void resize_to_dim_order_copy_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + const ValueRef out = args.at(0).refs.at(0); + const ValueRef in = args.at(1).refs.at(0); + const std::vector in_sizes = graph->sizes_of(in); + graph->virtual_resize(out, in_sizes); +} + void add_view_node( ComputeGraph& graph, ValueRef in, @@ -98,6 +108,11 @@ void add_view_copy_buffer_node( std::string kernel_name = "view_buffer"; add_dtype_suffix(kernel_name, graph.dtype_of(out)); + bool all_contiguous = graph.is_contiguous_buffer_tensor(in) && + graph.is_contiguous_buffer_tensor(out); + + int32_t all_contiguous_int = all_contiguous ? 1 : 0; + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), @@ -110,7 +125,41 @@ void add_view_copy_buffer_node( // Push Constants {}, // Specialization Constants + {all_contiguous_int}, + // Resize Args + resize_args, + // Resizing Logic + resize_fn)); +} + +void add_view_copy_convert_buffer_node( + ComputeGraph& graph, + ValueRef in, + ValueRef out, + const std::vector& resize_args, + const ExecuteNode::ResizeFunction& resize_fn) { + std::string kernel_name = "view_convert_buffer"; + add_dtype_suffix(kernel_name, graph.dtype_of(in)); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + bool all_contiguous = graph.is_contiguous_buffer_tensor(in) && + graph.is_contiguous_buffer_tensor(out); + + int32_t all_contiguous_int = all_contiguous ? 1 : 0; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + default_pick_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, {in, vkapi::kRead}}, + // Parameter Buffers + {graph.buffer_meta_ubo(out), graph.buffer_meta_ubo(in)}, + // Push Constants {}, + // Specialization Constants + {all_contiguous_int}, // Resize Args resize_args, // Resizing Logic @@ -132,8 +181,33 @@ void view(ComputeGraph& graph, const std::vector& args) { return add_view_node(graph, in, sizes, out); } +void to_dim_order_copy(ComputeGraph& graph, const std::vector& args) { + int args_idx = 0; + const ValueRef in = args.at(args_idx++); + const ValueRef dtype = args.at(args_idx++); + (void)dtype; + const ValueRef layout = args.at(args_idx++); + (void)layout; + const ValueRef device = args.at(args_idx++); + (void)device; + const ValueRef pin_memory = args.at(args_idx++); + (void)pin_memory; + const ValueRef non_blocking = args.at(args_idx++); + (void)non_blocking; + const ValueRef dim_order = args.at(args_idx++); + (void)dim_order; + + const ValueRef out = args.at(args_idx++); + + VK_CHECK_COND(graph.is_buffer_storage(in) && graph.is_buffer_storage(out)); + + return add_view_copy_convert_buffer_node( + graph, in, out, {}, resize_to_dim_order_copy_node); +} + REGISTER_OPERATORS { VK_REGISTER_OP(aten.view_copy.default, view); + VK_REGISTER_OP(dim_order_ops._to_dim_order_copy.default, to_dim_order_copy); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/View.h b/backends/vulkan/runtime/graph/ops/impl/View.h index 7a7a8d57742..c8e52492417 100644 --- a/backends/vulkan/runtime/graph/ops/impl/View.h +++ b/backends/vulkan/runtime/graph/ops/impl/View.h @@ -24,6 +24,19 @@ void add_view_copy_buffer_node( const std::vector& resize_args, const ExecuteNode::ResizeFunction& resize_fn); +/* + * Dispatches the view_convert_buffer compute shader. This can be used to + * implement ops that preserve the "contiguous" indexes of elements between the + * input and output while converting between different data types such as + * view_copy with dtype conversion. + */ +void add_view_copy_convert_buffer_node( + ComputeGraph& graph, + ValueRef in, + ValueRef out, + const std::vector& resize_args, + const ExecuteNode::ResizeFunction& resize_fn); + void add_view_node( ComputeGraph& graph, ValueRef in,