Skip to content

Commit 016eece

Browse files
authored
[ET-VK] Merge changes from #13158 and #13159 (#13173)
## Context As title. #13158 and #13159 landed in Meta internal repo as diffs but there was a problem creating a merge PR. This PR manually adds the changes from those PRs.
1 parent a44e4ac commit 016eece

21 files changed

+479
-150
lines changed

backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,8 @@ DynamicDispatchNode::DynamicDispatchNode(
5757
: DispatchNode(
5858
graph,
5959
shader,
60-
pick_global_wg_fn(&graph, shader, args, resize_args),
61-
pick_local_wg_fn(
62-
&graph,
63-
shader,
64-
pick_global_wg_fn(&graph, shader, args, resize_args),
65-
args,
66-
resize_args),
60+
{1u, 1u, 1u},
61+
{8u, 8u, 1u},
6762
args,
6863
params,
6964
push_constants,
@@ -72,7 +67,12 @@ DynamicDispatchNode::DynamicDispatchNode(
7267
resize_fn),
7368
pick_shader_fn_{nullptr},
7469
pick_global_wg_fn_(pick_global_wg_fn),
75-
pick_local_wg_fn_(pick_local_wg_fn) {}
70+
pick_local_wg_fn_(pick_local_wg_fn) {
71+
global_workgroup_size_ =
72+
pick_global_wg_fn(&graph, shader_, args, resize_args);
73+
local_workgroup_size_ = utils::WorkgroupSize(pick_local_wg_fn(
74+
&graph, shader_, global_workgroup_size_, args, resize_args));
75+
}
7676

7777
void DynamicDispatchNode::encode(ComputeGraph* graph) {
7878
if (pick_shader_fn_) {

backends/vulkan/runtime/graph/ops/impl/Arange.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1212

13+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
1314
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
1415

1516
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
@@ -86,11 +87,11 @@ void add_arange_node(
8687
kernel_name.reserve(kShaderNameReserve);
8788
add_dtype_suffix(kernel_name, graph.dtype_of(out));
8889

89-
graph.execute_nodes().emplace_back(new DispatchNode(
90+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
9091
graph,
9192
VK_KERNEL_FROM_STR(kernel_name),
92-
graph.create_global_wg_size(out),
93-
graph.create_local_wg_size(out),
93+
default_pick_global_wg_size,
94+
default_pick_local_wg_size,
9495
// Inputs and Outputs
9596
{{out, vkapi::kWrite}},
9697
// Shader params buffers

backends/vulkan/runtime/graph/ops/impl/BatchNorm.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1010

11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
1112
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
1213

1314
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
@@ -83,11 +84,11 @@ void add_native_batch_norm_node(
8384
const int32_t num_texel_per_batch =
8485
utils::div_up_4((dim_at<kChannel4D>(in_sizes)));
8586

86-
graph.execute_nodes().emplace_back(new DispatchNode(
87+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
8788
graph,
8889
VK_KERNEL_FROM_STR(kernel_name),
89-
graph.create_global_wg_size(out_ref),
90-
graph.create_local_wg_size(out_ref),
90+
default_pick_global_wg_size,
91+
default_pick_local_wg_size,
9192
{{out_ref, vkapi::kWrite},
9293
{{in_ref, arg_weight, arg_bias, arg_mean, arg_var}, vkapi::kRead}},
9394
{graph.logical_limits_ubo(out_ref),

backends/vulkan/runtime/graph/ops/impl/Convolution.cpp

Lines changed: 116 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1010

11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
1112
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
1213

1314
#include <executorch/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h>
@@ -19,6 +20,13 @@
1920

2021
namespace vkcompute {
2122

23+
enum class Conv2dMethod : uint8_t {
24+
Depthwise,
25+
Pointwise,
26+
SlidingWindow,
27+
Transposed,
28+
};
29+
2230
void resize_conv2d_node(
2331
ComputeGraph* graph,
2432
const std::vector<ArgGroup>& args,
@@ -114,13 +122,6 @@ ValueRef prepack_biases(
114122
return v;
115123
}
116124

117-
enum class Conv2dMethod : uint8_t {
118-
Depthwise,
119-
Pointwise,
120-
SlidingWindow,
121-
Transposed,
122-
};
123-
124125
vkapi::ShaderInfo get_conv2d_shader(
125126
ComputeGraph& graph,
126127
const ValueRef out,
@@ -327,6 +328,108 @@ utils::uvec3 create_conv2d_global_wg_size(
327328
}
328329
}
329330

331+
// Custom global workgroup size function for conv2d
332+
utils::uvec3 conv2d_global_wg_size(
333+
ComputeGraph* graph,
334+
const vkapi::ShaderInfo& shader,
335+
const std::vector<ArgGroup>& args,
336+
const std::vector<ValueRef>& resize_args) {
337+
const ValueRef out = args.at(0).refs.at(0);
338+
const ValueRef weight_data = resize_args.at(0);
339+
340+
// Determine method from shader name
341+
Conv2dMethod method;
342+
if (shader.kernel_name.find("conv2d_dw") != std::string::npos) {
343+
method = Conv2dMethod::Depthwise;
344+
} else if (
345+
shader.kernel_name.find("conv2d_pw") != std::string::npos ||
346+
(shader.kernel_name.find("conv2d") != std::string::npos &&
347+
shader.kernel_name.find("conv_transpose2d") == std::string::npos)) {
348+
// Check if it's pointwise by examining weight sizes
349+
const auto& weight_sizes = graph->get_tref(weight_data)->sizes;
350+
if (weight_sizes.at(2) == 1 && weight_sizes.at(3) == 1) {
351+
method = Conv2dMethod::Pointwise;
352+
} else {
353+
method = Conv2dMethod::SlidingWindow;
354+
}
355+
} else if (shader.kernel_name.find("conv_transpose2d") != std::string::npos) {
356+
method = Conv2dMethod::Transposed;
357+
} else {
358+
method = Conv2dMethod::SlidingWindow;
359+
}
360+
361+
// Determine stride_equals_dilation from shader name
362+
bool stride_equals_dilation =
363+
shader.kernel_name.find("_sned") == std::string::npos;
364+
365+
utils::uvec3 wg_size = create_conv2d_global_wg_size(
366+
*graph, method, out, weight_data, stride_equals_dilation);
367+
368+
if (method == Conv2dMethod::Depthwise || method == Conv2dMethod::Pointwise) {
369+
wg_size = {wg_size[0] * wg_size[1], wg_size[2], 1};
370+
}
371+
372+
return wg_size;
373+
}
374+
375+
// Custom local workgroup size function for conv2d
376+
utils::uvec3 conv2d_local_wg_size(
377+
ComputeGraph* graph,
378+
const vkapi::ShaderInfo& shader,
379+
const utils::uvec3& global_workgroup_size,
380+
const std::vector<ArgGroup>& args,
381+
const std::vector<ValueRef>& resize_args) {
382+
(void)args;
383+
(void)resize_args;
384+
385+
// Determine method from shader name
386+
Conv2dMethod method;
387+
if (shader.kernel_name.find("conv2d_dw") != std::string::npos) {
388+
method = Conv2dMethod::Depthwise;
389+
} else if (
390+
shader.kernel_name.find("conv2d_pw") != std::string::npos ||
391+
(shader.kernel_name.find("conv2d") != std::string::npos &&
392+
shader.kernel_name.find("conv_transpose2d") == std::string::npos)) {
393+
method = Conv2dMethod::Pointwise;
394+
} else {
395+
method = Conv2dMethod::SlidingWindow;
396+
}
397+
398+
if (method == Conv2dMethod::Pointwise) {
399+
uint32_t local_wg_size_y = 1;
400+
if (global_workgroup_size[1] % 8 == 0) {
401+
local_wg_size_y = 8;
402+
} else if (global_workgroup_size[1] % 4 == 0) {
403+
local_wg_size_y = 4;
404+
} else if (global_workgroup_size[1] % 2 == 0) {
405+
local_wg_size_y = 2;
406+
}
407+
return {64 / local_wg_size_y, local_wg_size_y, 1};
408+
} else if (method == Conv2dMethod::Depthwise) {
409+
return {64, 1, 1};
410+
} else {
411+
return graph->create_local_wg_size(global_workgroup_size);
412+
}
413+
}
414+
415+
// Custom global workgroup size function for conv1d
416+
utils::uvec3 conv1d_global_wg_size(
417+
ComputeGraph* graph,
418+
const vkapi::ShaderInfo& shader,
419+
const std::vector<ArgGroup>& args,
420+
const std::vector<ValueRef>& resize_args) {
421+
(void)shader;
422+
(void)resize_args;
423+
const ValueRef out = args.at(0).refs.at(0);
424+
425+
return {// out length
426+
graph->size_at<uint32_t>(-1, out),
427+
// out channels
428+
static_cast<uint32_t>(graph->size_at<int64_t>(-2, out)),
429+
// out batches
430+
utils::div_up_4(graph->size_at<uint32_t>(-3, out))};
431+
}
432+
330433
void add_conv2d_node(
331434
ComputeGraph& graph,
332435
const ValueRef in,
@@ -486,11 +589,11 @@ void add_conv2d_node(
486589
};
487590
}
488591

489-
graph.execute_nodes().emplace_back(new DispatchNode(
592+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
490593
graph,
491594
shader,
492-
wg_size,
493-
local_wg_size,
595+
conv2d_global_wg_size,
596+
conv2d_local_wg_size,
494597
// Inputs and Outputs
495598
{{out, vkapi::kWrite}, {{in, arg_weight, arg_bias}, vkapi::kRead}},
496599
// Shader params buffers
@@ -560,15 +663,6 @@ void add_conv1d_node(
560663
const int32_t out_group_size =
561664
static_cast<int64_t>(out_channels / groups_val);
562665

563-
const utils::uvec3 global_size = {
564-
// out length
565-
graph.size_at<uint32_t>(-1, out),
566-
// out channels
567-
static_cast<uint32_t>(out_channels),
568-
// out batches
569-
utils::div_up_4(graph.size_at<uint32_t>(-3, out))};
570-
const utils::uvec3 local_size = graph.create_local_wg_size(global_size);
571-
572666
Kernel1dParams kernel_params = {
573667
kernel_size,
574668
stride_size,
@@ -587,11 +681,11 @@ void add_conv1d_node(
587681

588682
add_dtype_suffix(kernel_name, graph.dtype_of(out));
589683

590-
graph.execute_nodes().emplace_back(new DispatchNode(
684+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
591685
graph,
592686
VK_KERNEL_FROM_STR(kernel_name),
593-
global_size,
594-
local_size,
687+
conv1d_global_wg_size,
688+
default_pick_local_wg_size,
595689
// Inputs and Outputs
596690
{{out, vkapi::kWrite}, {{in, arg_weight, arg_bias}, vkapi::kRead}},
597691
// Shader params buffers

backends/vulkan/runtime/graph/ops/impl/Copy.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1010

11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
1112
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
1213
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
1314
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
@@ -35,11 +36,11 @@ void add_copy_offset_node(
3536

3637
auto shader = VK_KERNEL_FROM_STR(kernel_name);
3738

38-
graph.execute_nodes().emplace_back(new DispatchNode(
39+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
3940
graph,
4041
VK_KERNEL_FROM_STR(kernel_name),
41-
graph.create_global_wg_size(out),
42-
graph.create_local_wg_size(out),
42+
default_pick_global_wg_size,
43+
default_pick_local_wg_size,
4344
// Inputs and Outputs
4445
{
4546
{out, vkapi::kWrite},

backends/vulkan/runtime/graph/ops/impl/Embedding.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1010

11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
1112
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
1213

1314
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
@@ -46,11 +47,11 @@ void add_embedding_node(
4647
kernel_name.reserve(kShaderNameReserve);
4748
add_dtype_suffix(kernel_name, graph.dtype_of(out));
4849

49-
graph.execute_nodes().emplace_back(new DispatchNode(
50+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
5051
graph,
5152
VK_KERNEL_FROM_STR(kernel_name),
52-
graph.create_global_wg_size(out),
53-
graph.create_local_wg_size(out),
53+
default_pick_global_wg_size,
54+
default_pick_local_wg_size,
5455
{{out, vkapi::kWrite}, {{in, weight}, vkapi::kRead}},
5556
{
5657
graph.sizes_ubo(out),

backends/vulkan/runtime/graph/ops/impl/Flip.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,26 @@
88

99
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1010

11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
1112
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
1213
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
1314
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
1415
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
1516

1617
namespace vkcompute {
1718

19+
// Custom global workgroup size function for flip
20+
utils::uvec3 flip_global_wg_size(
21+
ComputeGraph* graph,
22+
const vkapi::ShaderInfo& shader,
23+
const std::vector<ArgGroup>& args,
24+
const std::vector<ValueRef>& resize_args) {
25+
(void)shader;
26+
(void)resize_args;
27+
const ValueRef out = args.at(0).refs.at(0);
28+
return graph->create_global_wg_size(out);
29+
}
30+
1831
void check_flip_args(
1932
ComputeGraph& graph,
2033
const ValueRef in,
@@ -59,11 +72,11 @@ void add_flip_node(
5972
kernel_name.reserve(kShaderNameReserve);
6073
add_dtype_suffix(kernel_name, graph.dtype_of(out));
6174

62-
graph.execute_nodes().emplace_back(new DispatchNode(
75+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
6376
graph,
6477
VK_KERNEL_FROM_STR(kernel_name),
65-
graph.create_global_wg_size(out),
66-
graph.create_local_wg_size(out),
78+
flip_global_wg_size,
79+
default_pick_local_wg_size,
6780
// Inputs and Outputs
6881
{
6982
{out, vkapi::kWrite},

backends/vulkan/runtime/graph/ops/impl/Full.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1010

11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
1112
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
1213
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
1314

@@ -42,11 +43,11 @@ void add_full_node(
4243

4344
add_dtype_suffix(kernel_name, graph.dtype_of(out));
4445

45-
graph.execute_nodes().emplace_back(new DispatchNode(
46+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
4647
graph,
4748
VK_KERNEL_FROM_STR(kernel_name),
48-
graph.create_global_wg_size(out),
49-
graph.create_local_wg_size(out),
49+
default_pick_global_wg_size,
50+
default_pick_local_wg_size,
5051
// Inputs and Outputs
5152
{{out, vkapi::kWrite}},
5253
// Shader params buffers

0 commit comments

Comments
 (0)