Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion backends/vulkan/_passes/fuse_quantized_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from executorch.exir.passes import dead_code_elimination_pass

#################
## linear_qcnw ##
Expand Down Expand Up @@ -224,6 +225,8 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
)

graph_module.recompile()
graph_module = super().call(graph_module).graph_module
dead_code_elimination_pass(graph_module)

# Re-trace the graph since new nodes were (potentially) inserted
graph_module = super().call(graph_module).graph_module
return PassResult(graph_module, True)
4 changes: 2 additions & 2 deletions backends/vulkan/_passes/tag_memory_meta_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# LICENSE file in the root directory of this source tree.

import logging
from copy import deepcopy
from typing import Any, Optional, Set

import executorch.backends.vulkan.utils as utils
Expand All @@ -22,6 +21,7 @@
from executorch.exir.dialects._ops import ops as exir_ops

from executorch.exir.pass_base import ExportPass, PassResult
from executorch.exir.tensor import TensorSpec

logger: logging.Logger = logging.getLogger("")
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -52,7 +52,7 @@ def insert_transition_node(
(arg,),
)
clone_node.meta["val"] = arg.meta["val"]
clone_node.meta["spec"] = deepcopy(arg.meta["spec"])
clone_node.meta["spec"] = TensorSpec.from_tensor(clone_node.meta["val"])
clone_node.meta["spec"].const = False
set_memory_metadata(clone_node, storage, layout)
arg.replace_all_uses_with(clone_node, lambda x, y=node: x == y)
Expand Down
33 changes: 25 additions & 8 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,14 @@ def update_features_impl(op: OpKey):
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
# Symbolic integer ops
torch.ops.aten.sym_size.int,
operator.add,
operator.lt,
operator.gt,
operator.ge,
operator.le,
# Guard and assert ops
torch.ops.aten._assert_scalar.default,
torch.ops.aten.sym_constrain_range_for_size.default,
]
)
def register_ephemeral_op(features: OpFeatures):
Expand Down Expand Up @@ -500,7 +508,12 @@ def register_sdpa_with_kv_cache_op(features: OpFeatures):
return features


@update_features(["llama::update_cache", "llama::custom_sdpa"])
@update_features(
[
"llama::update_cache",
"llama::custom_sdpa",
]
)
def register_sdpa_ops(features: OpFeatures):
features.resize_fn = False
features.buffer_impl = False
Expand All @@ -520,8 +533,17 @@ def register_rotary_emb_op(features: OpFeatures):
return features


@update_features(exir_ops.edge.aten.view_copy.default)
def register_view_op(features: OpFeatures):
@update_features(
[
exir_ops.edge.aten.clone.default,
exir_ops.edge.aten.permute.default,
exir_ops.edge.aten.permute_copy.default,
exir_ops.edge.aten.select_copy.int,
exir_ops.edge.aten.slice_copy.Tensor,
exir_ops.edge.aten.view_copy.default,
]
)
def register_view_ops(features: OpFeatures):
features.texture_impl = TextureImplFeatures(
valid_packed_dims=all_packed_dims,
)
Expand All @@ -538,10 +560,8 @@ def register_view_op(features: OpFeatures):
# Indexing and lookup
exir_ops.edge.aten.flip.default,
exir_ops.edge.aten.index_select.default,
exir_ops.edge.aten.select_copy.int,
# Tensor creation
exir_ops.edge.aten.arange.start_step,
exir_ops.edge.aten.clone.default,
exir_ops.edge.aten.constant_pad_nd.default,
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.full_like.default,
Expand All @@ -564,12 +584,9 @@ def register_ported_op(features: OpFeatures):
# Ops ported from PyTorch Vulkan backend. These ops are in a separate registry becasue they support all packed dimensions
@update_features(
[
# Indexing and lookup
exir_ops.edge.aten.slice_copy.Tensor,
# Shape Manipulation
exir_ops.edge.aten.squeeze_copy.dims,
exir_ops.edge.aten.unsqueeze_copy.default,
exir_ops.edge.aten.permute_copy.default,
# Tensor combination
exir_ops.edge.aten.cat.default,
exir_ops.edge.aten.repeat.default,
Expand Down
7 changes: 4 additions & 3 deletions backends/vulkan/partitioner/vulkan_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,11 @@ def op_node_is_compatible( # noqa: C901: Function is too complex
def node_is_compatible(
self, node: torch.fx.Node, features: Optional[OpFeatures] = None
) -> Tuple[bool, str]:
if utils.is_symint_node(node):
return node.target in vulkan_supported_ops, "Op is compatible"
elif utils.is_tensor_node(node):
if utils.is_tensor_node(node):
return self.op_node_is_compatible(node, features=features)
# For non-tensor nodes, just check if the op is registered
elif hasattr(node, "target"):
return node.target in vulkan_supported_ops, "Op is compatible"

return False, f"Unsupported node type: {node.format_node()}"

Expand Down
120 changes: 83 additions & 37 deletions backends/vulkan/runtime/graph/ops/impl/Permute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@ using utils::uvec4;
namespace {

void check_args(
const api::vTensor& in,
const std::vector<int64_t>& permute_dims,
const api::vTensor& out) {
VK_CHECK_COND(check_same_packed_dim(in, out));
ComputeGraph& graph,
const ValueRef in,
const ValueRef permute_dims,
const ValueRef out) {
(void)permute_dims;
VK_CHECK_COND(check_same_packed_dim(graph, in, out));

// This implementation doesn't not requires the input tensor to have the same
// dim size as the argument. The code will work as long as the input tensor's
Expand All @@ -38,40 +40,94 @@ void check_args(

} // namespace

void resize_permute_node(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& resize_args) {
const ValueRef out = args[0].refs[0];
const ValueRef in = args[1].refs[0];

const std::vector<int64_t> in_sizes = graph->sizes_of(in);
const std::vector<int64_t> out_sizes = graph->sizes_of(out);

const std::vector<int64_t> permute_dims =
graph->extract_int_or_symint_list(resize_args[0]);

if (in_sizes.size() == out_sizes.size() &&
in_sizes.size() == permute_dims.size()) {
std::vector<int64_t> new_out_sizes(out_sizes.size(), 1);
const int64_t out_ndim = std::max(in_sizes.size(), out_sizes.size());
for (int i = 0; i < out_ndim; i++) {
const int64_t permute_dim = permute_dims.at(i);
new_out_sizes.at(i) = in_sizes.at(permute_dim);
}
graph->virtual_resize(out, new_out_sizes);
}
// Case where permute is being used to implement squeeze
else if (
in_sizes.size() > out_sizes.size() &&
in_sizes.size() == permute_dims.size()) {
std::vector<int64_t> new_out_sizes(out_sizes.size(), 1);
const size_t offset = in_sizes.size() - out_sizes.size();
for (int i = 0; i < out_sizes.size(); i++) {
const int64_t permute_dim = permute_dims.at(i + offset);
new_out_sizes.at(i) = in_sizes.at(permute_dim);
}
graph->virtual_resize(out, new_out_sizes);
}
// Case where Permute is being used to implement unsqueeze
else if (
in_sizes.size() < out_sizes.size() &&
out_sizes.size() == permute_dims.size()) {
std::vector<int64_t> new_out_sizes(out_sizes.size(), 1);
const size_t offset = out_sizes.size() - in_sizes.size();
for (int i = 0; i < out_sizes.size(); i++) {
int64_t permute_dim = permute_dims.at(i) - offset;
if (permute_dim >= 0) {
new_out_sizes.at(i) = in_sizes.at(permute_dim);
}
}
graph->virtual_resize(out, new_out_sizes);
} else {
VK_THROW("Invalid permute dims");
}
}

void add_permute_node(
ComputeGraph& graph,
ValueRef in,
const std::vector<int64_t>& permute_dims,
ValueRef out) {
vTensorPtr t_in = graph.get_tensor(in);
vTensorPtr t_out = graph.get_tensor(out);

check_args(*t_in, permute_dims, *t_out);
const ValueRef in,
const ValueRef permute_dims,
const ValueRef out) {
check_args(graph, in, permute_dims, out);

ivec4 out_dims{0, 1, 2, 3};

// Special cases of squeeze/unsqueeze. Because the input dim size can be
// different with output dim size. So pick t_in->dim() if squeeze, and
// t_out->dim() if unsqueeze to create parameter for permute.
int64_t out_ndim = std::max(t_in->dim(), t_out->dim());
// different with output dim size. So pick graph.dim_of(in) if squeeze, and
// graph.dim_of(out) if unsqueeze to create parameter for permute.
const int64_t out_ndim = std::max(graph.dim_of(in), graph.dim_of(out));
std::vector<bool> seen(out_ndim);
for (int i = 0; i < out_ndim; i++) {
int64_t permute_dim = permute_dims[i];
VK_CHECK_COND(
!seen[permute_dim], "Argument dim ", permute_dim, " is repeated");
seen[permute_dim] = true;

out_dims[(4u - out_ndim) + i] = permute_dim + (4 - out_ndim);
{
IntListPtr permute_dims_ptr = graph.get_int_list(permute_dims);
for (int i = 0; i < out_ndim; i++) {
int64_t permute_dim = permute_dims_ptr->at(i);
VK_CHECK_COND(
!seen[permute_dim], "Argument dim ", permute_dim, " is repeated");
seen[permute_dim] = true;

out_dims[(4u - out_ndim) + i] =
utils::safe_downcast<int32_t>(permute_dim + (4 - out_ndim));
}
}

std::string kernel_name = "permute";
kernel_name.reserve(kShaderNameReserve);
add_dtype_suffix(kernel_name, *t_out);
add_dtype_suffix(kernel_name, graph.dtype_of(out));

int32_t out_channels = dim_at<kChannel4D>(t_out->sizes());
int32_t in_channels = dim_at<kChannel4D>(t_in->sizes());
const int32_t out_channels = dim_at<kChannel4D>(graph.sizes_of(out));
const int32_t in_channels = dim_at<kChannel4D>(graph.sizes_of(in));

const auto packed_dim = graph.packed_dim_of(in);
const int32_t packed_dim = graph.packed_dim_of(in);
ivec2 channel_info = {out_channels, in_channels};
if (packed_dim == WHCN::kChannelsDim) {
channel_info[0] = utils::align_up_4(channel_info[0]);
Expand All @@ -95,19 +151,9 @@ void add_permute_node(
// Specialization Constants
spec_vars,
// Resize Args
{},
{permute_dims},
// Resizing Logic
nullptr));
}

void add_permute_node(
ComputeGraph& graph,
ValueRef in,
ValueRef permute_dims_ref,
ValueRef out) {
IntListPtr permute_dims = graph.get_int_list(permute_dims_ref);

add_permute_node(graph, in, *permute_dims, out);
resize_permute_node));
}

void permute(ComputeGraph& graph, const std::vector<ValueRef>& args) {
Expand Down
6 changes: 3 additions & 3 deletions backends/vulkan/runtime/graph/ops/impl/Permute.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ namespace vkcompute {

void add_permute_node(
ComputeGraph& graph,
ValueRef in,
const std::vector<int64_t>& permute_dims,
ValueRef out);
const ValueRef in,
const ValueRef permute_dims,
const ValueRef out);

} // namespace vkcompute
22 changes: 14 additions & 8 deletions backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,20 @@ namespace vkcompute {
void resize_rotary_embedding_node(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& extra_args) {
(void)extra_args;
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
vTensorPtr in = graph->get_tensor(args[1].refs[0]);

std::vector<int64_t> in_sizes = in->sizes();
// UNCOMMENT BELOW IF NEEDED
// out->virtual_resize(in_sizes);
const std::vector<ValueRef>& resize_args) {
(void)resize_args;

const ValueRef xq_out = args.at(0).refs.at(0);
const ValueRef xk_out = args.at(0).refs.at(1);

const ValueRef xq = args.at(1).refs.at(0);
const ValueRef xk = args.at(1).refs.at(1);

const std::vector<int64_t> xq_sizes = graph->sizes_of(xq);
const std::vector<int64_t> xk_sizes = graph->sizes_of(xk);

graph->virtual_resize(xq_out, xq_sizes);
graph->virtual_resize(xk_out, xk_sizes);
}

void add_rotary_embedding_node(
Expand Down
27 changes: 15 additions & 12 deletions backends/vulkan/runtime/graph/ops/impl/Squeeze.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,29 @@ namespace vkcompute {

void add_squeeze_copy_dims_node(
ComputeGraph& graph,
ValueRef in,
ValueRef dims_ref,
ValueRef out) {
vTensorPtr t_in = graph.get_tensor(in);
vTensorPtr t_out = graph.get_tensor(out);
const ValueRef in,
const ValueRef dims_ref,
const ValueRef out) {
const int64_t in_dim = graph.dim_of(in);
const std::vector<int64_t> in_sizes = graph.sizes_of(in);
const std::vector<int64_t> out_sizes = graph.sizes_of(in);

IntListPtr dims = graph.get_int_list(dims_ref);
const std::vector<int64_t> dims = graph.extract_int_or_symint_list(dims_ref);
std::vector<int64_t> squeeze_dims;
// Filter out edge cases that we don't need squeeze:
// 1. The size of squeeze dim is larger than 1.
// 2. Squeeze outter most dim
// For these cases, just pass input to output via clone.
for (int i = 0; i < dims->size(); ++i) {
if (dims->at(i) != 0 && t_in->sizes().at(dims->at(i)) == 1) {
squeeze_dims.push_back(dims->at(i));
for (int i = 0; i < dims.size(); ++i) {
if (dims.at(i) != 0 && in_sizes.at(dims.at(i)) == 1) {
squeeze_dims.push_back(dims.at(i));
}
}
if (squeeze_dims.size() == 0) {
add_clone_node(graph, in, out);
} else {
std::vector<int64_t> permute_dims(t_in->dim());
for (int i = 0; i < t_in->dim(); ++i) {
std::vector<int64_t> permute_dims(in_dim);
for (int i = 0; i < in_dim; ++i) {
permute_dims.at(i) = i;
}
for (auto& elem : squeeze_dims) {
Expand All @@ -48,7 +49,9 @@ void add_squeeze_copy_dims_node(
std::rotate(permute_dims.begin(), it, it + 1);
}

add_permute_node(graph, in, permute_dims, out);
const ValueRef permute_dims_ref =
graph.add_scalar_list<int64_t>(std::vector<int64_t>(permute_dims));
add_permute_node(graph, in, permute_dims_ref, out);
}
}

Expand Down
Loading
Loading