8
8
9
9
#include < executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10
10
11
+ #include < executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
11
12
#include < executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
12
13
13
14
#include < executorch/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h>
19
20
20
21
namespace vkcompute {
21
22
23
+ enum class Conv2dMethod : uint8_t {
24
+ Depthwise,
25
+ Pointwise,
26
+ SlidingWindow,
27
+ Transposed,
28
+ };
29
+
22
30
void resize_conv2d_node (
23
31
ComputeGraph* graph,
24
32
const std::vector<ArgGroup>& args,
@@ -114,13 +122,6 @@ ValueRef prepack_biases(
114
122
return v;
115
123
}
116
124
117
- enum class Conv2dMethod : uint8_t {
118
- Depthwise,
119
- Pointwise,
120
- SlidingWindow,
121
- Transposed,
122
- };
123
-
124
125
vkapi::ShaderInfo get_conv2d_shader (
125
126
ComputeGraph& graph,
126
127
const ValueRef out,
@@ -327,6 +328,108 @@ utils::uvec3 create_conv2d_global_wg_size(
327
328
}
328
329
}
329
330
331
+ // Custom global workgroup size function for conv2d
332
+ utils::uvec3 conv2d_global_wg_size (
333
+ ComputeGraph* graph,
334
+ const vkapi::ShaderInfo& shader,
335
+ const std::vector<ArgGroup>& args,
336
+ const std::vector<ValueRef>& resize_args) {
337
+ const ValueRef out = args.at (0 ).refs .at (0 );
338
+ const ValueRef weight_data = resize_args.at (0 );
339
+
340
+ // Determine method from shader name
341
+ Conv2dMethod method;
342
+ if (shader.kernel_name .find (" conv2d_dw" ) != std::string::npos) {
343
+ method = Conv2dMethod::Depthwise;
344
+ } else if (
345
+ shader.kernel_name .find (" conv2d_pw" ) != std::string::npos ||
346
+ (shader.kernel_name .find (" conv2d" ) != std::string::npos &&
347
+ shader.kernel_name .find (" conv_transpose2d" ) == std::string::npos)) {
348
+ // Check if it's pointwise by examining weight sizes
349
+ const auto & weight_sizes = graph->get_tref (weight_data)->sizes ;
350
+ if (weight_sizes.at (2 ) == 1 && weight_sizes.at (3 ) == 1 ) {
351
+ method = Conv2dMethod::Pointwise;
352
+ } else {
353
+ method = Conv2dMethod::SlidingWindow;
354
+ }
355
+ } else if (shader.kernel_name .find (" conv_transpose2d" ) != std::string::npos) {
356
+ method = Conv2dMethod::Transposed;
357
+ } else {
358
+ method = Conv2dMethod::SlidingWindow;
359
+ }
360
+
361
+ // Determine stride_equals_dilation from shader name
362
+ bool stride_equals_dilation =
363
+ shader.kernel_name .find (" _sned" ) == std::string::npos;
364
+
365
+ utils::uvec3 wg_size = create_conv2d_global_wg_size (
366
+ *graph, method, out, weight_data, stride_equals_dilation);
367
+
368
+ if (method == Conv2dMethod::Depthwise || method == Conv2dMethod::Pointwise) {
369
+ wg_size = {wg_size[0 ] * wg_size[1 ], wg_size[2 ], 1 };
370
+ }
371
+
372
+ return wg_size;
373
+ }
374
+
375
+ // Custom local workgroup size function for conv2d
376
+ utils::uvec3 conv2d_local_wg_size (
377
+ ComputeGraph* graph,
378
+ const vkapi::ShaderInfo& shader,
379
+ const utils::uvec3& global_workgroup_size,
380
+ const std::vector<ArgGroup>& args,
381
+ const std::vector<ValueRef>& resize_args) {
382
+ (void )args;
383
+ (void )resize_args;
384
+
385
+ // Determine method from shader name
386
+ Conv2dMethod method;
387
+ if (shader.kernel_name .find (" conv2d_dw" ) != std::string::npos) {
388
+ method = Conv2dMethod::Depthwise;
389
+ } else if (
390
+ shader.kernel_name .find (" conv2d_pw" ) != std::string::npos ||
391
+ (shader.kernel_name .find (" conv2d" ) != std::string::npos &&
392
+ shader.kernel_name .find (" conv_transpose2d" ) == std::string::npos)) {
393
+ method = Conv2dMethod::Pointwise;
394
+ } else {
395
+ method = Conv2dMethod::SlidingWindow;
396
+ }
397
+
398
+ if (method == Conv2dMethod::Pointwise) {
399
+ uint32_t local_wg_size_y = 1 ;
400
+ if (global_workgroup_size[1 ] % 8 == 0 ) {
401
+ local_wg_size_y = 8 ;
402
+ } else if (global_workgroup_size[1 ] % 4 == 0 ) {
403
+ local_wg_size_y = 4 ;
404
+ } else if (global_workgroup_size[1 ] % 2 == 0 ) {
405
+ local_wg_size_y = 2 ;
406
+ }
407
+ return {64 / local_wg_size_y, local_wg_size_y, 1 };
408
+ } else if (method == Conv2dMethod::Depthwise) {
409
+ return {64 , 1 , 1 };
410
+ } else {
411
+ return graph->create_local_wg_size (global_workgroup_size);
412
+ }
413
+ }
414
+
415
+ // Custom global workgroup size function for conv1d
416
+ utils::uvec3 conv1d_global_wg_size (
417
+ ComputeGraph* graph,
418
+ const vkapi::ShaderInfo& shader,
419
+ const std::vector<ArgGroup>& args,
420
+ const std::vector<ValueRef>& resize_args) {
421
+ (void )shader;
422
+ (void )resize_args;
423
+ const ValueRef out = args.at (0 ).refs .at (0 );
424
+
425
+ return {// out length
426
+ graph->size_at <uint32_t >(-1 , out),
427
+ // out channels
428
+ static_cast <uint32_t >(graph->size_at <int64_t >(-2 , out)),
429
+ // out batches
430
+ utils::div_up_4 (graph->size_at <uint32_t >(-3 , out))};
431
+ }
432
+
330
433
void add_conv2d_node (
331
434
ComputeGraph& graph,
332
435
const ValueRef in,
@@ -486,11 +589,11 @@ void add_conv2d_node(
486
589
};
487
590
}
488
591
489
- graph.execute_nodes ().emplace_back (new DispatchNode (
592
+ graph.execute_nodes ().emplace_back (new DynamicDispatchNode (
490
593
graph,
491
594
shader,
492
- wg_size ,
493
- local_wg_size ,
595
+ conv2d_global_wg_size ,
596
+ conv2d_local_wg_size ,
494
597
// Inputs and Outputs
495
598
{{out, vkapi::kWrite }, {{in, arg_weight, arg_bias}, vkapi::kRead }},
496
599
// Shader params buffers
@@ -560,15 +663,6 @@ void add_conv1d_node(
560
663
const int32_t out_group_size =
561
664
static_cast <int64_t >(out_channels / groups_val);
562
665
563
- const utils::uvec3 global_size = {
564
- // out length
565
- graph.size_at <uint32_t >(-1 , out),
566
- // out channels
567
- static_cast <uint32_t >(out_channels),
568
- // out batches
569
- utils::div_up_4 (graph.size_at <uint32_t >(-3 , out))};
570
- const utils::uvec3 local_size = graph.create_local_wg_size (global_size);
571
-
572
666
Kernel1dParams kernel_params = {
573
667
kernel_size,
574
668
stride_size,
@@ -587,11 +681,11 @@ void add_conv1d_node(
587
681
588
682
add_dtype_suffix (kernel_name, graph.dtype_of (out));
589
683
590
- graph.execute_nodes ().emplace_back (new DispatchNode (
684
+ graph.execute_nodes ().emplace_back (new DynamicDispatchNode (
591
685
graph,
592
686
VK_KERNEL_FROM_STR (kernel_name),
593
- global_size ,
594
- local_size ,
687
+ conv1d_global_wg_size ,
688
+ default_pick_local_wg_size ,
595
689
// Inputs and Outputs
596
690
{{out, vkapi::kWrite }, {{in, arg_weight, arg_bias}, vkapi::kRead }},
597
691
// Shader params buffers
0 commit comments