Skip to content

Commit 1716407

Browse files
committed
Update base for Update on "[ET-VK] Migrate ops to use DynamicDispatchNode"
## Changes * Migrate operators that are used in the llama model to use `DynamicDispatchNode` instead of `DispatchNode` ## Motivation `DynamicDispatchNode` is a subclass of `DispatchNode` that allows dynamic selection of compute shaders, global and local work group sizing whenever the command buffer is encoded. This is critical for ensuring optimum performance when input shapes are dynamic, since it allows operators to select the best compute shader for the input conditions and also to adjust global work group sizing to launch the minimum number of work groups necessary. Without this change, performance of llama 3.2 1B with dynamic shapes enabled is terrible (< 1 tok/s) because global work group sizing is determined based on maximum tensor sizes, which is based on the maximum sequence length. In practice, the sequence length dimension of tensors (even during the prefill phase) will not approach the maximum. This results in a lot of inactive threads launched during compute shader dispatches. Differential Revision: [D75878398](https://our.internmc.facebook.com/intern/diff/D75878398/) [ghstack-poisoned]
1 parent 4348319 commit 1716407

22 files changed

+281
-191
lines changed

.ci/scripts/test_model.sh

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,24 @@ prepare_artifacts_upload() {
4949
}
5050

5151
build_cmake_executor_runner() {
52+
local backend_string_select="${1:-}"
5253
echo "Building executor_runner"
5354
rm -rf ${CMAKE_OUTPUT_DIR}
54-
cmake -DCMAKE_BUILD_TYPE=Debug \
55-
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
56-
-DPYTHON_EXECUTABLE="$PYTHON_EXECUTABLE" \
57-
-B${CMAKE_OUTPUT_DIR} .
58-
59-
cmake --build ${CMAKE_OUTPUT_DIR} -j4 --config Debug
55+
mkdir ${CMAKE_OUTPUT_DIR}
56+
if [[ "$backend_string_select" == "XNNPACK" ]]; then
57+
echo "Backend $backend_string_select selected"
58+
(cd ${CMAKE_OUTPUT_DIR} \
59+
&& cmake -DCMAKE_BUILD_TYPE=Release \
60+
-DEXECUTORCH_BUILD_XNNPACK=ON \
61+
-DPYTHON_EXECUTABLE="$PYTHON_EXECUTABLE" ..)
62+
cmake --build ${CMAKE_OUTPUT_DIR} -j4
63+
else
64+
cmake -DCMAKE_BUILD_TYPE=Debug \
65+
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
66+
-DPYTHON_EXECUTABLE="$PYTHON_EXECUTABLE" \
67+
-B${CMAKE_OUTPUT_DIR} .
68+
cmake --build ${CMAKE_OUTPUT_DIR} -j4 --config Debug
69+
fi
6070
}
6171

6272
run_portable_executor_runner() {
@@ -111,19 +121,6 @@ test_model() {
111121
run_portable_executor_runner
112122
}
113123

114-
build_cmake_xnn_executor_runner() {
115-
echo "Building xnn_executor_runner"
116-
117-
(rm -rf ${CMAKE_OUTPUT_DIR} \
118-
&& mkdir ${CMAKE_OUTPUT_DIR} \
119-
&& cd ${CMAKE_OUTPUT_DIR} \
120-
&& retry cmake -DCMAKE_BUILD_TYPE=Release \
121-
-DEXECUTORCH_BUILD_XNNPACK=ON \
122-
-DPYTHON_EXECUTABLE="$PYTHON_EXECUTABLE" ..)
123-
124-
cmake --build ${CMAKE_OUTPUT_DIR} -j4
125-
}
126-
127124
test_model_with_xnnpack() {
128125
WITH_QUANTIZATION=$1
129126
WITH_DELEGATION=$2
@@ -148,12 +145,11 @@ test_model_with_xnnpack() {
148145

149146
# Run test model
150147
if [[ "${BUILD_TOOL}" == "buck2" ]]; then
148+
# TODO eventually buck should also use consolidated executor runners
151149
buck2 run //examples/xnnpack:xnn_executor_runner -- --model_path "${OUTPUT_MODEL_PATH}"
152150
elif [[ "${BUILD_TOOL}" == "cmake" ]]; then
153-
if [[ ! -f ${CMAKE_OUTPUT_DIR}/backends/xnnpack/xnn_executor_runner ]]; then
154-
build_cmake_xnn_executor_runner
155-
fi
156-
./${CMAKE_OUTPUT_DIR}/backends/xnnpack/xnn_executor_runner --model_path "${OUTPUT_MODEL_PATH}"
151+
build_cmake_executor_runner "XNNPACK"
152+
./${CMAKE_OUTPUT_DIR}/executor_runner --model_path "${OUTPUT_MODEL_PATH}"
157153
else
158154
echo "Invalid build tool ${BUILD_TOOL}. Only buck2 and cmake are supported atm"
159155
exit 1

backends/vulkan/_passes/fuse_quantized_ops.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from executorch.exir import ExportedProgram
1818
from executorch.exir.dialects._ops import ops as exir_ops
1919
from executorch.exir.pass_base import ExportPass, PassResult
20+
from executorch.exir.passes import dead_code_elimination_pass
2021

2122
#################
2223
## linear_qcnw ##
@@ -224,6 +225,8 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
224225
)
225226

226227
graph_module.recompile()
227-
graph_module = super().call(graph_module).graph_module
228+
dead_code_elimination_pass(graph_module)
228229

230+
# Re-trace the graph since new nodes were (potentially) inserted
231+
graph_module = super().call(graph_module).graph_module
229232
return PassResult(graph_module, True)

backends/vulkan/_passes/tag_memory_meta_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import logging
8-
from copy import deepcopy
98
from typing import Any, Optional, Set
109

1110
import executorch.backends.vulkan.utils as utils
@@ -22,6 +21,7 @@
2221
from executorch.exir.dialects._ops import ops as exir_ops
2322

2423
from executorch.exir.pass_base import ExportPass, PassResult
24+
from executorch.exir.tensor import TensorSpec
2525

2626
logger: logging.Logger = logging.getLogger("")
2727
logger.setLevel(logging.INFO)
@@ -52,7 +52,7 @@ def insert_transition_node(
5252
(arg,),
5353
)
5454
clone_node.meta["val"] = arg.meta["val"]
55-
clone_node.meta["spec"] = deepcopy(arg.meta["spec"])
55+
clone_node.meta["spec"] = TensorSpec.from_tensor(clone_node.meta["val"])
5656
clone_node.meta["spec"].const = False
5757
set_memory_metadata(clone_node, storage, layout)
5858
arg.replace_all_uses_with(clone_node, lambda x, y=node: x == y)

backends/vulkan/op_registry.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,14 @@ def update_features_impl(op: OpKey):
230230
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
231231
# Symbolic integer ops
232232
torch.ops.aten.sym_size.int,
233+
operator.add,
234+
operator.lt,
235+
operator.gt,
236+
operator.ge,
237+
operator.le,
238+
# Guard and assert ops
239+
torch.ops.aten._assert_scalar.default,
240+
torch.ops.aten.sym_constrain_range_for_size.default,
233241
]
234242
)
235243
def register_ephemeral_op(features: OpFeatures):
@@ -500,7 +508,12 @@ def register_sdpa_with_kv_cache_op(features: OpFeatures):
500508
return features
501509

502510

503-
@update_features(["llama::update_cache", "llama::custom_sdpa"])
511+
@update_features(
512+
[
513+
"llama::update_cache",
514+
"llama::custom_sdpa",
515+
]
516+
)
504517
def register_sdpa_ops(features: OpFeatures):
505518
features.resize_fn = False
506519
features.buffer_impl = False
@@ -520,8 +533,17 @@ def register_rotary_emb_op(features: OpFeatures):
520533
return features
521534

522535

523-
@update_features(exir_ops.edge.aten.view_copy.default)
524-
def register_view_op(features: OpFeatures):
536+
@update_features(
537+
[
538+
exir_ops.edge.aten.clone.default,
539+
exir_ops.edge.aten.permute.default,
540+
exir_ops.edge.aten.permute_copy.default,
541+
exir_ops.edge.aten.select_copy.int,
542+
exir_ops.edge.aten.slice_copy.Tensor,
543+
exir_ops.edge.aten.view_copy.default,
544+
]
545+
)
546+
def register_view_ops(features: OpFeatures):
525547
features.texture_impl = TextureImplFeatures(
526548
valid_packed_dims=all_packed_dims,
527549
)
@@ -538,10 +560,8 @@ def register_view_op(features: OpFeatures):
538560
# Indexing and lookup
539561
exir_ops.edge.aten.flip.default,
540562
exir_ops.edge.aten.index_select.default,
541-
exir_ops.edge.aten.select_copy.int,
542563
# Tensor creation
543564
exir_ops.edge.aten.arange.start_step,
544-
exir_ops.edge.aten.clone.default,
545565
exir_ops.edge.aten.constant_pad_nd.default,
546566
exir_ops.edge.aten.full.default,
547567
exir_ops.edge.aten.full_like.default,
@@ -564,12 +584,9 @@ def register_ported_op(features: OpFeatures):
564584
# Ops ported from PyTorch Vulkan backend. These ops are in a separate registry becasue they support all packed dimensions
565585
@update_features(
566586
[
567-
# Indexing and lookup
568-
exir_ops.edge.aten.slice_copy.Tensor,
569587
# Shape Manipulation
570588
exir_ops.edge.aten.squeeze_copy.dims,
571589
exir_ops.edge.aten.unsqueeze_copy.default,
572-
exir_ops.edge.aten.permute_copy.default,
573590
# Tensor combination
574591
exir_ops.edge.aten.cat.default,
575592
exir_ops.edge.aten.repeat.default,

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,11 @@ def op_node_is_compatible( # noqa: C901: Function is too complex
146146
def node_is_compatible(
147147
self, node: torch.fx.Node, features: Optional[OpFeatures] = None
148148
) -> Tuple[bool, str]:
149-
if utils.is_symint_node(node):
150-
return node.target in vulkan_supported_ops, "Op is compatible"
151-
elif utils.is_tensor_node(node):
149+
if utils.is_tensor_node(node):
152150
return self.op_node_is_compatible(node, features=features)
151+
# For non-tensor nodes, just check if the op is registered
152+
elif hasattr(node, "target"):
153+
return node.target in vulkan_supported_ops, "Op is compatible"
153154

154155
return False, f"Unsupported node type: {node.format_node()}"
155156

backends/vulkan/runtime/graph/ops/impl/Permute.cpp

Lines changed: 83 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,12 @@ using utils::uvec4;
2525
namespace {
2626

2727
void check_args(
28-
const api::vTensor& in,
29-
const std::vector<int64_t>& permute_dims,
30-
const api::vTensor& out) {
31-
VK_CHECK_COND(check_same_packed_dim(in, out));
28+
ComputeGraph& graph,
29+
const ValueRef in,
30+
const ValueRef permute_dims,
31+
const ValueRef out) {
32+
(void)permute_dims;
33+
VK_CHECK_COND(check_same_packed_dim(graph, in, out));
3234

3335
// This implementation doesn't not requires the input tensor to have the same
3436
// dim size as the argument. The code will work as long as the input tensor's
@@ -38,40 +40,94 @@ void check_args(
3840

3941
} // namespace
4042

43+
void resize_permute_node(
44+
ComputeGraph* graph,
45+
const std::vector<ArgGroup>& args,
46+
const std::vector<ValueRef>& resize_args) {
47+
const ValueRef out = args[0].refs[0];
48+
const ValueRef in = args[1].refs[0];
49+
50+
const std::vector<int64_t> in_sizes = graph->sizes_of(in);
51+
const std::vector<int64_t> out_sizes = graph->sizes_of(out);
52+
53+
const std::vector<int64_t> permute_dims =
54+
graph->extract_int_or_symint_list(resize_args[0]);
55+
56+
if (in_sizes.size() == out_sizes.size() &&
57+
in_sizes.size() == permute_dims.size()) {
58+
std::vector<int64_t> new_out_sizes(out_sizes.size(), 1);
59+
const int64_t out_ndim = std::max(in_sizes.size(), out_sizes.size());
60+
for (int i = 0; i < out_ndim; i++) {
61+
const int64_t permute_dim = permute_dims.at(i);
62+
new_out_sizes.at(i) = in_sizes.at(permute_dim);
63+
}
64+
graph->virtual_resize(out, new_out_sizes);
65+
}
66+
// Case where permute is being used to implement squeeze
67+
else if (
68+
in_sizes.size() > out_sizes.size() &&
69+
in_sizes.size() == permute_dims.size()) {
70+
std::vector<int64_t> new_out_sizes(out_sizes.size(), 1);
71+
const size_t offset = in_sizes.size() - out_sizes.size();
72+
for (int i = 0; i < out_sizes.size(); i++) {
73+
const int64_t permute_dim = permute_dims.at(i + offset);
74+
new_out_sizes.at(i) = in_sizes.at(permute_dim);
75+
}
76+
graph->virtual_resize(out, new_out_sizes);
77+
}
78+
// Case where Permute is being used to implement unsqueeze
79+
else if (
80+
in_sizes.size() < out_sizes.size() &&
81+
out_sizes.size() == permute_dims.size()) {
82+
std::vector<int64_t> new_out_sizes(out_sizes.size(), 1);
83+
const size_t offset = out_sizes.size() - in_sizes.size();
84+
for (int i = 0; i < out_sizes.size(); i++) {
85+
int64_t permute_dim = permute_dims.at(i) - offset;
86+
if (permute_dim >= 0) {
87+
new_out_sizes.at(i) = in_sizes.at(permute_dim);
88+
}
89+
}
90+
graph->virtual_resize(out, new_out_sizes);
91+
} else {
92+
VK_THROW("Invalid permute dims");
93+
}
94+
}
95+
4196
void add_permute_node(
4297
ComputeGraph& graph,
43-
ValueRef in,
44-
const std::vector<int64_t>& permute_dims,
45-
ValueRef out) {
46-
vTensorPtr t_in = graph.get_tensor(in);
47-
vTensorPtr t_out = graph.get_tensor(out);
48-
49-
check_args(*t_in, permute_dims, *t_out);
98+
const ValueRef in,
99+
const ValueRef permute_dims,
100+
const ValueRef out) {
101+
check_args(graph, in, permute_dims, out);
50102

51103
ivec4 out_dims{0, 1, 2, 3};
52104

53105
// Special cases of squeeze/unsqueeze. Because the input dim size can be
54-
// different with output dim size. So pick t_in->dim() if squeeze, and
55-
// t_out->dim() if unsqueeze to create parameter for permute.
56-
int64_t out_ndim = std::max(t_in->dim(), t_out->dim());
106+
// different with output dim size. So pick graph.dim_of(in) if squeeze, and
107+
// graph.dim_of(out) if unsqueeze to create parameter for permute.
108+
const int64_t out_ndim = std::max(graph.dim_of(in), graph.dim_of(out));
57109
std::vector<bool> seen(out_ndim);
58-
for (int i = 0; i < out_ndim; i++) {
59-
int64_t permute_dim = permute_dims[i];
60-
VK_CHECK_COND(
61-
!seen[permute_dim], "Argument dim ", permute_dim, " is repeated");
62-
seen[permute_dim] = true;
63-
64-
out_dims[(4u - out_ndim) + i] = permute_dim + (4 - out_ndim);
110+
{
111+
IntListPtr permute_dims_ptr = graph.get_int_list(permute_dims);
112+
for (int i = 0; i < out_ndim; i++) {
113+
int64_t permute_dim = permute_dims_ptr->at(i);
114+
VK_CHECK_COND(
115+
!seen[permute_dim], "Argument dim ", permute_dim, " is repeated");
116+
seen[permute_dim] = true;
117+
118+
out_dims[(4u - out_ndim) + i] =
119+
utils::safe_downcast<int32_t>(permute_dim + (4 - out_ndim));
120+
}
65121
}
66122

67123
std::string kernel_name = "permute";
68124
kernel_name.reserve(kShaderNameReserve);
69-
add_dtype_suffix(kernel_name, *t_out);
125+
add_dtype_suffix(kernel_name, graph.dtype_of(out));
70126

71-
int32_t out_channels = dim_at<kChannel4D>(t_out->sizes());
72-
int32_t in_channels = dim_at<kChannel4D>(t_in->sizes());
127+
const int32_t out_channels = dim_at<kChannel4D>(graph.sizes_of(out));
128+
const int32_t in_channels = dim_at<kChannel4D>(graph.sizes_of(in));
73129

74-
const auto packed_dim = graph.packed_dim_of(in);
130+
const int32_t packed_dim = graph.packed_dim_of(in);
75131
ivec2 channel_info = {out_channels, in_channels};
76132
if (packed_dim == WHCN::kChannelsDim) {
77133
channel_info[0] = utils::align_up_4(channel_info[0]);
@@ -95,19 +151,9 @@ void add_permute_node(
95151
// Specialization Constants
96152
spec_vars,
97153
// Resize Args
98-
{},
154+
{permute_dims},
99155
// Resizing Logic
100-
nullptr));
101-
}
102-
103-
void add_permute_node(
104-
ComputeGraph& graph,
105-
ValueRef in,
106-
ValueRef permute_dims_ref,
107-
ValueRef out) {
108-
IntListPtr permute_dims = graph.get_int_list(permute_dims_ref);
109-
110-
add_permute_node(graph, in, *permute_dims, out);
156+
resize_permute_node));
111157
}
112158

113159
void permute(ComputeGraph& graph, const std::vector<ValueRef>& args) {

backends/vulkan/runtime/graph/ops/impl/Permute.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ namespace vkcompute {
1818

1919
void add_permute_node(
2020
ComputeGraph& graph,
21-
ValueRef in,
22-
const std::vector<int64_t>& permute_dims,
23-
ValueRef out);
21+
const ValueRef in,
22+
const ValueRef permute_dims,
23+
const ValueRef out);
2424

2525
} // namespace vkcompute

backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,20 @@ namespace vkcompute {
1515
void resize_rotary_embedding_node(
1616
ComputeGraph* graph,
1717
const std::vector<ArgGroup>& args,
18-
const std::vector<ValueRef>& extra_args) {
19-
(void)extra_args;
20-
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
21-
vTensorPtr in = graph->get_tensor(args[1].refs[0]);
22-
23-
std::vector<int64_t> in_sizes = in->sizes();
24-
// UNCOMMENT BELOW IF NEEDED
25-
// out->virtual_resize(in_sizes);
18+
const std::vector<ValueRef>& resize_args) {
19+
(void)resize_args;
20+
21+
const ValueRef xq_out = args.at(0).refs.at(0);
22+
const ValueRef xk_out = args.at(0).refs.at(1);
23+
24+
const ValueRef xq = args.at(1).refs.at(0);
25+
const ValueRef xk = args.at(1).refs.at(1);
26+
27+
const std::vector<int64_t> xq_sizes = graph->sizes_of(xq);
28+
const std::vector<int64_t> xk_sizes = graph->sizes_of(xk);
29+
30+
graph->virtual_resize(xq_out, xq_sizes);
31+
graph->virtual_resize(xk_out, xk_sizes);
2632
}
2733

2834
void add_rotary_embedding_node(

0 commit comments

Comments
 (0)