Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 19 additions & 17 deletions backends/vulkan/_passes/remove_redundant_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,35 +31,37 @@ class RemoveRedundantOpsTransform(ExportPass):
exir_ops.edge.aten.lift_fresh_copy.default,
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
exir_ops.edge.dim_order_ops._clone_dim_order.default,
exir_ops.edge.aten.expand_copy.default,
}

def __init__(self) -> None:
super(RemoveRedundantOpsTransform, self).__init__()

def _should_remove(self, node: torch.fx.Node) -> bool:
if node.target in self.redundant_ops:
return True

# Only remove to_copy if dtype does not change. Otherwise, memory format changes
# will be handled internally by the backend.
if (
node.target == exir_ops.edge.aten._to_copy.default
or node.target == torch.ops.aten._to_copy.default
):
src_dtype = node.meta["val"].dtype
# pyre-ignore
dst_dtype = node.args[0].meta["val"].dtype
return src_dtype == dst_dtype

return False
if node.target not in self.redundant_ops:
return False

orig_node = node.args[0]
assert isinstance(orig_node, torch.fx.Node)

src_dtype = orig_node.meta["val"].dtype
dst_dtype = node.meta["val"].dtype

# Do not remove if the op is converting the dtype.
if src_dtype != dst_dtype:
return False

src_shape = orig_node.meta["val"].shape
dst_shape = node.meta["val"].shape

return src_shape == dst_shape

def _remove(self, graph_module: torch.fx.GraphModule) -> None:
for node in graph_module.graph.nodes:
if not self._should_remove(node):
continue

with graph_module.graph.inserting_after(node):
node.replace_all_uses_with(node.args[0])
node.replace_all_uses_with(node.args[0])

graph_module.graph.eliminate_dead_code()

Expand Down
26 changes: 2 additions & 24 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,12 @@
# pyre-unsafe

import operator

from typing import Any, Callable, Dict, List, Optional, Union

import executorch.backends.vulkan.custom_ops_lib # noqa

import executorch.backends.vulkan.utils as utils

import torch

from executorch.exir.dialects._ops import ops as exir_ops

from executorch.exir.dialects.edge._ops import EdgeOpOverload
from torch._subclasses.fake_tensor import FakeTensor

Expand Down Expand Up @@ -129,6 +124,7 @@ def update_features_impl(op: OpKey):
# Symbolic integer ops
torch.ops.aten.sym_size.int,
operator.add,
operator.sub,
operator.lt,
operator.gt,
operator.ge,
Expand Down Expand Up @@ -297,27 +293,9 @@ def check_to_copy_node(node: torch.fx.Node) -> bool:

@update_features(exir_ops.edge.dim_order_ops._to_dim_order_copy.default)
def register_to_copy_dim_order_op():
# Currently there is no "real" implementation for to_dim_order_copy, but it can be
# removed as long as the operator is not changing the dtype, i.e. the operator call
# is modifying the dim order only. Therefore, check that the input and output dtypes
# are the same, if so the operator is safe to remove.
def check_dim_order_copy_node(node: torch.fx.Node) -> bool:
in_arg = node.args[0]
if not isinstance(in_arg, torch.fx.Node):
return False

in_tensor = in_arg.meta.get("val", None)
out_tensor = node.meta.get("val", None)

if in_tensor.dtype != out_tensor.dtype:
return False

return True

return OpFeatures(
inputs_storage=utils.ANY_STORAGE,
inputs_storage=utils.ANY_BUFFER,
supports_resize=True,
are_node_inputs_supported_fn=check_dim_order_copy_node,
)


Expand Down
4 changes: 4 additions & 0 deletions backends/vulkan/runtime/graph/ComputeGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,10 @@ class ComputeGraph final {

bool device_name_contains(const char* substr);

int64_t max_buffer_numel() {
return static_cast<int64_t>(context_->adapter_ptr()->max_buffer_numel());
}

//
// Graph Building
//
Expand Down
21 changes: 13 additions & 8 deletions backends/vulkan/runtime/graph/ops/glsl/view_buffer.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ ${layout_declare_ubo(B, "BufferMetadata", "inp")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

${layout_declare_spec_const(C, "int", "all_contiguous", "0")}

/*
* The insight behind the view operation is that the contiguous index of each
* tensor element in the input and output tensors are the same.
Expand All @@ -28,17 +30,20 @@ void main() {
return;
}

TensorIndex outp_tidx;
linear_idx_to_tensor_idx(outp, outp_bufi, outp_tidx);
uint inp_bufi = outp_bufi;
if (all_contiguous == 0) {
TensorIndex outp_tidx;
linear_idx_to_tensor_idx(outp, outp_bufi, outp_tidx);

// To map the output to the input, find the input element that has the same
// contiguous index as the output element.
const uint contig_idx = tensor_idx_to_contiguous_idx(outp, outp_tidx);
// To map the output to the input, find the input element that has the same
// contiguous index as the output element.
const uint contig_idx = tensor_idx_to_contiguous_idx(outp, outp_tidx);

TensorIndex inp_tidx;
contiguous_idx_to_tensor_idx(inp, contig_idx, inp_tidx);
TensorIndex inp_tidx;
contiguous_idx_to_tensor_idx(inp, contig_idx, inp_tidx);

const uint inp_bufi = tensor_idx_to_linear_idx(inp, inp_tidx);
inp_bufi = tensor_idx_to_linear_idx(inp, inp_tidx);
}

t_outp[outp_bufi] = t_inp[inp_bufi];
}
54 changes: 54 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/view_convert_buffer.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#version 450 core

#define PRECISION ${PRECISION}

#define IN_T ${buffer_scalar_type(IN_DTYPE)}
#define OUT_T ${buffer_scalar_type(OUT_DTYPE)}

${define_required_extensions(IN_DTYPE)}
${define_required_extensions(OUT_DTYPE)}

layout(std430) buffer;

#include "indexing.glslh"

${layout_declare_buffer(B, "w", "t_outp", OUT_DTYPE)}
${layout_declare_buffer(B, "r", "t_inp", IN_DTYPE)}

${layout_declare_ubo(B, "BufferMetadata", "outp")}
${layout_declare_ubo(B, "BufferMetadata", "inp")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

${layout_declare_spec_const(C, "int", "all_contiguous", "0")}

/*
* The insight behind the view_convert operation is that the contiguous index of each
* tensor element in the input and output tensors are the same, but the data types
* may be different and need conversion.
*/
void main() {
const uint outp_bufi = gl_GlobalInvocationID.x;
if (outp_bufi >= numel(outp)) {
return;
}

uint inp_bufi = outp_bufi;

if (all_contiguous == 0) {
TensorIndex outp_tidx;
linear_idx_to_tensor_idx(outp, outp_bufi, outp_tidx);

// To map the output to the input, find the input element that has the same
// contiguous index as the output element.
const uint contig_idx = tensor_idx_to_contiguous_idx(outp, outp_tidx);

TensorIndex inp_tidx;
contiguous_idx_to_tensor_idx(inp, contig_idx, inp_tidx);

inp_bufi = tensor_idx_to_linear_idx(inp, inp_tidx);
}

// Convert data type from input to output
t_outp[outp_bufi] = OUT_T(t_inp[inp_bufi]);
}
22 changes: 22 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/view_convert_buffer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

view_convert_buffer:
parameter_names_with_default_values:
IN_DTYPE: float
OUT_DTYPE: float
STORAGE: buffer
generate_variant_forall:
combination:
parameter_names: [IN_DTYPE, OUT_DTYPE]
combos:
- parameter_values: [int32, float]
- parameter_values: [int32, half]
- parameter_values: [uint8, float]
- parameter_values: [uint8, half]
- parameter_values: [uint8, int32]
shader_variants:
- NAME: view_convert_buffer
35 changes: 28 additions & 7 deletions backends/vulkan/runtime/graph/ops/impl/SDPA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -471,10 +471,31 @@ void sdpa_impl(ComputeGraph& graph, const std::vector<ValueRef>& args) {
VK_CHECK_COND(graph.val_is_none(attn_mask));

const int64_t num_q_heads = graph.size_at<int64_t>(-2, q_projected);
const int64_t max_seq_len = graph.size_at<int64_t>(-3, q_projected);

int64_t max_seq_len = graph.size_at<int64_t>(-3, q_projected);
const int64_t max_context_len = graph.size_at<int32_t>(-3, k_cache);

const utils::StorageType attn_weights_storage =
graph.storage_type_of(q_projected);

// If using buffer storage for attn weights, we need to ensure that the buffer
// numel limit is not exceeded. If needed, manually adjust max_seq_len based
// on the buffer numel limit.
if (attn_weights_storage == utils::kBuffer) {
const int64_t max_buffer_numel = graph.max_buffer_numel();
if (num_q_heads * max_seq_len * max_context_len >= max_buffer_numel) {
// Compute the maximum possible value for max_seq_len that will hit
// the buffer numel limit.
max_seq_len = max_buffer_numel / (num_q_heads * max_context_len);
// Adjust down to the nearest multiple of 4 to make sure the limit is
// not hit.
if (max_seq_len % 4 != 0) {
max_seq_len = (max_seq_len / 4) * 4;
} else {
max_seq_len -= 4;
}
}
}

std::vector<int64_t> attn_weight_full_sizes = {
1, // batch
num_q_heads,
Expand All @@ -485,14 +506,14 @@ void sdpa_impl(ComputeGraph& graph, const std::vector<ValueRef>& args) {
&graph,
attn_weight_full_sizes,
graph.dtype_of(q_projected),
graph.storage_type_of(q_projected),
attn_weights_storage,
utils::kWidthPacked);

TmpTensor attn_weights_softmax(
&graph,
attn_weight_full_sizes,
graph.dtype_of(q_projected),
graph.storage_type_of(q_projected),
attn_weights_storage,
utils::kWidthPacked);

add_sdpa_compute_attn_weights_node(
Expand Down Expand Up @@ -528,9 +549,9 @@ void sdpa_with_kv_cache_impl(

utils::StorageType cache_storage = graph.storage_type_of(q_projected);
const ValueRef k_cache =
prepack_standard(graph, k_cache_data, cache_storage, utils::kWidthPacked);
graph.add_tensor_like(k_cache_data, cache_storage, utils::kWidthPacked);
const ValueRef v_cache =
prepack_standard(graph, v_cache_data, cache_storage, utils::kWidthPacked);
graph.add_tensor_like(v_cache_data, cache_storage, utils::kWidthPacked);

update_cache_impl(graph, {k_projected, k_cache, input_pos_symint, -1});
update_cache_impl(graph, {v_projected, v_cache, input_pos_symint, -1});
Expand Down Expand Up @@ -573,7 +594,7 @@ void compute_attn_weight_with_kv_cache_impl(

(void)sequence_len;

utils::StorageType cache_storage = graph.storage_type_of(q_projected);
const utils::StorageType cache_storage = graph.storage_type_of(q_projected);
const ValueRef k_cache =
graph.add_tensor_like(k_cache_data, cache_storage, utils::kWidthPacked);
const ValueRef v_cache =
Expand Down
12 changes: 12 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,18 @@ void resize_unsqueeze_node(

std::vector<int64_t> out_sizes = graph->sizes_of(in);

std::vector<int64_t> unsqueezed_dims;

if (graph->val_is_int_list(dims_ref)) {
const IntListPtr dims = graph->get_int_list(dims_ref);
for (int64_t d : *dims) {
unsqueezed_dims.push_back(d);
}
} else {
const int64_t dim = graph->extract_scalar<int64_t>(dims_ref);
unsqueezed_dims.push_back(dim);
}

// Insert singleton dimensions at the specified positions
for (auto dim : dims_vec) {
int64_t d = dim;
Expand Down
Loading
Loading