diff --git a/backends/vulkan/runtime/graph/ops/glsl/permute.glsl b/backends/vulkan/runtime/graph/ops/glsl/permute.glsl index 8414d811fc8..5378099d03f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/permute.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/permute.glsl @@ -36,8 +36,10 @@ layout(set = 0, binding = 4) uniform PRECISION restrict Block { layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require + void main() { - const ivec3 pos = ivec3(gl_GlobalInvocationID); + const u16vec3 pos = u16vec3(gl_GlobalInvocationID); if (any(greaterThanEqual(pos, out_limits))) { return; @@ -46,28 +48,34 @@ void main() { const int out_channel_4up = int(ch_info.x); const int in_channel_4up = int(ch_info.y); const int out_batch = int(sizes[3]); - const int max_dst_index = out_batch * out_channel_4up; VEC4_T outval = VEC4_T(0.0); + ivec4 v = ivec4(0); // holds b,c,h,w + + v[out_ndims[2]] = pos.y; + v[out_ndims[3]] = pos.x; + + const int dst_index = pos.z << 2; + int dst_out_index = dst_index / out_channel_4up; + int dst_out_lane = dst_index % out_channel_4up; - for (int j = 0; j < 4; ++j) { - int dst_index = pos.z * 4 + j; - if (dst_index >= max_dst_index) { + for (int j = 0; j < 4; ++j, ++dst_out_lane) { + if (dst_out_index >= out_batch) { // out of range break; } - ivec4 v = ivec4(0); // holds b,c,h,w - v[out_ndims[0]] = dst_index / out_channel_4up; - v[out_ndims[1]] = dst_index % out_channel_4up; - v[out_ndims[2]] = pos.y; - v[out_ndims[3]] = pos.x; + if (dst_out_lane == out_channel_4up) { + dst_out_lane = 0; + dst_out_index++; + } + + v[out_ndims[0]] = dst_out_index; + v[out_ndims[1]] = dst_out_lane; int src_index = v[0] * in_channel_4up + v[1]; - int w = v[3]; - int h = v[2]; - VEC4_T inval = VEC4_T(texelFetch(image_in, ivec3(w, h, src_index / 4), 0)); - outval[j] = inval[src_index % 4]; + VEC4_T inval = VEC4_T(texelFetch(image_in, u16vec3(v[3], v[2], src_index >> 2), 0)); + outval[j] = inval[src_index & 0x3]; } imageStore(image_out, pos, outval);