Skip to content

Commit 4a1fcb2

Browse files
committed
[ET-VK] Reduced int precision for texture coordinates in conv2d_pw op, to reduce shader register pressure and slightly improve performance.
This diff reduces the precision of texture coordinates in the conv2d_pw op in Executorch Vulkan backend to reduce shader register pressure. Differential Revision: [D64766910](https://our.internmc.facebook.com/intern/diff/D64766910/) [ghstack-poisoned]
1 parent 793f17e commit 4a1fcb2

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,24 +32,26 @@ ${layout_declare_ubo(8, "float", "out_min", "float", "out_max")}
3232

3333
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3434

35+
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
36+
3537
/*
3638
* Computes a 2D pointwise convolution of an NxN output tile. Calculating an
3739
* output tile for pointwise convolution is more efficient because the kernel
3840
* size is only 1x1, making it easier to re-use loaded texels from t_kernel.
3941
*/
4042
void main() {
41-
const ivec3 gpos = ivec3(gl_GlobalInvocationID);
43+
const u16vec3 gpos = u16vec3(gl_GlobalInvocationID);
4244

4345
// Output position for TILE_SIZE = 2
4446
// +--------+--------+
4547
// | pos[0] | pos[1] |
4648
// +--------+--------+
4749
// | pos[2] | pos[3] |
4850
// +--------+--------+
49-
ivec3 pos[TILE_SIZE * TILE_SIZE];
51+
u16vec3 pos[TILE_SIZE * TILE_SIZE];
5052
for (int y = 0, i = 0; y < TILE_SIZE; ++y) {
5153
for (int x = 0; x < TILE_SIZE; ++x) {
52-
pos[i] = ivec3(
54+
pos[i] = u16vec3(
5355
gpos.x * TILE_SIZE + x, gpos.y * TILE_SIZE + y, gpos.z);
5456
i++;
5557
}
@@ -64,13 +66,13 @@ void main() {
6466
// Compute the index of the input texture that needs to be loaded for each
6567
// output position. Note that negative indices can be produced indicating that
6668
// the top-left element is in a region added by padding.
67-
ivec2 ipos[TILE_SIZE * TILE_SIZE];
69+
u16vec2 ipos[TILE_SIZE * TILE_SIZE];
6870
for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) {
69-
ipos[i] = pos[i].xy * stride - padding;
71+
ipos[i] = pos[i].xy * u16vec2(stride) - u16vec2(padding);
7072
}
7173

7274
vec4 sum[TILE_SIZE * TILE_SIZE];
73-
sum[0] = texelFetch(t_bias, ivec2(gpos.z, 0), 0);
75+
sum[0] = texelFetch(t_bias, u16vec2(gpos.z, 0), 0);
7476
for (int i = 1; i < TILE_SIZE * TILE_SIZE; ++i) {
7577
sum[i] = sum[0];
7678
}
@@ -81,13 +83,13 @@ void main() {
8183
// channel (IC) dim is along the x-axis, and the batch (OC) dim is along
8284
// the z-axis.
8385
vec4 in_tex[TILE_SIZE * TILE_SIZE];
84-
const vec4 ktex_0 = texelFetch(t_kernel, ivec2(z + 0, gpos.z), 0);
85-
const vec4 ktex_1 = texelFetch(t_kernel, ivec2(z + 1, gpos.z), 0);
86-
const vec4 ktex_2 = texelFetch(t_kernel, ivec2(z + 2, gpos.z), 0);
87-
const vec4 ktex_3 = texelFetch(t_kernel, ivec2(z + 3, gpos.z), 0);
86+
const vec4 ktex_0 = texelFetch(t_kernel, u16vec2(z + 0, gpos.z), 0);
87+
const vec4 ktex_1 = texelFetch(t_kernel, u16vec2(z + 1, gpos.z), 0);
88+
const vec4 ktex_2 = texelFetch(t_kernel, u16vec2(z + 2, gpos.z), 0);
89+
const vec4 ktex_3 = texelFetch(t_kernel, u16vec2(z + 3, gpos.z), 0);
8890

8991
for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) {
90-
in_tex[i] = texelFetch(t_in, ivec3(ipos[i], z4), 0);
92+
in_tex[i] = texelFetch(t_in, u16vec3(ipos[i], z4), 0);
9193
}
9294

9395
for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) {

0 commit comments

Comments
 (0)