From f0f92a26d9e23b7d5c98df66bc5ec630fdf950e3 Mon Sep 17 00:00:00 2001 From: Vivek Trivedi <5340687+trivedivivek@users.noreply.github.com> Date: Mon, 26 May 2025 21:40:53 -0700 Subject: [PATCH] [ET-VK] Tuning local workgroup size calculation for conv2d pw to improve performance. This diff adjusts the local workgroup size (`local_wg_size`) based on batch count (stored in `wg_size[1]`), to improve conv2d pw performance. * If `wg_size[1]` is a multiple of 8, `local_wg_size_y` is set to 8. * If `wg_size[1]` is a multiple of 4, `local_wg_size_y` is set to 4. * If `wg_size[1]` is a multiple of 2, `local_wg_size_y` is set to 2. * Otherwise, we default to `local_wg_size_y` = 1. The dispatch size in 2 dimensions is then calculate based on `{64 / local_wg_size_y, local_wg_size_y, 1}`. Differential Revision: [D75420517](https://our.internmc.facebook.com/intern/diff/D75420517/) [ghstack-poisoned] --- .../runtime/graph/ops/impl/Convolution.cpp | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp index 5250c3baef2..ba1f50a23c1 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp @@ -404,6 +404,21 @@ void add_conv2d_node( wg_size = {wg_size[0] * wg_size[1], wg_size[2], 1}; } + utils::uvec3 local_wg_size; + if (method == Conv2dMethod::Pointwise) { + uint32_t local_wg_size_y = 1; + if (wg_size[1] % 8 == 0) { + local_wg_size_y = 8; + } else if (wg_size[1] % 4 == 0) { + local_wg_size_y = 4; + } else if (wg_size[1] % 2 == 0) { + local_wg_size_y = 2; + } + local_wg_size = {64 / local_wg_size_y, local_wg_size_y, 1}; + } else { + local_wg_size = graph.create_local_wg_size(wg_size); + } + vkapi::ParamsBindList param_buffers; std::vector push_constants; if (method == Conv2dMethod::Pointwise) { @@ -464,7 +479,7 @@ void add_conv2d_node( graph, shader, wg_size, - graph.create_local_wg_size(wg_size), + local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {{in, arg_weight, arg_bias}, vkapi::kRead}}, // Shader params buffers