Skip to content
Merged
25 changes: 17 additions & 8 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,12 @@ def register_sdpa_with_kv_cache_op(features: OpFeatures):
return features


@update_features(["llama::update_cache", "llama::custom_sdpa"])
@update_features(
[
"llama::update_cache",
"llama::custom_sdpa",
]
)
def register_sdpa_ops(features: OpFeatures):
features.resize_fn = False
features.buffer_impl = False
Expand All @@ -520,8 +525,17 @@ def register_rotary_emb_op(features: OpFeatures):
return features


@update_features(exir_ops.edge.aten.view_copy.default)
def register_view_op(features: OpFeatures):
@update_features(
[
exir_ops.edge.aten.clone.default,
exir_ops.edge.aten.permute.default,
exir_ops.edge.aten.permute_copy.default,
exir_ops.edge.aten.select_copy.int,
exir_ops.edge.aten.slice_copy.Tensor,
exir_ops.edge.aten.view_copy.default,
]
)
def register_view_ops(features: OpFeatures):
features.texture_impl = TextureImplFeatures(
valid_packed_dims=all_packed_dims,
)
Expand All @@ -538,10 +552,8 @@ def register_view_op(features: OpFeatures):
# Indexing and lookup
exir_ops.edge.aten.flip.default,
exir_ops.edge.aten.index_select.default,
exir_ops.edge.aten.select_copy.int,
# Tensor creation
exir_ops.edge.aten.arange.start_step,
exir_ops.edge.aten.clone.default,
exir_ops.edge.aten.constant_pad_nd.default,
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.full_like.default,
Expand All @@ -564,12 +576,9 @@ def register_ported_op(features: OpFeatures):
# Ops ported from PyTorch Vulkan backend. These ops are in a separate registry becasue they support all packed dimensions
@update_features(
[
# Indexing and lookup
exir_ops.edge.aten.slice_copy.Tensor,
# Shape Manipulation
exir_ops.edge.aten.squeeze_copy.dims,
exir_ops.edge.aten.unsqueeze_copy.default,
exir_ops.edge.aten.permute_copy.default,
# Tensor combination
exir_ops.edge.aten.cat.default,
exir_ops.edge.aten.repeat.default,
Expand Down
118 changes: 81 additions & 37 deletions backends/vulkan/runtime/graph/ops/impl/Permute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@ using utils::uvec4;
namespace {

void check_args(
const api::vTensor& in,
const std::vector<int64_t>& permute_dims,
const api::vTensor& out) {
VK_CHECK_COND(check_same_packed_dim(in, out));
ComputeGraph& graph,
const ValueRef in,
const ValueRef permute_dims,
const ValueRef out) {
VK_CHECK_COND(check_same_packed_dim(graph, in, out));

// This implementation doesn't not requires the input tensor to have the same
// dim size as the argument. The code will work as long as the input tensor's
Expand All @@ -38,40 +39,93 @@ void check_args(

} // namespace

void resize_permute_node(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& resize_args) {
const ValueRef out = args[0].refs[0];
const ValueRef in = args[1].refs[0];

const std::vector<int64_t> in_sizes = graph->sizes_of(in);
const std::vector<int64_t> out_sizes = graph->sizes_of(out);

const std::vector<int64_t> permute_dims =
graph->extract_int_or_symint_list(resize_args[0]);

if (in_sizes.size() == out_sizes.size() &&
in_sizes.size() == permute_dims.size()) {
std::vector<int64_t> new_out_sizes(out_sizes.size(), 1);
const int64_t out_ndim = std::max(in_sizes.size(), out_sizes.size());
for (int i = 0; i < out_ndim; i++) {
const int64_t permute_dim = permute_dims.at(i);
new_out_sizes.at(i) = in_sizes.at(permute_dim);
}
graph->virtual_resize(out, new_out_sizes);
}
// Case where permute is being used to implement squeeze
else if (
in_sizes.size() > out_sizes.size() &&
in_sizes.size() == permute_dims.size()) {
std::vector<int64_t> new_out_sizes(out_sizes.size(), 1);
const int offset = in_sizes.size() - out_sizes.size();
for (int i = 0; i < out_sizes.size(); i++) {
const int64_t permute_dim = permute_dims.at(i + offset);
new_out_sizes.at(i) = in_sizes.at(permute_dim);
}
graph->virtual_resize(out, new_out_sizes);
}
// Case where Permute is being used to implement unsqueeze
else if (
in_sizes.size() < out_sizes.size() &&
out_sizes.size() == permute_dims.size()) {
std::vector<int64_t> new_out_sizes(out_sizes.size(), 1);
const int offset = out_sizes.size() - in_sizes.size();
for (int i = 0; i < out_sizes.size(); i++) {
int64_t permute_dim = permute_dims.at(i) - offset;
if (permute_dim >= 0) {
new_out_sizes.at(i) = in_sizes.at(permute_dim);
}
}
graph->virtual_resize(out, new_out_sizes);
} else {
VK_THROW("Invalid permute dims");
}
}

void add_permute_node(
ComputeGraph& graph,
ValueRef in,
const std::vector<int64_t>& permute_dims,
ValueRef out) {
vTensorPtr t_in = graph.get_tensor(in);
vTensorPtr t_out = graph.get_tensor(out);

check_args(*t_in, permute_dims, *t_out);
const ValueRef in,
const ValueRef permute_dims,
const ValueRef out) {
check_args(graph, in, permute_dims, out);

ivec4 out_dims{0, 1, 2, 3};

// Special cases of squeeze/unsqueeze. Because the input dim size can be
// different with output dim size. So pick t_in->dim() if squeeze, and
// t_out->dim() if unsqueeze to create parameter for permute.
int64_t out_ndim = std::max(t_in->dim(), t_out->dim());
// different with output dim size. So pick graph.dim_of(in) if squeeze, and
// graph.dim_of(out) if unsqueeze to create parameter for permute.
const int64_t out_ndim = std::max(graph.dim_of(in), graph.dim_of(out));
std::vector<bool> seen(out_ndim);
for (int i = 0; i < out_ndim; i++) {
int64_t permute_dim = permute_dims[i];
VK_CHECK_COND(
!seen[permute_dim], "Argument dim ", permute_dim, " is repeated");
seen[permute_dim] = true;

out_dims[(4u - out_ndim) + i] = permute_dim + (4 - out_ndim);
{
IntListPtr permute_dims_ptr = graph.get_int_list(permute_dims);
for (int i = 0; i < out_ndim; i++) {
int64_t permute_dim = permute_dims_ptr->at(i);
VK_CHECK_COND(
!seen[permute_dim], "Argument dim ", permute_dim, " is repeated");
seen[permute_dim] = true;

out_dims[(4u - out_ndim) + i] = permute_dim + (4 - out_ndim);
}
}

std::string kernel_name = "permute";
kernel_name.reserve(kShaderNameReserve);
add_dtype_suffix(kernel_name, *t_out);
add_dtype_suffix(kernel_name, graph.dtype_of(out));

int32_t out_channels = dim_at<kChannel4D>(t_out->sizes());
int32_t in_channels = dim_at<kChannel4D>(t_in->sizes());
const int32_t out_channels = dim_at<kChannel4D>(graph.sizes_of(out));
const int32_t in_channels = dim_at<kChannel4D>(graph.sizes_of(in));

const auto packed_dim = graph.packed_dim_of(in);
const int32_t packed_dim = graph.packed_dim_of(in);
ivec2 channel_info = {out_channels, in_channels};
if (packed_dim == WHCN::kChannelsDim) {
channel_info[0] = utils::align_up_4(channel_info[0]);
Expand All @@ -95,19 +149,9 @@ void add_permute_node(
// Specialization Constants
spec_vars,
// Resize Args
{},
{permute_dims},
// Resizing Logic
nullptr));
}

void add_permute_node(
ComputeGraph& graph,
ValueRef in,
ValueRef permute_dims_ref,
ValueRef out) {
IntListPtr permute_dims = graph.get_int_list(permute_dims_ref);

add_permute_node(graph, in, *permute_dims, out);
resize_permute_node));
}

void permute(ComputeGraph& graph, const std::vector<ValueRef>& args) {
Expand Down
6 changes: 3 additions & 3 deletions backends/vulkan/runtime/graph/ops/impl/Permute.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ namespace vkcompute {

void add_permute_node(
ComputeGraph& graph,
ValueRef in,
const std::vector<int64_t>& permute_dims,
ValueRef out);
const ValueRef in,
const ValueRef permute_dims,
const ValueRef out);

} // namespace vkcompute
22 changes: 14 additions & 8 deletions backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,20 @@ namespace vkcompute {
void resize_rotary_embedding_node(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& extra_args) {
(void)extra_args;
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
vTensorPtr in = graph->get_tensor(args[1].refs[0]);

std::vector<int64_t> in_sizes = in->sizes();
// UNCOMMENT BELOW IF NEEDED
// out->virtual_resize(in_sizes);
const std::vector<ValueRef>& resize_args) {
(void)resize_args;

const ValueRef xq_out = args.at(0).refs.at(0);
const ValueRef xk_out = args.at(0).refs.at(1);

const ValueRef xq = args.at(1).refs.at(0);
const ValueRef xk = args.at(1).refs.at(1);

const std::vector<int64_t> xq_sizes = graph->sizes_of(xq);
const std::vector<int64_t> xk_sizes = graph->sizes_of(xk);

graph->virtual_resize(xq_out, xq_sizes);
graph->virtual_resize(xk_out, xk_sizes);
}

void add_rotary_embedding_node(
Expand Down
27 changes: 15 additions & 12 deletions backends/vulkan/runtime/graph/ops/impl/Squeeze.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,29 @@ namespace vkcompute {

void add_squeeze_copy_dims_node(
ComputeGraph& graph,
ValueRef in,
ValueRef dims_ref,
ValueRef out) {
vTensorPtr t_in = graph.get_tensor(in);
vTensorPtr t_out = graph.get_tensor(out);
const ValueRef in,
const ValueRef dims_ref,
const ValueRef out) {
const int64_t in_dim = graph.dim_of(in);
const std::vector<int64_t> in_sizes = graph.sizes_of(in);
const std::vector<int64_t> out_sizes = graph.sizes_of(in);

IntListPtr dims = graph.get_int_list(dims_ref);
const std::vector<int64_t> dims = graph.extract_int_or_symint_list(dims_ref);
std::vector<int64_t> squeeze_dims;
// Filter out edge cases that we don't need squeeze:
// 1. The size of squeeze dim is larger than 1.
// 2. Squeeze outter most dim
// For these cases, just pass input to output via clone.
for (int i = 0; i < dims->size(); ++i) {
if (dims->at(i) != 0 && t_in->sizes().at(dims->at(i)) == 1) {
squeeze_dims.push_back(dims->at(i));
for (int i = 0; i < dims.size(); ++i) {
if (dims.at(i) != 0 && in_sizes.at(dims.at(i)) == 1) {
squeeze_dims.push_back(dims.at(i));
}
}
if (squeeze_dims.size() == 0) {
add_clone_node(graph, in, out);
} else {
std::vector<int64_t> permute_dims(t_in->dim());
for (int i = 0; i < t_in->dim(); ++i) {
std::vector<int64_t> permute_dims(in_dim);
for (int i = 0; i < in_dim; ++i) {
permute_dims.at(i) = i;
}
for (auto& elem : squeeze_dims) {
Expand All @@ -48,7 +49,9 @@ void add_squeeze_copy_dims_node(
std::rotate(permute_dims.begin(), it, it + 1);
}

add_permute_node(graph, in, permute_dims, out);
const ValueRef permute_dims_ref =
graph.add_scalar_list<int64_t>(std::vector<int64_t>(permute_dims));
add_permute_node(graph, in, permute_dims_ref, out);
}
}

Expand Down
17 changes: 9 additions & 8 deletions backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,16 @@ namespace vkcompute {

void add_unsqueeze_node(
ComputeGraph& graph,
ValueRef in,
ValueRef dim_ref,
ValueRef out) {
vTensorPtr t_in = graph.get_tensor(in);
vTensorPtr t_out = graph.get_tensor(out);
const ValueRef in,
const ValueRef dim_ref,
const ValueRef out) {
const int64_t in_dim = graph.dim_of(in);
const int64_t out_dim = graph.dim_of(out);

VK_CHECK_COND(
t_in->dim() < 4, "Cannot unsqueeze a tensor with more than 3 dimensions");
in_dim < 4, "Cannot unsqueeze a tensor with more than 3 dimensions");

int64_t dim = graph.extract_scalar<int64_t>(dim_ref);
int64_t out_dim = t_out->dim();

std::vector<int64_t> permute_dims(out_dim);
for (int i = 1; i <= dim; i++) {
Expand All @@ -38,7 +37,9 @@ void add_unsqueeze_node(
permute_dims[i] = i;
}

add_permute_node(graph, in, permute_dims, out);
const ValueRef permute_dims_ref =
graph.add_scalar_list<int64_t>(std::vector<int64_t>(permute_dims));
add_permute_node(graph, in, permute_dims_ref, out);
}

void unsqueeze(ComputeGraph& graph, const std::vector<ValueRef>& args) {
Expand Down
7 changes: 7 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@ bool check_same_packed_dim(const api::vTensor& t1, const api::vTensor& t2) {
return t1.packed_dim() == t2.packed_dim();
}

bool check_same_packed_dim(
ComputeGraph& graph,
const ValueRef in,
const ValueRef out) {
return graph.packed_dim_of(in) == graph.packed_dim_of(out);
}

bool check_same_packed_dim(
const api::vTensor& t1,
const api::vTensor& t2,
Expand Down
6 changes: 6 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#pragma once

#include <executorch/backends/vulkan/runtime/api/api.h>
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>

namespace vkcompute {

Expand Down Expand Up @@ -38,6 +39,11 @@ bool check_packed_dim_is(const api::vTensor& t, const int32_t packed_dim);

bool check_same_packed_dim(const api::vTensor& t1, const api::vTensor& t2);

bool check_same_packed_dim(
ComputeGraph& graph,
const ValueRef in,
const ValueRef out);

bool check_same_packed_dim(
const api::vTensor& t1,
const api::vTensor& t2,
Expand Down
Loading