1313#define T ${buffer_scalar_type(DTYPE)}
1414#define VEC4_T ${buffer_gvec_type(DTYPE, 4 )}
1515
16+ #define TILE_ROWS ${TILE_ROWS}
17+
1618${define_required_extensions(DTYPE)}
1719$if WEIGHT_STORAGE == "buffer ":
1820 ${define_required_extensions("uint8")}
1921
22+ #extension GL_EXT_control_flow_attributes : require
23+
2024layout (std430) buffer ;
2125
2226${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array= False)}
@@ -53,10 +57,12 @@ layout(constant_id = 3) const int group_size = 64;
5357 * first value contains the scale for the group and the second value
5458 * contains the zero point for the group.
5559 *
60+ * Each thread computes a tile of TILE_ROWS * 2 texels of the output tensor.
61+ *
5662 * Note that this shader assumes that all tensors are width packed.
5763 */
5864void main() {
59- const uint out_row = gl_GlobalInvocationID.y;
65+ const uint out_row = gl_GlobalInvocationID.y * TILE_ROWS ;
6066 // Each thread writes out 2 texels along the width axis, equivalent to 8
6167 // scalar elements. Therefore multiply the thread_idx.x by 8.
6268 const uint out_col = gl_GlobalInvocationID.x << 3 ;
@@ -70,10 +76,14 @@ void main() {
7076
7177 const int num_blocks = mat1_sizes.x / group_size;
7278
73- VEC4_T sums[2 ];
79+ VEC4_T mat1[TILE_ROWS];
80+ VEC4_T qmat2[4 ][2 ];
81+ VEC4_T sums[TILE_ROWS][2 ];
7482
75- sums[0 ] = VEC4_T(0 );
76- sums[1 ] = VEC4_T(0 );
83+ [[unroll]] for (int r = 0 ; r < TILE_ROWS; ++ r) {
84+ sums[r][0 ] = VEC4_T(0 );
85+ sums[r][1 ] = VEC4_T(0 );
86+ }
7787
7888 VEC4_T scales[2 ];
7989 VEC4_T zeros[2 ];
@@ -101,33 +111,51 @@ void main() {
101111 for (int g_idx = 0 ; g_idx < group_size; g_idx += 4 ) {
102112 const int k = block_idx * group_size + g_idx;
103113
104- $if IN_STORAGE == "buffer ":
105- const VEC4_T mat1_tex = t_mat1[(out_row * mat1_sizes.x + k) >> 2 ];
106- $else :
107- const VEC4_T mat1_tex = texelFetch(t_mat1, ivec3 (k >> 2 , out_row, 0 ), 0 );
108-
109- for (int comp = 0 ; comp < 4 ; ++ comp) {
114+ // Preload B
115+ [[unroll]] for (int r = 0 ; r < 4 ; ++ r) {
110116 $if WEIGHT_STORAGE == "buffer ":
111- const u8vec4 packed_weight_tex = t_qmat2[(k + comp ) * qmat2_stride + gl_GlobalInvocationID.x];
117+ const u8vec4 packed_weight_tex = t_qmat2[(k + r ) * qmat2_stride + gl_GlobalInvocationID.x];
112118 $else :
113119 const uvec4 packed_weight_tex = texelFetch(
114120 t_qmat2,
115- ivec2 (gl_GlobalInvocationID.x, k + comp ),
121+ ivec2 (gl_GlobalInvocationID.x, k + r ),
116122 0 );
117123
118- const uvec4 weight_tex_1 = (packed_weight_tex & 0xF0) >> 4 ;
119- const uvec4 weight_tex_2 = packed_weight_tex & 0x0F;
124+ qmat2[r][0 ] = (VEC4_T((packed_weight_tex & 0xF0) >> 4 ) - 8.0 ) * scales[0 ] + zeros[0 ];
125+ qmat2[r][1 ] = (VEC4_T(packed_weight_tex & 0x0F) - 8.0 ) * scales[1 ] + zeros[1 ];
126+ }
127+
128+ // Preload A
129+ [[unroll]] for (int r = 0 ; r < TILE_ROWS; ++ r) {
130+ $if IN_STORAGE == "buffer ":
131+ mat1[r] = t_mat1[((out_row + r) * mat1_sizes.x + k) >> 2 ];
132+ $else :
133+ mat1[r] = texelFetch(t_mat1, ivec3 (k >> 2 , out_row + r, 0 ), 0 );
134+ }
120135
121- sums[0 ] += mat1_tex[comp] * ((vec4 (weight_tex_1) - 8.0 ) * scales[0 ] + zeros[0 ]);
122- sums[1 ] += mat1_tex[comp] * ((vec4 (weight_tex_2) - 8.0 ) * scales[1 ] + zeros[1 ]);
136+ // Accumulate output tile
137+ [[unroll]] for (int r = 0 ; r < TILE_ROWS; ++ r) {
138+ sums[r][0 ] += mat1[r].x * qmat2[0 ][0 ]
139+ + mat1[r].y * qmat2[1 ][0 ]
140+ + mat1[r].z * qmat2[2 ][0 ]
141+ + mat1[r].w * qmat2[3 ][0 ];
142+
143+ sums[r][1 ] += mat1[r].x * qmat2[0 ][1 ]
144+ + mat1[r].y * qmat2[1 ][1 ]
145+ + mat1[r].z * qmat2[2 ][1 ]
146+ + mat1[r].w * qmat2[3 ][1 ];
123147 }
124148 }
125149 }
126150
127- $if OUT_STORAGE == "buffer ":
128- t_out[(out_row * out_sizes.x + out_col) >> 2 ] = sums[0 ];
129- t_out[(out_row * out_sizes.x + out_col + 4 ) >> 2 ] = sums[1 ];
130- $else :
131- imageStore(t_out, ivec3 (out_col_texel_idx, out_row, 0 ), sums[0 ]);
132- imageStore(t_out, ivec3 (out_col_texel_idx + 1 , out_row, 0 ), sums[1 ]);
151+ [[unroll]] for (int r = 0 ; r < TILE_ROWS; ++ r) {
152+ $if OUT_STORAGE == "buffer ":
153+ if (out_row + r < out_sizes.y) {
154+ t_out[((out_row + r) * out_sizes.x + out_col) >> 2 ] = sums[r][0 ];
155+ t_out[((out_row + r) * out_sizes.x + out_col + 4 ) >> 2 ] = sums[r][1 ];
156+ }
157+ $else :
158+ imageStore(t_out, ivec3 (out_col_texel_idx, out_row + r, 0 ), sums[r][0 ]);
159+ imageStore(t_out, ivec3 (out_col_texel_idx + 1 , out_row + r, 0 ), sums[r][1 ]);
160+ }
133161}
0 commit comments