diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_coop.glsl b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_coop.glsl index c8ccbacffc1..3ad9e759910 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_coop.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_coop.glsl @@ -38,18 +38,21 @@ layout(push_constant) uniform restrict Block { ivec4 weight_sizes; }; +#include "indexing_utils.h" + layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; shared VEC4_T partial_c[NGROUPS][NWORKERS][TILE_ROWS]; void main() { - const uint out_row = gl_GlobalInvocationID.y * TILE_ROWS; - const uint out_col = gl_GlobalInvocationID.x << 2; + const uint out_width_ntexels = divup4(out_sizes.x); + const uint out_col = (gl_GlobalInvocationID.x % out_width_ntexels) << 2; + const uint out_row = (gl_GlobalInvocationID.x / out_width_ntexels) * TILE_ROWS; const int gid = int(gl_LocalInvocationID.x); // group id const int wid = int(gl_LocalInvocationID.z); // worker id - if (out_col >= out_sizes.x || out_row >= out_sizes.y) { + if (out_row >= out_sizes.y) { return; } diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.glsl index 8a8670b4bb3..6d7995a77f0 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.glsl @@ -36,13 +36,16 @@ layout(push_constant) uniform restrict Block { ivec4 weight_sizes; }; +#include "indexing_utils.h" + layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; void main() { - const uint out_row = gl_GlobalInvocationID.y * TILE_ROWS; - const uint out_col = gl_GlobalInvocationID.x << 2; + const uint out_width_ntexels = divup4(out_sizes.x); + const uint out_col = (gl_GlobalInvocationID.x % out_width_ntexels) << 2; + const uint out_row = (gl_GlobalInvocationID.x / out_width_ntexels) * TILE_ROWS; - if (out_col >= out_sizes.x || out_row >= out_sizes.y) { + if (out_row >= out_sizes.y) { return; } diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.yaml index 1e8a5e1fe7d..941836b48c4 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.yaml @@ -16,10 +16,10 @@ q_8w_linear_tiled: TILE_ROWS: - VALUE: 1 SUFFIX: o4x1 + - VALUE: 2 + SUFFIX: o4x2 - VALUE: 4 SUFFIX: o4x4 - - VALUE: 6 - SUFFIX: o4x6 shader_variants: - NAME: q_8w_linear_tiled_texture3d_texture3d_texture2d_texture2d_float - NAME: q_8w_linear_tiled_buffer_buffer_texture2d_texture2d_float diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt8.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt8.cpp index 4a10f469be0..d7156ebef90 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt8.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt8.cpp @@ -180,10 +180,10 @@ void add_q_8w_linear_tiled_node( std::vector mat1_sizes = graph.sizes_of(mat1); const int64_t M = utils::val_at(-2, mat1_sizes); - int out_tile_nrows = 4; + uint32_t out_tile_nrows = 4; if (M % 6 == 0) { - kernel_name += "_o4x6"; - out_tile_nrows = 6; + kernel_name += "_o4x2"; + out_tile_nrows = 2; } else if (M % 4 == 0) { kernel_name += "_o4x4"; out_tile_nrows = 4; @@ -195,8 +195,11 @@ void add_q_8w_linear_tiled_node( out_tile_nrows = 4; } - utils::uvec3 global_wg_size = graph.logical_limits_of(out); - global_wg_size[1] = global_wg_size[1] / out_tile_nrows; + utils::uvec3 out_limits = graph.logical_limits_of(out); + utils::uvec3 global_wg_size = { + out_limits[0] * (utils::div_up(out_limits[1], out_tile_nrows)), + 1, + out_limits[2]}; utils::uvec3 local_wg_size{64, 1, 1}; if (use_coop_algorithm) {