diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl index 29e57c4fecf..b2ae4953a78 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl @@ -14,6 +14,8 @@ #define TILE_SIZE ${TILE_SIZE} +#define BATCH_SIZE_Y ${BATCH_SIZE_Y} + #define op(X, A, B) ${OPERATOR} #include "indexing_utils.h" @@ -39,12 +41,20 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; * output at a single output location. */ void main() { - const u16vec3 pos = u16vec3( + // y divided up by batch size is used to determine 3d position + // since work size is calculated by x * ((y + B_Y - 1) / B_Y) * z + const uint out_limits_y_scaled = (out_limits.y + BATCH_SIZE_Y - 1) / BATCH_SIZE_Y; + + u16vec3 pos = u16vec3( gl_GlobalInvocationID.x % out_limits.x, - (gl_GlobalInvocationID.x / out_limits.x) % out_limits.y, - gl_GlobalInvocationID.x / (out_limits.x * out_limits.y)); + ((gl_GlobalInvocationID.x / out_limits.x) % out_limits_y_scaled), + gl_GlobalInvocationID.x / (out_limits.x * out_limits_y_scaled)); - if (any(greaterThanEqual(pos, out_limits))) { + // scale pos.y by batch size, because that's the top pixel to be processed + pos.y *= uint16_t(BATCH_SIZE_Y); + + // do not process if top pixel does not fit within the output range + if (any(greaterThanEqual(u16vec3(pos.x, pos.y, pos.z), out_limits))) { return; } @@ -57,18 +67,47 @@ void main() { const u16vec2 start = ipos; const u16vec2 end = ipos + u16vec2(overlay_region.xy); - VEC4_T sum = texelFetch(t_bias, u16vec2(pos.z, 0), 0); + // sum outputs + VEC4_T sum[BATCH_SIZE_Y]; + + sum[0] = texelFetch(t_bias, u16vec2(pos.z, 0), 0); + for (int i = 1; i < BATCH_SIZE_Y; i++) { + sum[i] = sum[0]; + } + + // array to store input texels + VEC4_T in_texels[TILE_SIZE]; + + // array to store kernel data of previous y + VEC4_T prev_kernel_line[TILE_SIZE]; + uint16_t kx = uint16_t(0); - for (uint16_t y = start.y, i = uint16_t(0); i < uint16_t(TILE_SIZE); y += uint16_t(dilation.y), i++) { + for (uint16_t y = start.y, i = uint16_t(0); i < uint16_t(TILE_SIZE + BATCH_SIZE_Y - 1); y += uint16_t(dilation.y), i++) { for (uint16_t x = start.x, j = uint16_t(0); j < uint16_t(TILE_SIZE); x += uint16_t(dilation.x), j++) { - // The weight kernel was rearranged such that every NxN filter is - // flattened to fit in one row. Each filter was then stacked on top of - // each other vertically. - const vec4 in_texel = texelFetch(t_in, u16vec3(x, y, pos.z), 0); - sum = fma(in_texel, texelFetch(t_kernel, u16vec2(kx, pos.z), 0), sum); - kx++; + in_texels[int(j)] = texelFetch(t_in, u16vec3(x, y, pos.z), 0); + } + + // from 2nd iteration onwards accumulate dot product in 2nd sum + // based on kernel line data fetched in previous iteration and input texel from this iteration + if (i > uint16_t(0)) { + for (uint16_t j = uint16_t(0); j < uint16_t(TILE_SIZE); j++) { + sum[1] = fma(in_texels[int(j)], prev_kernel_line[int(j)], sum[1]); + } + } + + // accumulate dot product in 1st sum only until tile size + if (i < uint16_t(TILE_SIZE)) { + for (uint16_t j = uint16_t(0); j < uint16_t(TILE_SIZE); j++, kx++) { + prev_kernel_line[int(j)] = texelFetch(t_kernel, u16vec2(kx, pos.z), 0); + sum[0] = fma(in_texels[int(j)], prev_kernel_line[int(j)], sum[0]); + } } } - imageStore(t_out, pos, op(sum, out_min, out_max)); + for (int i = 0; i < BATCH_SIZE_Y; i++) { + if (any(greaterThanEqual(u16vec3(pos.x, pos.y + i, pos.z), out_limits))) { + continue; + } + imageStore(t_out, u16vec3(pos.x, pos.y + i, pos.z), op(sum[i], out_min, out_max)); + } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.yaml index a0d11284258..bb197c2c187 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.yaml @@ -10,6 +10,7 @@ conv2d_dw_output_tile: NDIM: 3 DTYPE: float TILE_SIZE: 3 + BATCH_SIZE_Y: 2 generate_variant_forall: DTYPE: - VALUE: half diff --git a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp index 9ad600d27a7..3519635ac7e 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp @@ -296,6 +296,12 @@ utils::uvec3 create_conv2d_global_wg_size( utils::div_up(image_extents[0u], 2u), utils::div_up(image_extents[1u], 2u), image_extents[2u]}; + } else if (method == Conv2dMethod::Depthwise) { + const utils::uvec3 image_extents = graph.logical_limits_of(out); + return { + utils::div_up(image_extents[0u], 1u), + utils::div_up(image_extents[1u], 2u), + image_extents[2u]}; } else { return graph.create_global_wg_size(out); }