Skip to content

Commit a155072

Browse files
committed
Update base for Update on "[ET-VK] Migrate ops to use DynamicDispatchNode"
## Changes * Migrate operators that are used in the llama model to use `DynamicDispatchNode` instead of `DispatchNode` ## Motivation `DynamicDispatchNode` is a subclass of `DispatchNode` that allows dynamic selection of compute shaders, global and local work group sizing whenever the command buffer is encoded. This is critical for ensuring optimum performance when input shapes are dynamic, since it allows operators to select the best compute shader for the input conditions and also to adjust global work group sizing to launch the minimum number of work groups necessary. Without this change, performance of llama 3.2 1B with dynamic shapes enabled is terrible (< 1 tok/s) because global work group sizing is determined based on maximum tensor sizes, which is based on the maximum sequence length. In practice, the sequence length dimension of tensors (even during the prefill phase) will not approach the maximum. This results in a lot of inactive threads launched during compute shader dispatches. Differential Revision: [D75878398](https://our.internmc.facebook.com/intern/diff/D75878398/) [ghstack-poisoned]
2 parents c38171b + b5a6362 commit a155072

File tree

27 files changed

+418
-174
lines changed

27 files changed

+418
-174
lines changed

.ci/scripts/utils.sh

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,7 @@ build_executorch_runner() {
158158
cmake_install_executorch_lib() {
159159
echo "Installing libexecutorch.a and libportable_kernels.a"
160160
clean_executorch_install_folders
161-
retry cmake -DBUCK2="$BUCK" \
162-
-DCMAKE_INSTALL_PREFIX=cmake-out \
161+
retry cmake -DCMAKE_INSTALL_PREFIX=cmake-out \
163162
-DCMAKE_BUILD_TYPE=Release \
164163
-DPYTHON_EXECUTABLE="$PYTHON_EXECUTABLE" \
165164
-Bcmake-out .
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}
14+
#define T ${buffer_scalar_type(DTYPE)}
15+
16+
${define_active_storage_type(STORAGE)}
17+
18+
#include "indexing_utils.h"
19+
20+
${define_required_extensions(DTYPE)}
21+
22+
layout(std430) buffer;
23+
24+
${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)}
25+
${layout_declare_tensor(1, "r", "t_in", DTYPE, STORAGE)}
26+
$if STORAGE == "buffer":
27+
${layout_declare_ubo(2, "int", "numel")}
28+
$else:
29+
${layout_declare_ubo(2, "ivec3", "out_limits")}
30+
31+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
32+
33+
#include "activations.h"
34+
35+
#ifdef USING_BUFFER
36+
37+
void main() {
38+
const int i = int(gl_GlobalInvocationID.x);
39+
if (i >= numel) {
40+
return;
41+
}
42+
43+
float in_val = float(t_in[i]);
44+
t_out[i] = T(tan(in_val));
45+
}
46+
47+
#else
48+
49+
void main() {
50+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
51+
52+
if (any(greaterThanEqual(pos, out_limits))) {
53+
return;
54+
}
55+
56+
VEC4_T in_texel = texelFetch(t_in, pos, 0);
57+
imageStore(t_out, pos, VEC4_T(tan(in_texel)));
58+
}
59+
60+
#endif
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
tan:
2+
parameter_names_with_default_values:
3+
DTYPE: float
4+
STORAGE: texture3d
5+
generate_variant_forall:
6+
DTYPE:
7+
- VALUE: half
8+
- VALUE: float
9+
STORAGE:
10+
- VALUE: texture3d
11+
- VALUE: buffer
12+
shader_variants:
13+
- NAME: tan
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
12+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
13+
14+
namespace vkcompute {
15+
16+
using namespace utils;
17+
18+
void resize_tan_node(
19+
ComputeGraph* graph,
20+
const std::vector<ArgGroup>& args,
21+
const std::vector<ValueRef>& extra_args) {
22+
(void)extra_args;
23+
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
24+
vTensorPtr self = graph->get_tensor(args[1].refs[0]);
25+
26+
out->virtual_resize(self->sizes());
27+
}
28+
29+
void add_tan_node(ComputeGraph& graph, const ValueRef in, const ValueRef out) {
30+
std::string kernel_name = "tan";
31+
add_dtype_suffix(kernel_name, graph.dtype_of(out));
32+
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
33+
34+
vkapi::ParamsBindList ubos({});
35+
ubos.append({graph.logical_limits_ubo(out)});
36+
37+
graph.execute_nodes().emplace_back(new DispatchNode(
38+
graph,
39+
VK_KERNEL_FROM_STR(kernel_name),
40+
graph.create_global_wg_size(out),
41+
graph.create_local_wg_size(out),
42+
// Inputs and Outputs
43+
{{out, vkapi::kWrite}, {in, vkapi::kRead}},
44+
// Shader params buffers
45+
ubos,
46+
// Push Constants
47+
{},
48+
// Specialization Constants
49+
{},
50+
// Resize Args
51+
{},
52+
// Resizing Logic
53+
resize_tan_node));
54+
}
55+
56+
void tan(ComputeGraph& graph, const std::vector<ValueRef>& args) {
57+
return add_tan_node(graph, args[0], args[1]);
58+
}
59+
60+
REGISTER_OPERATORS {
61+
VK_REGISTER_OP(aten.tan.default, tan);
62+
}
63+
64+
} // namespace vkcompute

backends/vulkan/test/op_tests/cases.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1171,6 +1171,22 @@ def get_unary_ops_inputs():
11711171
return test_suite
11721172

11731173

1174+
# separate test suite from unary_ops for learning purposes
1175+
@register_test_suite("aten.tan.default")
1176+
def get_tan_inputs():
1177+
test_suite = VkTestSuite(
1178+
[
1179+
(M1,),
1180+
(M1, M2),
1181+
(S1, M1, M2),
1182+
(S1, S2, S2, M2),
1183+
]
1184+
)
1185+
test_suite.storage_types = ["utils::kTexture3D", "utils::kBuffer"]
1186+
test_suite.dtypes = ["at::kFloat", "at::kHalf"]
1187+
return test_suite
1188+
1189+
11741190
@register_test_suite("aten._native_batch_norm_legit_no_training.default")
11751191
def get_native_batch_norm_inputs():
11761192
Test = namedtuple(

devtools/inspector/tests/TARGETS

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")
2+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
23

34
oncall("executorch")
45

@@ -13,6 +14,7 @@ python_unittest(
1314
"//executorch/devtools/inspector:inspector",
1415
"//executorch/devtools/inspector:lib",
1516
"//executorch/exir:lib",
17+
"//executorch/devtools/inspector/tests:inspector_test_utils",
1618
],
1719
)
1820

@@ -48,5 +50,16 @@ python_unittest(
4850
"//executorch/devtools/inspector:lib",
4951
"//executorch/devtools/inspector:intermediate_output_capturer",
5052
"//executorch/exir:lib",
53+
"//executorch/devtools/inspector/tests:inspector_test_utils",
54+
],
55+
)
56+
57+
python_library(
58+
name = "inspector_test_utils",
59+
srcs = [
60+
"inspector_test_utils.py",
61+
],
62+
deps = [
63+
"//caffe2:torch",
5164
],
5265
)
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import torch
10+
import torch.nn as nn
11+
import torch.nn.functional as F
12+
13+
14+
class ConvlLinearModel(nn.Module):
15+
"""
16+
A neural network model with a convolutional layer followed by a linear layer.
17+
"""
18+
19+
def __init__(self):
20+
super(ConvlLinearModel, self).__init__()
21+
self.conv_layer = nn.Conv2d(
22+
in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1
23+
)
24+
self.conv_layer.weight = nn.Parameter(
25+
torch.tensor([[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]]])
26+
)
27+
self.conv_layer.bias = nn.Parameter(torch.tensor([0.0]))
28+
29+
self.linear_layer = nn.Linear(in_features=4, out_features=2)
30+
self.linear_layer.weight = nn.Parameter(
31+
torch.tensor([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]])
32+
)
33+
self.linear_layer.bias = nn.Parameter(torch.tensor([0.0, 0.0]))
34+
self.additional_bias = nn.Parameter(
35+
torch.tensor([0.5, -0.5]), requires_grad=False
36+
)
37+
self.scale_factor = nn.Parameter(torch.tensor([2.0, 0.5]), requires_grad=False)
38+
39+
def forward(self, x):
40+
x = self.conv_layer(x)
41+
x = x.view(x.size(0), -1)
42+
x = self.linear_layer(x)
43+
x = x + self.additional_bias
44+
x = x - 0.1
45+
x = x * self.scale_factor
46+
x = x / (self.scale_factor + 1.0)
47+
x = F.relu(x)
48+
x = torch.sigmoid(x)
49+
output1, output2 = torch.split(x, 1, dim=1)
50+
return output1, output2
51+
52+
@staticmethod
53+
def get_input():
54+
"""
55+
Returns the pre-defined input tensor for this model.
56+
"""
57+
return torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]], requires_grad=True)
58+
59+
@staticmethod
60+
def get_expected_intermediate_outputs():
61+
"""
62+
Returns the expected outputs of the debug handles and intermediate output mapping for this model for the given input.
63+
"""
64+
return {
65+
(10,): torch.tensor([[[[7.7000, 6.7000], [4.7000, 3.7000]]]]),
66+
(11,): torch.tensor([[7.7000, 6.7000, 4.7000, 3.7000]]),
67+
(12,): torch.tensor(
68+
[
69+
[0.1000, 0.5000],
70+
[0.2000, 0.6000],
71+
[0.3000, 0.7000],
72+
[0.4000, 0.8000],
73+
]
74+
),
75+
(13,): torch.tensor([[5.0000, 14.1200]]),
76+
(14,): torch.tensor([[5.5000, 13.6200]]),
77+
(15,): torch.tensor([[5.4000, 13.5200]]),
78+
(16,): torch.tensor([[10.8000, 6.7600]]),
79+
(17,): torch.tensor([3.0000, 1.5000]),
80+
(18,): torch.tensor([[3.6000, 4.5067]]),
81+
(19,): torch.tensor([[3.6000, 4.5067]]),
82+
(20,): torch.tensor([[0.9734, 0.9891]]),
83+
(21,): [torch.tensor([[0.9734]]), torch.tensor([[0.9891]])],
84+
}
85+
86+
87+
# Global model registry
88+
model_registry = {
89+
"ConvLinearModel": ConvlLinearModel,
90+
# Add new models here
91+
}
92+
93+
94+
def check_if_final_outputs_match(model_name, actual_outputs_with_handles):
95+
"""
96+
Checks if the actual outputs match the expected outputs for the specified model.
97+
Returns True if all outputs match, otherwise returns False.
98+
"""
99+
model_instance = model_registry[model_name]
100+
expected_outputs_with_handles = model_instance.get_expected_intermediate_outputs()
101+
if len(actual_outputs_with_handles) != len(expected_outputs_with_handles):
102+
return False
103+
for debug_handle, expected_output in expected_outputs_with_handles.items():
104+
actual_output = actual_outputs_with_handles.get(debug_handle)
105+
if actual_output is None:
106+
return False
107+
if isinstance(expected_output, list):
108+
if not isinstance(actual_output, list):
109+
return False
110+
if len(actual_output) != len(expected_output):
111+
return False
112+
for actual, expected in zip(actual_output, expected_output):
113+
if not torch.allclose(actual, expected, rtol=1e-4, atol=1e-5):
114+
return False
115+
else:
116+
if not torch.allclose(actual_output, expected_output, rtol=1e-4, atol=1e-5):
117+
return False
118+
return True

0 commit comments

Comments
 (0)