diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 0486110ced6..2aa940dcc4b 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -500,7 +500,12 @@ def register_sdpa_with_kv_cache_op(features: OpFeatures): return features -@update_features(["llama::update_cache", "llama::custom_sdpa"]) +@update_features( + [ + "llama::update_cache", + "llama::custom_sdpa", + ] +) def register_sdpa_ops(features: OpFeatures): features.resize_fn = False features.buffer_impl = False @@ -520,8 +525,17 @@ def register_rotary_emb_op(features: OpFeatures): return features -@update_features(exir_ops.edge.aten.view_copy.default) -def register_view_op(features: OpFeatures): +@update_features( + [ + exir_ops.edge.aten.clone.default, + exir_ops.edge.aten.permute.default, + exir_ops.edge.aten.permute_copy.default, + exir_ops.edge.aten.select_copy.int, + exir_ops.edge.aten.slice_copy.Tensor, + exir_ops.edge.aten.view_copy.default, + ] +) +def register_view_ops(features: OpFeatures): features.texture_impl = TextureImplFeatures( valid_packed_dims=all_packed_dims, ) @@ -538,10 +552,8 @@ def register_view_op(features: OpFeatures): # Indexing and lookup exir_ops.edge.aten.flip.default, exir_ops.edge.aten.index_select.default, - exir_ops.edge.aten.select_copy.int, # Tensor creation exir_ops.edge.aten.arange.start_step, - exir_ops.edge.aten.clone.default, exir_ops.edge.aten.constant_pad_nd.default, exir_ops.edge.aten.full.default, exir_ops.edge.aten.full_like.default, @@ -564,12 +576,9 @@ def register_ported_op(features: OpFeatures): # Ops ported from PyTorch Vulkan backend. These ops are in a separate registry becasue they support all packed dimensions @update_features( [ - # Indexing and lookup - exir_ops.edge.aten.slice_copy.Tensor, # Shape Manipulation exir_ops.edge.aten.squeeze_copy.dims, exir_ops.edge.aten.unsqueeze_copy.default, - exir_ops.edge.aten.permute_copy.default, # Tensor combination exir_ops.edge.aten.cat.default, exir_ops.edge.aten.repeat.default, diff --git a/backends/vulkan/runtime/graph/ops/impl/Permute.cpp b/backends/vulkan/runtime/graph/ops/impl/Permute.cpp index 8e2c72d7627..fba3f03467b 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Permute.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Permute.cpp @@ -25,10 +25,12 @@ using utils::uvec4; namespace { void check_args( - const api::vTensor& in, - const std::vector& permute_dims, - const api::vTensor& out) { - VK_CHECK_COND(check_same_packed_dim(in, out)); + ComputeGraph& graph, + const ValueRef in, + const ValueRef permute_dims, + const ValueRef out) { + (void)permute_dims; + VK_CHECK_COND(check_same_packed_dim(graph, in, out)); // This implementation doesn't not requires the input tensor to have the same // dim size as the argument. The code will work as long as the input tensor's @@ -38,40 +40,94 @@ void check_args( } // namespace +void resize_permute_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef out = args[0].refs[0]; + const ValueRef in = args[1].refs[0]; + + const std::vector in_sizes = graph->sizes_of(in); + const std::vector out_sizes = graph->sizes_of(out); + + const std::vector permute_dims = + graph->extract_int_or_symint_list(resize_args[0]); + + if (in_sizes.size() == out_sizes.size() && + in_sizes.size() == permute_dims.size()) { + std::vector new_out_sizes(out_sizes.size(), 1); + const int64_t out_ndim = std::max(in_sizes.size(), out_sizes.size()); + for (int i = 0; i < out_ndim; i++) { + const int64_t permute_dim = permute_dims.at(i); + new_out_sizes.at(i) = in_sizes.at(permute_dim); + } + graph->virtual_resize(out, new_out_sizes); + } + // Case where permute is being used to implement squeeze + else if ( + in_sizes.size() > out_sizes.size() && + in_sizes.size() == permute_dims.size()) { + std::vector new_out_sizes(out_sizes.size(), 1); + const size_t offset = in_sizes.size() - out_sizes.size(); + for (int i = 0; i < out_sizes.size(); i++) { + const int64_t permute_dim = permute_dims.at(i + offset); + new_out_sizes.at(i) = in_sizes.at(permute_dim); + } + graph->virtual_resize(out, new_out_sizes); + } + // Case where Permute is being used to implement unsqueeze + else if ( + in_sizes.size() < out_sizes.size() && + out_sizes.size() == permute_dims.size()) { + std::vector new_out_sizes(out_sizes.size(), 1); + const size_t offset = out_sizes.size() - in_sizes.size(); + for (int i = 0; i < out_sizes.size(); i++) { + int64_t permute_dim = permute_dims.at(i) - offset; + if (permute_dim >= 0) { + new_out_sizes.at(i) = in_sizes.at(permute_dim); + } + } + graph->virtual_resize(out, new_out_sizes); + } else { + VK_THROW("Invalid permute dims"); + } +} + void add_permute_node( ComputeGraph& graph, - ValueRef in, - const std::vector& permute_dims, - ValueRef out) { - vTensorPtr t_in = graph.get_tensor(in); - vTensorPtr t_out = graph.get_tensor(out); - - check_args(*t_in, permute_dims, *t_out); + const ValueRef in, + const ValueRef permute_dims, + const ValueRef out) { + check_args(graph, in, permute_dims, out); ivec4 out_dims{0, 1, 2, 3}; // Special cases of squeeze/unsqueeze. Because the input dim size can be - // different with output dim size. So pick t_in->dim() if squeeze, and - // t_out->dim() if unsqueeze to create parameter for permute. - int64_t out_ndim = std::max(t_in->dim(), t_out->dim()); + // different with output dim size. So pick graph.dim_of(in) if squeeze, and + // graph.dim_of(out) if unsqueeze to create parameter for permute. + const int64_t out_ndim = std::max(graph.dim_of(in), graph.dim_of(out)); std::vector seen(out_ndim); - for (int i = 0; i < out_ndim; i++) { - int64_t permute_dim = permute_dims[i]; - VK_CHECK_COND( - !seen[permute_dim], "Argument dim ", permute_dim, " is repeated"); - seen[permute_dim] = true; - - out_dims[(4u - out_ndim) + i] = permute_dim + (4 - out_ndim); + { + IntListPtr permute_dims_ptr = graph.get_int_list(permute_dims); + for (int i = 0; i < out_ndim; i++) { + int64_t permute_dim = permute_dims_ptr->at(i); + VK_CHECK_COND( + !seen[permute_dim], "Argument dim ", permute_dim, " is repeated"); + seen[permute_dim] = true; + + out_dims[(4u - out_ndim) + i] = + utils::safe_downcast(permute_dim + (4 - out_ndim)); + } } std::string kernel_name = "permute"; kernel_name.reserve(kShaderNameReserve); - add_dtype_suffix(kernel_name, *t_out); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); - int32_t out_channels = dim_at(t_out->sizes()); - int32_t in_channels = dim_at(t_in->sizes()); + const int32_t out_channels = dim_at(graph.sizes_of(out)); + const int32_t in_channels = dim_at(graph.sizes_of(in)); - const auto packed_dim = graph.packed_dim_of(in); + const int32_t packed_dim = graph.packed_dim_of(in); ivec2 channel_info = {out_channels, in_channels}; if (packed_dim == WHCN::kChannelsDim) { channel_info[0] = utils::align_up_4(channel_info[0]); @@ -95,19 +151,9 @@ void add_permute_node( // Specialization Constants spec_vars, // Resize Args - {}, + {permute_dims}, // Resizing Logic - nullptr)); -} - -void add_permute_node( - ComputeGraph& graph, - ValueRef in, - ValueRef permute_dims_ref, - ValueRef out) { - IntListPtr permute_dims = graph.get_int_list(permute_dims_ref); - - add_permute_node(graph, in, *permute_dims, out); + resize_permute_node)); } void permute(ComputeGraph& graph, const std::vector& args) { diff --git a/backends/vulkan/runtime/graph/ops/impl/Permute.h b/backends/vulkan/runtime/graph/ops/impl/Permute.h index 941a8896fe2..0f17a4a26b0 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Permute.h +++ b/backends/vulkan/runtime/graph/ops/impl/Permute.h @@ -18,8 +18,8 @@ namespace vkcompute { void add_permute_node( ComputeGraph& graph, - ValueRef in, - const std::vector& permute_dims, - ValueRef out); + const ValueRef in, + const ValueRef permute_dims, + const ValueRef out); } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp b/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp index ee40a043ee5..31bab144d8a 100644 --- a/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp @@ -15,14 +15,20 @@ namespace vkcompute { void resize_rotary_embedding_node( ComputeGraph* graph, const std::vector& args, - const std::vector& extra_args) { - (void)extra_args; - vTensorPtr out = graph->get_tensor(args[0].refs[0]); - vTensorPtr in = graph->get_tensor(args[1].refs[0]); - - std::vector in_sizes = in->sizes(); - // UNCOMMENT BELOW IF NEEDED - // out->virtual_resize(in_sizes); + const std::vector& resize_args) { + (void)resize_args; + + const ValueRef xq_out = args.at(0).refs.at(0); + const ValueRef xk_out = args.at(0).refs.at(1); + + const ValueRef xq = args.at(1).refs.at(0); + const ValueRef xk = args.at(1).refs.at(1); + + const std::vector xq_sizes = graph->sizes_of(xq); + const std::vector xk_sizes = graph->sizes_of(xk); + + graph->virtual_resize(xq_out, xq_sizes); + graph->virtual_resize(xk_out, xk_sizes); } void add_rotary_embedding_node( diff --git a/backends/vulkan/runtime/graph/ops/impl/Squeeze.cpp b/backends/vulkan/runtime/graph/ops/impl/Squeeze.cpp index b212d24f06b..249f5e7fa6b 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Squeeze.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Squeeze.cpp @@ -17,28 +17,29 @@ namespace vkcompute { void add_squeeze_copy_dims_node( ComputeGraph& graph, - ValueRef in, - ValueRef dims_ref, - ValueRef out) { - vTensorPtr t_in = graph.get_tensor(in); - vTensorPtr t_out = graph.get_tensor(out); + const ValueRef in, + const ValueRef dims_ref, + const ValueRef out) { + const int64_t in_dim = graph.dim_of(in); + const std::vector in_sizes = graph.sizes_of(in); + const std::vector out_sizes = graph.sizes_of(in); - IntListPtr dims = graph.get_int_list(dims_ref); + const std::vector dims = graph.extract_int_or_symint_list(dims_ref); std::vector squeeze_dims; // Filter out edge cases that we don't need squeeze: // 1. The size of squeeze dim is larger than 1. // 2. Squeeze outter most dim // For these cases, just pass input to output via clone. - for (int i = 0; i < dims->size(); ++i) { - if (dims->at(i) != 0 && t_in->sizes().at(dims->at(i)) == 1) { - squeeze_dims.push_back(dims->at(i)); + for (int i = 0; i < dims.size(); ++i) { + if (dims.at(i) != 0 && in_sizes.at(dims.at(i)) == 1) { + squeeze_dims.push_back(dims.at(i)); } } if (squeeze_dims.size() == 0) { add_clone_node(graph, in, out); } else { - std::vector permute_dims(t_in->dim()); - for (int i = 0; i < t_in->dim(); ++i) { + std::vector permute_dims(in_dim); + for (int i = 0; i < in_dim; ++i) { permute_dims.at(i) = i; } for (auto& elem : squeeze_dims) { @@ -48,7 +49,9 @@ void add_squeeze_copy_dims_node( std::rotate(permute_dims.begin(), it, it + 1); } - add_permute_node(graph, in, permute_dims, out); + const ValueRef permute_dims_ref = + graph.add_scalar_list(std::vector(permute_dims)); + add_permute_node(graph, in, permute_dims_ref, out); } } diff --git a/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp b/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp index c8ada796e8e..306a79fb8b8 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp @@ -16,17 +16,16 @@ namespace vkcompute { void add_unsqueeze_node( ComputeGraph& graph, - ValueRef in, - ValueRef dim_ref, - ValueRef out) { - vTensorPtr t_in = graph.get_tensor(in); - vTensorPtr t_out = graph.get_tensor(out); + const ValueRef in, + const ValueRef dim_ref, + const ValueRef out) { + const int64_t in_dim = graph.dim_of(in); + const int64_t out_dim = graph.dim_of(out); VK_CHECK_COND( - t_in->dim() < 4, "Cannot unsqueeze a tensor with more than 3 dimensions"); + in_dim < 4, "Cannot unsqueeze a tensor with more than 3 dimensions"); int64_t dim = graph.extract_scalar(dim_ref); - int64_t out_dim = t_out->dim(); std::vector permute_dims(out_dim); for (int i = 1; i <= dim; i++) { @@ -38,7 +37,9 @@ void add_unsqueeze_node( permute_dims[i] = i; } - add_permute_node(graph, in, permute_dims, out); + const ValueRef permute_dims_ref = + graph.add_scalar_list(std::vector(permute_dims)); + add_permute_node(graph, in, permute_dims_ref, out); } void unsqueeze(ComputeGraph& graph, const std::vector& args) { diff --git a/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.cpp b/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.cpp index 9d010c794ec..2bcf2a3842f 100644 --- a/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.cpp @@ -57,6 +57,13 @@ bool check_same_packed_dim(const api::vTensor& t1, const api::vTensor& t2) { return t1.packed_dim() == t2.packed_dim(); } +bool check_same_packed_dim( + ComputeGraph& graph, + const ValueRef in, + const ValueRef out) { + return graph.packed_dim_of(in) == graph.packed_dim_of(out); +} + bool check_same_packed_dim( const api::vTensor& t1, const api::vTensor& t2, diff --git a/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h b/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h index c9eeb0efe08..3b61083069e 100644 --- a/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h +++ b/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h @@ -9,6 +9,7 @@ #pragma once #include +#include namespace vkcompute { @@ -38,6 +39,11 @@ bool check_packed_dim_is(const api::vTensor& t, const int32_t packed_dim); bool check_same_packed_dim(const api::vTensor& t1, const api::vTensor& t2); +bool check_same_packed_dim( + ComputeGraph& graph, + const ValueRef in, + const ValueRef out); + bool check_same_packed_dim( const api::vTensor& t1, const api::vTensor& t2,