Skip to content

Commit 906ec99

Browse files
committed
[ET-VK][BE] Move all ops to use DynamicDispatchNode
## Changes Update (almost) all operators use `DynamicDispatchNode` instead of `DispatchNode`. ## Context `DynamicDispatchNode` was introduced in order to provide a way for operators to adjust 1. Which compute shader to dispatch 2. What global work group size to use 3. What local work group size to use Based on the current input and output shapes. This is useful for making sure that the most optimal compute shader is used for the current tensor sizes, and minimizing the number of inactive shader invocations. Differential Revision: [D79564595](https://our.internmc.facebook.com/intern/diff/D79564595/) [ghstack-poisoned]
1 parent bbe30ad commit 906ec99

20 files changed

+471
-142
lines changed

backends/vulkan/runtime/graph/ops/impl/Arange.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1212

13+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
1314
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
1415

1516
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
@@ -86,11 +87,11 @@ void add_arange_node(
8687
kernel_name.reserve(kShaderNameReserve);
8788
add_dtype_suffix(kernel_name, graph.dtype_of(out));
8889

89-
graph.execute_nodes().emplace_back(new DispatchNode(
90+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
9091
graph,
9192
VK_KERNEL_FROM_STR(kernel_name),
92-
graph.create_global_wg_size(out),
93-
graph.create_local_wg_size(out),
93+
default_pick_global_wg_size,
94+
default_pick_local_wg_size,
9495
// Inputs and Outputs
9596
{{out, vkapi::kWrite}},
9697
// Shader params buffers

backends/vulkan/runtime/graph/ops/impl/BatchNorm.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1010

11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
1112
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
1213

1314
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
@@ -83,11 +84,11 @@ void add_native_batch_norm_node(
8384
const int32_t num_texel_per_batch =
8485
utils::div_up_4((dim_at<kChannel4D>(in_sizes)));
8586

86-
graph.execute_nodes().emplace_back(new DispatchNode(
87+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
8788
graph,
8889
VK_KERNEL_FROM_STR(kernel_name),
89-
graph.create_global_wg_size(out_ref),
90-
graph.create_local_wg_size(out_ref),
90+
default_pick_global_wg_size,
91+
default_pick_local_wg_size,
9192
{{out_ref, vkapi::kWrite},
9293
{{in_ref, arg_weight, arg_bias, arg_mean, arg_var}, vkapi::kRead}},
9394
{graph.logical_limits_ubo(out_ref),

backends/vulkan/runtime/graph/ops/impl/Convolution.cpp

Lines changed: 116 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1010

11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
1112
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
1213

1314
#include <executorch/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h>
@@ -19,6 +20,13 @@
1920

2021
namespace vkcompute {
2122

23+
enum class Conv2dMethod : uint8_t {
24+
Depthwise,
25+
Pointwise,
26+
SlidingWindow,
27+
Transposed,
28+
};
29+
2230
void resize_conv2d_node(
2331
ComputeGraph* graph,
2432
const std::vector<ArgGroup>& args,
@@ -114,13 +122,6 @@ ValueRef prepack_biases(
114122
return v;
115123
}
116124

117-
enum class Conv2dMethod : uint8_t {
118-
Depthwise,
119-
Pointwise,
120-
SlidingWindow,
121-
Transposed,
122-
};
123-
124125
vkapi::ShaderInfo get_conv2d_shader(
125126
ComputeGraph& graph,
126127
const ValueRef out,
@@ -327,6 +328,108 @@ utils::uvec3 create_conv2d_global_wg_size(
327328
}
328329
}
329330

331+
// Custom global workgroup size function for conv2d
332+
utils::uvec3 conv2d_global_wg_size(
333+
ComputeGraph* graph,
334+
const vkapi::ShaderInfo& shader,
335+
const std::vector<ArgGroup>& args,
336+
const std::vector<ValueRef>& resize_args) {
337+
const ValueRef out = args.at(0).refs.at(0);
338+
const ValueRef weight_data = resize_args.at(0);
339+
340+
// Determine method from shader name
341+
Conv2dMethod method;
342+
if (shader.kernel_name.find("conv2d_dw") != std::string::npos) {
343+
method = Conv2dMethod::Depthwise;
344+
} else if (
345+
shader.kernel_name.find("conv2d_pw") != std::string::npos ||
346+
(shader.kernel_name.find("conv2d") != std::string::npos &&
347+
shader.kernel_name.find("conv_transpose2d") == std::string::npos)) {
348+
// Check if it's pointwise by examining weight sizes
349+
const auto& weight_sizes = graph->get_tref(weight_data)->sizes;
350+
if (weight_sizes.at(2) == 1 && weight_sizes.at(3) == 1) {
351+
method = Conv2dMethod::Pointwise;
352+
} else {
353+
method = Conv2dMethod::SlidingWindow;
354+
}
355+
} else if (shader.kernel_name.find("conv_transpose2d") != std::string::npos) {
356+
method = Conv2dMethod::Transposed;
357+
} else {
358+
method = Conv2dMethod::SlidingWindow;
359+
}
360+
361+
// Determine stride_equals_dilation from shader name
362+
bool stride_equals_dilation =
363+
shader.kernel_name.find("_sned") == std::string::npos;
364+
365+
utils::uvec3 wg_size = create_conv2d_global_wg_size(
366+
*graph, method, out, weight_data, stride_equals_dilation);
367+
368+
if (method == Conv2dMethod::Depthwise || method == Conv2dMethod::Pointwise) {
369+
wg_size = {wg_size[0] * wg_size[1], wg_size[2], 1};
370+
}
371+
372+
return wg_size;
373+
}
374+
375+
// Custom local workgroup size function for conv2d
376+
utils::uvec3 conv2d_local_wg_size(
377+
ComputeGraph* graph,
378+
const vkapi::ShaderInfo& shader,
379+
const utils::uvec3& global_workgroup_size,
380+
const std::vector<ArgGroup>& args,
381+
const std::vector<ValueRef>& resize_args) {
382+
(void)args;
383+
(void)resize_args;
384+
385+
// Determine method from shader name
386+
Conv2dMethod method;
387+
if (shader.kernel_name.find("conv2d_dw") != std::string::npos) {
388+
method = Conv2dMethod::Depthwise;
389+
} else if (
390+
shader.kernel_name.find("conv2d_pw") != std::string::npos ||
391+
(shader.kernel_name.find("conv2d") != std::string::npos &&
392+
shader.kernel_name.find("conv_transpose2d") == std::string::npos)) {
393+
method = Conv2dMethod::Pointwise;
394+
} else {
395+
method = Conv2dMethod::SlidingWindow;
396+
}
397+
398+
if (method == Conv2dMethod::Pointwise) {
399+
uint32_t local_wg_size_y = 1;
400+
if (global_workgroup_size[1] % 8 == 0) {
401+
local_wg_size_y = 8;
402+
} else if (global_workgroup_size[1] % 4 == 0) {
403+
local_wg_size_y = 4;
404+
} else if (global_workgroup_size[1] % 2 == 0) {
405+
local_wg_size_y = 2;
406+
}
407+
return {64 / local_wg_size_y, local_wg_size_y, 1};
408+
} else if (method == Conv2dMethod::Depthwise) {
409+
return {64, 1, 1};
410+
} else {
411+
return graph->create_local_wg_size(global_workgroup_size);
412+
}
413+
}
414+
415+
// Custom global workgroup size function for conv1d
416+
utils::uvec3 conv1d_global_wg_size(
417+
ComputeGraph* graph,
418+
const vkapi::ShaderInfo& shader,
419+
const std::vector<ArgGroup>& args,
420+
const std::vector<ValueRef>& resize_args) {
421+
(void)shader;
422+
(void)resize_args;
423+
const ValueRef out = args.at(0).refs.at(0);
424+
425+
return {// out length
426+
graph->size_at<uint32_t>(-1, out),
427+
// out channels
428+
static_cast<uint32_t>(graph->size_at<int64_t>(-2, out)),
429+
// out batches
430+
utils::div_up_4(graph->size_at<uint32_t>(-3, out))};
431+
}
432+
330433
void add_conv2d_node(
331434
ComputeGraph& graph,
332435
const ValueRef in,
@@ -486,11 +589,11 @@ void add_conv2d_node(
486589
};
487590
}
488591

489-
graph.execute_nodes().emplace_back(new DispatchNode(
592+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
490593
graph,
491594
shader,
492-
wg_size,
493-
local_wg_size,
595+
conv2d_global_wg_size,
596+
conv2d_local_wg_size,
494597
// Inputs and Outputs
495598
{{out, vkapi::kWrite}, {{in, arg_weight, arg_bias}, vkapi::kRead}},
496599
// Shader params buffers
@@ -560,15 +663,6 @@ void add_conv1d_node(
560663
const int32_t out_group_size =
561664
static_cast<int64_t>(out_channels / groups_val);
562665

563-
const utils::uvec3 global_size = {
564-
// out length
565-
graph.size_at<uint32_t>(-1, out),
566-
// out channels
567-
static_cast<uint32_t>(out_channels),
568-
// out batches
569-
utils::div_up_4(graph.size_at<uint32_t>(-3, out))};
570-
const utils::uvec3 local_size = graph.create_local_wg_size(global_size);
571-
572666
Kernel1dParams kernel_params = {
573667
kernel_size,
574668
stride_size,
@@ -587,11 +681,11 @@ void add_conv1d_node(
587681

588682
add_dtype_suffix(kernel_name, graph.dtype_of(out));
589683

590-
graph.execute_nodes().emplace_back(new DispatchNode(
684+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
591685
graph,
592686
VK_KERNEL_FROM_STR(kernel_name),
593-
global_size,
594-
local_size,
687+
conv1d_global_wg_size,
688+
default_pick_local_wg_size,
595689
// Inputs and Outputs
596690
{{out, vkapi::kWrite}, {{in, arg_weight, arg_bias}, vkapi::kRead}},
597691
// Shader params buffers

backends/vulkan/runtime/graph/ops/impl/Copy.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1010

11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
1112
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
1213
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
1314
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
@@ -35,11 +36,11 @@ void add_copy_offset_node(
3536

3637
auto shader = VK_KERNEL_FROM_STR(kernel_name);
3738

38-
graph.execute_nodes().emplace_back(new DispatchNode(
39+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
3940
graph,
4041
VK_KERNEL_FROM_STR(kernel_name),
41-
graph.create_global_wg_size(out),
42-
graph.create_local_wg_size(out),
42+
default_pick_global_wg_size,
43+
default_pick_local_wg_size,
4344
// Inputs and Outputs
4445
{
4546
{out, vkapi::kWrite},

backends/vulkan/runtime/graph/ops/impl/Embedding.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1010

11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
1112
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
1213

1314
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
@@ -46,11 +47,11 @@ void add_embedding_node(
4647
kernel_name.reserve(kShaderNameReserve);
4748
add_dtype_suffix(kernel_name, graph.dtype_of(out));
4849

49-
graph.execute_nodes().emplace_back(new DispatchNode(
50+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
5051
graph,
5152
VK_KERNEL_FROM_STR(kernel_name),
52-
graph.create_global_wg_size(out),
53-
graph.create_local_wg_size(out),
53+
default_pick_global_wg_size,
54+
default_pick_local_wg_size,
5455
{{out, vkapi::kWrite}, {{in, weight}, vkapi::kRead}},
5556
{
5657
graph.sizes_ubo(out),

backends/vulkan/runtime/graph/ops/impl/Flip.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,26 @@
88

99
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1010

11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
1112
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
1213
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
1314
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
1415
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
1516

1617
namespace vkcompute {
1718

19+
// Custom global workgroup size function for flip
20+
utils::uvec3 flip_global_wg_size(
21+
ComputeGraph* graph,
22+
const vkapi::ShaderInfo& shader,
23+
const std::vector<ArgGroup>& args,
24+
const std::vector<ValueRef>& resize_args) {
25+
(void)shader;
26+
(void)resize_args;
27+
const ValueRef out = args.at(0).refs.at(0);
28+
return graph->create_global_wg_size(out);
29+
}
30+
1831
void check_flip_args(
1932
ComputeGraph& graph,
2033
const ValueRef in,
@@ -59,11 +72,11 @@ void add_flip_node(
5972
kernel_name.reserve(kShaderNameReserve);
6073
add_dtype_suffix(kernel_name, graph.dtype_of(out));
6174

62-
graph.execute_nodes().emplace_back(new DispatchNode(
75+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
6376
graph,
6477
VK_KERNEL_FROM_STR(kernel_name),
65-
graph.create_global_wg_size(out),
66-
graph.create_local_wg_size(out),
78+
flip_global_wg_size,
79+
default_pick_local_wg_size,
6780
// Inputs and Outputs
6881
{
6982
{out, vkapi::kWrite},

backends/vulkan/runtime/graph/ops/impl/Full.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1010

11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
1112
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
1213
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
1314

@@ -42,11 +43,11 @@ void add_full_node(
4243

4344
add_dtype_suffix(kernel_name, graph.dtype_of(out));
4445

45-
graph.execute_nodes().emplace_back(new DispatchNode(
46+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
4647
graph,
4748
VK_KERNEL_FROM_STR(kernel_name),
48-
graph.create_global_wg_size(out),
49-
graph.create_local_wg_size(out),
49+
default_pick_global_wg_size,
50+
default_pick_local_wg_size,
5051
// Inputs and Outputs
5152
{{out, vkapi::kWrite}},
5253
// Shader params buffers

backends/vulkan/runtime/graph/ops/impl/GridPriors.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1010

11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
1112
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
1213

1314
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
@@ -46,11 +47,11 @@ void add_grid_priors_node(
4647
add_dtype_suffix(kernel_name, graph.dtype_of(out));
4748

4849
const GridPriorsParam param = {stride, offset};
49-
graph.execute_nodes().emplace_back(new DispatchNode(
50+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
5051
graph,
5152
VK_KERNEL_FROM_STR(kernel_name),
52-
graph.create_global_wg_size(out),
53-
graph.create_local_wg_size(out),
53+
default_pick_global_wg_size,
54+
default_pick_local_wg_size,
5455
// Inputs and Outputs
5556
{
5657
{out, vkapi::kWrite},

0 commit comments

Comments
 (0)