diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl index e98d2e919b0..56bffaee675 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl @@ -90,9 +90,10 @@ void main() { void main() { const u16vec2 out_pos = u16vec2( - gl_GlobalInvocationID.x / out_limits.y, - gl_GlobalInvocationID.x % out_limits.y); - if (out_pos.x >= out_limits.x) { + gl_GlobalInvocationID.x, + gl_GlobalInvocationID.y); + + if (out_pos.x >= out_limits.x || out_pos.y >= out_limits.y) { return; } diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp index ea6601502f1..59684d73bd2 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp @@ -114,15 +114,37 @@ void add_q_8w_linear_node( graph.sizes_ubo(mat1_W_packed)}); } - // set global work group size to be 1 dimensional - const utils::uvec3 wg_size = { - static_cast(graph.numel_of(out_W_packed)), 1, 1}; + utils::uvec3 global_wg; + if (graph.is_buffer_storage(out)) { + global_wg = {static_cast(graph.numel_of(out_W_packed)), 1, 1}; + } else { + global_wg = graph.logical_limits_of(out_W_packed); + } + + utils::uvec3 local_wg{8, 8, 1}; + int32_t out_W = graph.size_at(-1, out_W_packed); + + if (graph.is_buffer_storage(out_W_packed)) { + local_wg[0] = 64; + local_wg[1] = 1; + local_wg[2] = 1; + } else { + if (out_W % 8 != 0) { + if (out_W % 4 == 0) { + local_wg[0] = 4; + local_wg[1] = 16; + } else { + local_wg[0] = 2; + local_wg[1] = 32; + } + } + } graph.execute_nodes().emplace_back(new DispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - wg_size, - graph.create_local_wg_size(wg_size), + global_wg, + local_wg, // Inputs and Outputs {{out_W_packed, vkapi::MemoryAccessType::WRITE}, {{mat1_W_packed, q_mat2, scales}, vkapi::MemoryAccessType::READ}},