88
99#version 450 core
1010
11- #include "indexing_utils.h"
12-
1311#define PRECISION ${PRECISION}
1412
15- #define FOUR 4
16-
17- #define VEC4_T ${texel_load_type(DTYPE, STORAGE)}
18- #define FLOAT_T ${buffer_scalar_type(DTYPE)}
13+ #define T ${buffer_scalar_type(DTYPE)}
14+ #define VEC4_T ${buffer_gvec_type(DTYPE, 4 )}
1915
20- ${define_active_storage_type(STORAGE)}
21-
22- ${define_required_extensions([DTYPE, "uint8", "uint16"])}
23- #extension GL_EXT_control_flow_attributes : require
16+ ${define_required_extensions(DTYPE)}
17+ ${define_required_extensions("int8")}
2418
2519layout (std430) buffer ;
2620
27- ${layout_declare_tensor(B, "w", "ret", DTYPE, STORAGE)}
28- ${layout_declare_tensor(B, "r", "x", DTYPE, STORAGE)}
29- ${layout_declare_tensor(B, "r", "weights", "uint8", "buffer ")}
30- ${layout_declare_tensor(B, "r", "qparams", DTYPE, STORAGE)}
31- ${layout_declare_ubo(B, "ivec3 ", "ret_limits")}
32- ${layout_declare_ubo(B, "ivec4 ", "x_sizes")}
33- ${layout_declare_ubo(B, "ivec4 ", "weights_strides")}
34- ${layout_declare_ubo(B, "ivec4 ", "qparams_strides")}
21+ ${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array= False)}
22+ ${layout_declare_tensor(B, "r", "t_mat1", DTYPE, IN_STORAGE, is_scalar_array= False)}
23+ ${layout_declare_tensor(B, "r", "t_qmat2", "uint8", WEIGHT_STORAGE, is_scalar_array= False)}
24+ ${layout_declare_tensor(B, "r", "t_qparams", DTYPE, "texture3D ")}
25+
26+ layout (push_constant) uniform restrict Block {
27+ ivec4 out_sizes;
28+ ivec4 mat1_sizes;
29+ ivec4 qmat2_sizes;
30+ };
3531
3632layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
3733
38- layout (constant_id = 3 ) const int group_size = 1 ;
34+ layout (constant_id = 3 ) const int group_size = 64 ;
3935
4036/*
4137 * This shader computes a linear operator between a floating point input matrix
4238 * x and a weights matrix that is quantized to 4 bits.
4339 *
4440 * The (W, H, C) shape of each tensor is:
4541 * - x: (K, M)
46- * - weights: (K / 2, N )
42+ * - weights: (N / 2, K )
4743 * - The weights tensor has a data type of `uint8`. Each element in the tensor
4844 * contains 2 4-bit values packed into a uint8.
45+ * - See the pack_int4_linear_weight_transposed_interleave shader to see more
46+ * details on how the weight tensor is stored.
4947 * - qparams: (2, N, number_of_groups)
5048 * - This tensor contains the scales and zeros quantization parameters for the
5149 * weights tensor. The weight tensor is quantized group-wise, which means
@@ -57,56 +55,68 @@ layout(constant_id = 3) const int group_size = 1;
5755 * Note that this shader assumes that all tensors are width packed.
5856 */
5957void main() {
60- // output positions being calculated are (n, m), (n + 1, m), ...
61- // This means multiplying the m-th row of x with the n-th, (n+1)-th, ... rows
62- // of the weights tensor.
63- const u16vec3 ret_pos = u16vec3(gl_GlobalInvocationID);
64- if (any (greaterThanEqual (ret_pos, ret_limits))) {
58+ const uint out_row = gl_GlobalInvocationID.y;
59+ // Each thread writes out 2 texels along the width axis, equivalent to 8
60+ // scalar elements. Therefore multiply the thread_idx.x by 8.
61+ const uint out_col = gl_GlobalInvocationID.x << 3 ;
62+ // Similar reasoning to the above, each thread works on 2 texels along the
63+ // width axis so multiply thread_idx.x by 2.
64+ const int out_col_texel_idx = int (gl_GlobalInvocationID.x) << 1 ;
65+
66+ if (out_col >= out_sizes.x || out_row >= out_sizes.y) {
6567 return ;
6668 }
6769
68- // Since ret is width packed, need to multiply by 4
69- const uint16_t n = uint16_t(ret_pos.x * 4 );
70+ const int num_blocks = mat1_sizes.x / group_size;
7071
71- // K is guaranteed to be a multiple of group size
72- const uint16_t num_blocks = uint16_t(x_sizes.x / group_size);
72+ VEC4_T sums[2 ];
7373
74- uint16_t k_texel_i = uint16_t(0 );
75- vec4 sums = vec4 (0.0 );
76- for (uint16_t block_idx = uint16_t(0 ); block_idx < num_blocks; block_idx++ ) {
77- vec4 scales;
78- vec4 zeros;
74+ sums[0 ] = VEC4_T(0 );
75+ sums[1 ] = VEC4_T(0 );
7976
80- [[unroll]] for (int comp = 0 ; comp < 4 ; comp++ ) {
81- const vec4 scale_and_zero = load_texel(
82- qparams, u16vec3(0 , n + comp, block_idx));
83- scales[comp] = scale_and_zero.x;
84- zeros[comp] = scale_and_zero.y;
85- }
77+ VEC4_T scales[2 ];
78+ VEC4_T zeros[2 ];
79+
80+ $if WEIGHT_STORAGE == "buffer ":
81+ const int qmat2_stride = qmat2_sizes.x >> 2 ;
82+
83+ for (int block_idx = 0 ; block_idx < num_blocks; ++ block_idx) {
84+ scales[0 ] = texelFetch(t_qparams, ivec3 (out_col_texel_idx, 0 , block_idx), 0 );
85+ zeros[0 ] = texelFetch(t_qparams, ivec3 (out_col_texel_idx, 1 , block_idx), 0 );
8686
87- for (uint16_t i = uint16_t(0 ); i < group_size; i += uint16_t(4 ), k_texel_i++ ) {
88- const VEC4_T x_texel = load_texel(
89- x, u16vec3(k_texel_i, ret_pos.y, ret_pos.z));
90-
91- [[unroll]] for (int comp = 0 ; comp < 4 ; comp++ ) {
92- const int weights_bufi = (n + comp) * weights_strides.y + (k_texel_i * 2 );
93- // Need to read 4 unpacked values, which corresponds to 2 packed values
94- const uint8_t weights_val_1 = weights[weights_bufi];
95- const uint8_t weights_val_2 = weights[weights_bufi + 1 ];
96-
97- const u8vec4 weights_texel = u8vec4(
98- (weights_val_1 & 0xF0) >> 4 ,
99- weights_val_1 & 0x0F,
100- (weights_val_2 & 0xF0) >> 4 ,
101- weights_val_2 & 0x0F);
102-
103- // Note that the unpacked 4-bit values are unsigned, therefore they must
104- // first be "centered" around 0 by subtracting 8 before applying the
105- // scale and zero point.
106- sums[comp] += dot (
107- x_texel, (vec4 (weights_texel) - 8.0 ) * scales[comp] + zeros[comp]);
87+ scales[1 ] = texelFetch(t_qparams, ivec3 (out_col_texel_idx + 1 , 0 , block_idx), 0 );
88+ zeros[1 ] = texelFetch(t_qparams, ivec3 (out_col_texel_idx + 1 , 1 , block_idx), 0 );
89+
90+ for (int g_idx = 0 ; g_idx < group_size; g_idx += 4 ) {
91+ const int k = block_idx * group_size + g_idx;
92+
93+ $if IN_STORAGE == "buffer ":
94+ const VEC4_T mat1_tex = t_mat1[(out_row * mat1_sizes.x + k) >> 2 ];
95+ $else :
96+ const VEC4_T mat1_tex = texelFetch(t_mat1, ivec3 (k >> 2 , out_row, 0 ), 0 );
97+
98+ for (int comp = 0 ; comp < 4 ; ++ comp) {
99+ $if WEIGHT_STORAGE == "buffer ":
100+ const u8vec4 packed_weight_tex = t_qmat2[(k + comp) * qmat2_stride + gl_GlobalInvocationID.x];
101+ $else :
102+ const uvec4 packed_weight_tex = texelFetch(
103+ t_qmat2,
104+ ivec3 (gl_GlobalInvocationID.x, k + comp, 0 ),
105+ 0 );
106+
107+ const uvec4 weight_tex_1 = (packed_weight_tex & 0xF0) >> 4 ;
108+ const uvec4 weight_tex_2 = packed_weight_tex & 0x0F;
109+
110+ sums[0 ] += mat1_tex[comp] * ((vec4 (weight_tex_1) - 8.0 ) * scales[0 ] + zeros[0 ]);
111+ sums[1 ] += mat1_tex[comp] * ((vec4 (weight_tex_2) - 8.0 ) * scales[1 ] + zeros[1 ]);
108112 }
109113 }
110114 }
111- write_texel(ret, ret_pos, sums);
115+
116+ $if OUT_STORAGE == "buffer ":
117+ t_out[(out_row * out_sizes.x + out_col) >> 2 ] = sums[0 ];
118+ t_out[(out_row * out_sizes.x + out_col + 4 ) >> 2 ] = sums[1 ];
119+ $else :
120+ imageStore(t_out, ivec3 (out_col_texel_idx, out_row, 0 ), sums[0 ]);
121+ imageStore(t_out, ivec3 (out_col_texel_idx + 1 , out_row, 0 ), sums[1 ]);
112122}
0 commit comments