diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl index 21760eca0e0..57ae98eb85b 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl @@ -32,12 +32,14 @@ ${layout_declare_ubo(8, "float", "out_min", "float", "out_max")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require + /* * Computes a depthwise convolution. Each shader invocation calculates the * output at a single output location. */ void main() { - const ivec3 pos = ivec3(gl_GlobalInvocationID); + const u16vec3 pos = u16vec3(gl_GlobalInvocationID); if (any(greaterThanEqual(pos, out_limits))) { return; @@ -45,22 +47,22 @@ void main() { // Compute the index of the top-left element of the overlay region. Negative // indices indicate that the top-left element is in a region added by padding. - const ivec2 ipos = pos.xy * stride - padding; + const u16vec2 ipos = pos.xy * u16vec2(stride) - u16vec2(padding); // Compute the start and end of the input indices to load. Padding is assumed // to be constant 0 padding, so any reads from the padding region is skipped. - const ivec2 start = ipos; - const ivec2 end = ipos + overlay_region.xy; + const u16vec2 start = ipos; + const u16vec2 end = ipos + u16vec2(overlay_region.xy); - VEC4_T sum = texelFetch(t_bias, ivec2(pos.z, 0), 0); - int kx = 0; - for (int y = start.y, i = 0; i < TILE_SIZE; y += dilation.y, i++) { - for (int x = start.x, j = 0; j < TILE_SIZE; x += dilation.x, j++) { + VEC4_T sum = texelFetch(t_bias, u16vec2(pos.z, 0), 0); + uint16_t kx = uint16_t(0); + for (uint16_t y = start.y, i = uint16_t(0); i < uint16_t(TILE_SIZE); y += uint16_t(dilation.y), i++) { + for (uint16_t x = start.x, j = uint16_t(0); j < uint16_t(TILE_SIZE); x += uint16_t(dilation.x), j++) { // The weight kernel was rearranged such that every NxN filter is // flattened to fit in one row. Each filter was then stacked on top of // each other vertically. - const vec4 in_texel = texelFetch(t_in, ivec3(x, y, pos.z), 0); - sum = fma(in_texel, texelFetch(t_kernel, ivec2(kx, pos.z), 0), sum); + const vec4 in_texel = texelFetch(t_in, u16vec3(x, y, pos.z), 0); + sum = fma(in_texel, texelFetch(t_kernel, u16vec2(kx, pos.z), 0), sum); kx++; } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl index b1950f970e4..9d1f6c3bd91 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl @@ -40,7 +40,12 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; * size is only 1x1, making it easier to re-use loaded texels from t_kernel. */ void main() { - const u16vec3 gpos = u16vec3(gl_GlobalInvocationID); + const uint16_t out_limits_y_scaled = uint16_t((out_limits.y + TILE_SIZE - 1) / TILE_SIZE); + + const u16vec3 gpos = u16vec3( + gl_GlobalInvocationID.x / (out_limits_y_scaled * out_limits.z), + (gl_GlobalInvocationID.x / out_limits.z) % out_limits_y_scaled, + gl_GlobalInvocationID.x % out_limits.z); // Output position for TILE_SIZE = 2 // +--------+--------+ diff --git a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp index 1cdd7315f16..4f123cb8337 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp @@ -370,11 +370,17 @@ void add_conv2d_node( weight_data, clamp_out); + utils::uvec3 wg_size = create_conv2d_global_wg_size(graph, method, out); + + if (method == Conv2dMethod::Pointwise) { + wg_size = {wg_size[0] * wg_size[1] * wg_size[2], 1, 1}; + } + graph.execute_nodes().emplace_back(new DispatchNode( graph, shader, - create_conv2d_global_wg_size(graph, method, out), - graph.create_local_wg_size(out), + wg_size, + graph.create_local_wg_size(wg_size), // Inputs and Outputs {{out, vkapi::MemoryAccessType::WRITE}, {{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},