File tree Expand file tree Collapse file tree 3 files changed +14
-12
lines changed
backends/vulkan/runtime/graph/ops Expand file tree Collapse file tree 3 files changed +14
-12
lines changed Original file line number Diff line number Diff line change @@ -38,18 +38,21 @@ layout(push_constant) uniform restrict Block {
3838 ivec4 weight_sizes;
3939};
4040
41+ #include "indexing_utils.h"
42+
4143layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
4244
4345shared VEC4_T partial_c[NGROUPS][NWORKERS][TILE_ROWS];
4446
4547void main() {
46- const uint out_row = gl_GlobalInvocationID.y * TILE_ROWS;
47- const uint out_col = gl_GlobalInvocationID.x << 2 ;
48+ const uint out_width_ntexels = divup4(out_sizes.x);
49+ const uint out_col = (gl_GlobalInvocationID.x % out_width_ntexels) << 2 ;
50+ const uint out_row = (gl_GlobalInvocationID.x / out_width_ntexels) * TILE_ROWS;
4851
4952 const int gid = int (gl_LocalInvocationID.x); // group id
5053 const int wid = int (gl_LocalInvocationID.z); // worker id
5154
52- if (out_col >= out_sizes.x || out_row >= out_sizes.y) {
55+ if (out_row >= out_sizes.y) {
5356 return ;
5457 }
5558
Original file line number Diff line number Diff line change @@ -41,9 +41,9 @@ layout(push_constant) uniform restrict Block {
4141layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
4242
4343void main() {
44- const uint out_size_x_div_4 = divup4(out_sizes.x);
45- const uint out_col = (gl_GlobalInvocationID.x % out_size_x_div_4 ) << 2 ;
46- const uint out_row = (gl_GlobalInvocationID.x / out_size_x_div_4 ) * TILE_ROWS;
44+ const uint out_width_ntexels = divup4(out_sizes.x);
45+ const uint out_col = (gl_GlobalInvocationID.x % out_width_ntexels ) << 2 ;
46+ const uint out_row = (gl_GlobalInvocationID.x / out_width_ntexels ) * TILE_ROWS;
4747
4848 if (out_row >= out_sizes.y) {
4949 return ;
Original file line number Diff line number Diff line change @@ -195,12 +195,11 @@ void add_q_8w_linear_tiled_node(
195195 out_tile_nrows = 4 ;
196196 }
197197
198- utils::uvec3 global_wg_size = graph.logical_limits_of (out);
199- global_wg_size[1 ] = global_wg_size[1 ] / out_tile_nrows;
200- if (!use_coop_algorithm) {
201- global_wg_size[0 ] *= global_wg_size[1 ];
202- global_wg_size[1 ] = 1 ;
203- }
198+ utils::uvec3 out_limits = graph.logical_limits_of (out);
199+ utils::uvec3 global_wg_size = {
200+ out_limits[0 ] * (utils::div_up (out_limits, out_tile_nrows)),
201+ 1 ,
202+ out_limit[2 ]};
204203
205204 utils::uvec3 local_wg_size{64 , 1 , 1 };
206205 if (use_coop_algorithm) {
You can’t perform that action at this time.
0 commit comments