From 70e09cd7250594bffdd59ccc1fa46d4a9a821d7e Mon Sep 17 00:00:00 2001 From: Vivek Trivedi <5340687+trivedivivek@users.noreply.github.com> Date: Mon, 26 May 2025 21:41:08 -0700 Subject: [PATCH] [ET-VK] Reducing precision of some in members in conv2d pw to improved performance. Reducing precision of some in members in conv2d pw to improve performance. Differential Revision: [D75423958](https://our.internmc.facebook.com/intern/diff/D75423958/) [ghstack-poisoned] --- .../graph/ops/glsl/conv2d_pw_s1p0.glsl | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl index 9d424683077..9b5707ce073 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl @@ -50,12 +50,13 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; void main() { const int out_limits_scaled[2] = {out_limits.x + (TILE_SIZE_X - 1) * TILE_SIZE_X, out_limits.y + (TILE_SIZE_Y - 1) * TILE_SIZE_Y}; - const int div_by_x = int(gl_GlobalInvocationID.x / out_limits_scaled[0]); - const int out_pos[3] = {int(gl_GlobalInvocationID.x % out_limits_scaled[0]), div_by_x, int(gl_GlobalInvocationID.y)}; + const uint16_t div_by_x = uint16_t(gl_GlobalInvocationID.x / out_limits_scaled[0]); + const uint16_t out_pos_xy[2] = {uint16_t(gl_GlobalInvocationID.x % out_limits_scaled[0]), div_by_x}; + const int out_pos_z = int(gl_GlobalInvocationID.y); // If the top left position is out of bounds, then this invocation will have // no work to do. - if (out_pos[1] >= out_limits_scaled[1] || out_pos[2] >= out_limits.z) { + if (out_pos_xy[1] >= out_limits_scaled[1] || out_pos_z >= out_limits.z) { return; } @@ -68,8 +69,8 @@ void main() { uint16_t pos[TILE_SIZE_X * TILE_SIZE_Y * 2]; for (uint16_t y = 0us, i = 0us; y < TILE_SIZE_Y; ++y) { for (uint16_t x = 0us; x < TILE_SIZE_X; ++x) { - pos[i * 2] = uint16_t(out_pos[0]) * TILE_SIZE_X + x; - pos[i * 2 + 1] = uint16_t(out_pos[1]) * TILE_SIZE_Y + y; + pos[i * 2] = out_pos_xy[0] * TILE_SIZE_X + x; + pos[i * 2 + 1] = out_pos_xy[1] * TILE_SIZE_Y + y; i++; } } @@ -78,7 +79,7 @@ void main() { // Tuple of consecutive 4 elements represents a single output texel. float sum[TILE_SIZE_X * TILE_SIZE_Y * 4]; - const vec4 bias = texelFetch(t_bias, ivec2(out_pos[2], 0), 0); + const vec4 bias = texelFetch(t_bias, ivec2(out_pos_z, 0), 0); // Initialize the output array with the bias value for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y * 4; i += 4) { @@ -98,7 +99,7 @@ void main() { // Load kernel values from texels to array [[unroll]] for (int i = 0; i < 4; ++i) { - const vec4 k_tex = texelFetch(t_kernel, ivec2(z + i, out_pos[2]), 0); + const vec4 k_tex = texelFetch(t_kernel, ivec2(z + i, out_pos_z), 0); kernel_values[i * 4 + 0] = k_tex.x; kernel_values[i * 4 + 1] = k_tex.y; kernel_values[i * 4 + 2] = k_tex.z; @@ -157,8 +158,8 @@ void main() { } for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) { - const ivec3 pos_l = ivec3(pos[i * 2], pos[i * 2 + 1], out_pos[2]); - if (all(lessThan(pos_l, out_limits.xyz))) { + const ivec3 pos_l = ivec3(pos[i * 2], pos[i * 2 + 1], out_pos_z); + if (all(lessThan(pos_l.xy, out_limits.xy))) { imageStore(t_out, pos_l, op(vec4(sum[i * 4], sum[i * 4 + 1], sum[i * 4 + 2], sum[i * 4 + 3]), out_min, out_max)); } }