|
19 | 19 |
|
20 | 20 | ${define_active_storage_type(STORAGE)} |
21 | 21 |
|
22 | | -${define_required_extensions(DTYPE)} |
23 | | -${define_required_extensions("int8")} |
| 22 | +${define_required_extensions([DTYPE, "uint8", "uint16"])} |
| 23 | +#extension GL_EXT_control_flow_attributes : require |
24 | 24 |
|
25 | 25 | layout(std430) buffer; |
26 | 26 |
|
27 | | -${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)} |
28 | | -${layout_declare_tensor(1, "r", "t_mat1", DTYPE, STORAGE)} |
29 | | -${layout_declare_tensor(2, "r", "t_mat2", "int8", "buffer")} |
30 | | -${layout_declare_tensor(3, "r", "t_scales_and_zeros", DTYPE, STORAGE)} |
31 | | - |
32 | | -$if STORAGE == "texture3d": |
33 | | - ${layout_declare_ubo(4, "ivec4", "out_sizes")} |
34 | | - ${layout_declare_ubo(5, "ivec4", "mat1_sizes")} |
35 | | - ${layout_declare_ubo(6, "ivec4", "mat2_strides")} |
36 | | - ${layout_declare_ubo(7, "ivec4", "scales_strides")} |
37 | | -$else: |
38 | | - ${layout_declare_ubo(4, "ivec4", "out_sizes")} |
39 | | - ${layout_declare_ubo(5, "ivec4", "out_strides")} |
40 | | - ${layout_declare_ubo(6, "ivec4", "mat1_sizes")} |
41 | | - ${layout_declare_ubo(7, "ivec4", "mat1_strides")} |
42 | | - ${layout_declare_ubo(8, "ivec4", "mat2_strides")} |
43 | | - ${layout_declare_ubo(9, "ivec4", "scales_strides")} |
| 27 | +${layout_declare_tensor(B, "w", "ret", DTYPE, STORAGE)} |
| 28 | +${layout_declare_tensor(B, "r", "x", DTYPE, STORAGE)} |
| 29 | +${layout_declare_tensor(B, "r", "weights", "uint8", "buffer")} |
| 30 | +${layout_declare_tensor(B, "r", "qparams", DTYPE, STORAGE)} |
| 31 | +${layout_declare_ubo(B, "ivec3", "ret_limits")} |
| 32 | +${layout_declare_ubo(B, "ivec4", "x_sizes")} |
| 33 | +${layout_declare_ubo(B, "ivec4", "weights_strides")} |
| 34 | +${layout_declare_ubo(B, "ivec4", "qparams_strides")} |
44 | 35 |
|
45 | 36 | layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; |
46 | 37 |
|
47 | 38 | layout(constant_id = 3) const int group_size = 1; |
48 | 39 |
|
| 40 | +/* |
| 41 | + * This shader computes a linear operator between a floating point input matrix |
| 42 | + * x and a weights matrix that is quantized to 4 bits. |
| 43 | + * |
| 44 | + * The (W, H, C) shape of each tensor is: |
| 45 | + * - x: (K, M) |
| 46 | + * - weights: (K / 2, N) |
| 47 | + * - The weights tensor has a data type of `uint8`. Each element in the tensor |
| 48 | + * contains 2 4-bit values packed into a uint8. |
| 49 | + * - qparams: (2, N, number_of_groups) |
| 50 | + * - This tensor contains the scales and zeros quantization parameters for the |
| 51 | + * weights tensor. The weight tensor is quantized group-wise, which means |
| 52 | + * that every `group_size` elements along the K dimension of the weights |
| 53 | + * tensor has independent quantization parameters. Along the width dim, the |
| 54 | + * first value contains the scale for the group and the second value |
| 55 | + * contains the zero point for the group. |
| 56 | + * |
| 57 | + * Note that this shader assumes that all tensors are width packed. |
| 58 | + */ |
49 | 59 | void main() { |
50 | | - |
51 | | - const ivec4 out_pos = ivec4( |
52 | | - gl_GlobalInvocationID.x, // n = 0..N-1 |
53 | | - gl_GlobalInvocationID.y, // m = 0..M-1 |
54 | | - gl_GlobalInvocationID.z % out_sizes.z, |
55 | | - gl_GlobalInvocationID.z / out_sizes.z); |
56 | | - |
57 | | - if (any(greaterThanEqual(out_pos, out_sizes))) { |
58 | | - return; |
| 60 | + // output positions being calculated are (n, m), (n + 1, m), ... |
| 61 | + // This means multiplying the m-th row of x with the n-th, (n+1)-th, ... rows |
| 62 | + // of the weights tensor. |
| 63 | + const u16vec3 ret_pos = u16vec3(gl_GlobalInvocationID); |
| 64 | + if (any(greaterThanEqual(ret_pos, ret_limits))) { |
| 65 | + return; |
| 66 | + } |
| 67 | + |
| 68 | + // Since ret is width packed, need to multiply by 4 |
| 69 | + const uint16_t n = uint16_t(ret_pos.x * 4); |
| 70 | + |
| 71 | + // K is guaranteed to be a multiple of group size |
| 72 | + const uint16_t num_blocks = uint16_t(x_sizes.x / group_size); |
| 73 | + |
| 74 | + uint16_t k_texel_i = uint16_t(0); |
| 75 | + vec4 sums = vec4(0.0); |
| 76 | + for (uint16_t block_idx = uint16_t(0); block_idx < num_blocks; block_idx++) { |
| 77 | + vec4 scales; |
| 78 | + vec4 zeros; |
| 79 | + |
| 80 | + [[unroll]] for (int comp = 0; comp < 4; ++comp) { |
| 81 | + const vec4 scale_and_zero = load_texel( |
| 82 | + qparams, u16vec3(0, n + comp, block_idx)); |
| 83 | + scales[comp] = scale_and_zero.x; |
| 84 | + zeros[comp] = scale_and_zero.y; |
59 | 85 | } |
60 | 86 |
|
61 | | - const uint K = mat1_sizes.x; |
62 | | - const uint n = out_pos.x; |
63 | | - const uint m = out_pos.y; |
64 | | - const uint mask = uint(0x0f); |
65 | | - |
66 | | - float rc = 0.0; |
67 | | - int k = 0; |
68 | | - const uint k_block = (K + group_size - 1) / group_size; |
69 | | - |
70 | | - #ifdef USING_BUFFER |
71 | | - ivec4 mat1_pos = ivec4(0, m, out_pos.z, out_pos.w); |
72 | | - ivec4 mat2_pos = ivec4(0, n, out_pos.z, out_pos.w); |
73 | | - ivec4 scale_pos = ivec4(0, n, 0, out_pos.w); |
74 | | - ivec4 zero_pos = ivec4(0, n, 1, out_pos.w); |
75 | | - |
76 | | - for (int kb = 0; kb < k_block; kb++) { |
77 | | - scale_pos.x = kb; |
78 | | - const int scale_bufi = tidx_to_bufi(scale_pos, scales_strides); |
79 | | - const float scale = float(t_scales_and_zeros[scale_bufi]); |
80 | | - |
81 | | - zero_pos.x = kb; |
82 | | - const int zero_bufi = tidx_to_bufi(zero_pos, scales_strides); |
83 | | - const float zero = float(t_scales_and_zeros[zero_bufi]) - scale * 8.0; |
84 | | - |
85 | | - for(uint idx = 0; idx < group_size && k < K; idx++, k++) { |
86 | | - mat1_pos.x = k; |
87 | | - const int mat1_bufi = tidx_to_bufi(mat1_pos, mat1_strides); |
88 | | - const float mat1_val = float(t_mat1[mat1_bufi]); |
89 | | - |
90 | | - mat2_pos.x = k / 2; |
91 | | - const int mat2_bufi = tidx_to_bufi(mat2_pos, mat2_strides); |
92 | | - // Bitwise op treats sign bit from int8 as a value bit instead, |
93 | | - // since there is no uint8_t datatype |
94 | | - uint mat2_val = (t_mat2[mat2_bufi] & 0xFF); |
95 | | - mat2_val = (k & 1) == 0 ? mat2_val & mask : (mat2_val >> 4); |
96 | | - |
97 | | - rc += mat1_val * (scale * float(mat2_val) + zero); |
98 | | - } |
99 | | - } |
100 | | - |
101 | | - const int out_bufi = tidx_to_bufi(out_pos, out_strides); |
102 | | - t_out[out_bufi] = FLOAT_T(rc); |
103 | | - |
104 | | - #else // Using texture |
105 | | - ivec3 mat1_pos = ivec3(0, m, out_pos.z); |
106 | | - ivec4 mat2_pos = ivec4(0, n, out_pos.z, out_pos.w); |
107 | | - ivec3 scale_zero_pos = ivec3(0, n, 0); |
108 | | - uint K_texel = K / FOUR; |
109 | | - |
110 | | - for (int kb = 0; kb < k_block; kb++) { |
111 | | - scale_zero_pos.x = kb; |
112 | | - const vec4 scale_zero = load_texel(t_scales_and_zeros, scale_zero_pos); |
113 | | - const float scale = scale_zero.x; |
114 | | - const float zero = scale_zero.y - scale * 8.0; |
115 | | - |
116 | | - for(uint idx = 0; idx < group_size && k < K_texel; idx += FOUR, k++) { |
117 | | - mat1_pos.x = k; |
118 | | - const VEC4_T mat1_tex = load_texel(t_mat1, mat1_pos); |
119 | | - |
120 | | - mat2_pos.x = k * 2; // k * FOUR / 2 |
121 | | - const int mat2_id = tidx_to_bufi(mat2_pos, mat2_strides); |
122 | | - |
123 | | - for (int texel_pos = 0; texel_pos < FOUR; texel_pos++) { |
124 | | - // Bitwise op treats sign bit from int8 as a value bit instead, |
125 | | - // since there is no uint8_t datatype |
126 | | - uint mat2_val = (t_mat2[mat2_id + texel_pos / 2] & 0xFF); |
127 | | - mat2_val = (texel_pos & 1) == 0 ? mat2_val & mask : (mat2_val >> 4); |
128 | | - rc += mat1_tex[texel_pos] * (scale * float(mat2_val) + zero); |
129 | | - } |
130 | | - } |
| 87 | + for (uint16_t i = uint16_t(0); i < group_size; i += uint16_t(4), k_texel_i++) { |
| 88 | + const VEC4_T x_texel = load_texel( |
| 89 | + x, u16vec3(k_texel_i, ret_pos.y, ret_pos.z)); |
| 90 | + |
| 91 | + [[unroll]] for (int comp = 0; comp < 4; ++comp) { |
| 92 | + const int weights_bufi = (n + comp) * weights_strides.y + (k_texel_i * 2); |
| 93 | + // Need to read 4 unpacked values, which corresponds to 2 packed values |
| 94 | + const uint8_t weights_val_1 = weights[weights_bufi]; |
| 95 | + const uint8_t weights_val_2 = weights[weights_bufi + 1]; |
| 96 | + |
| 97 | + const u8vec4 weights_texel = u8vec4( |
| 98 | + (weights_val_1 & 0xF0) >> 4, |
| 99 | + weights_val_1 & 0x0F, |
| 100 | + (weights_val_2 & 0xF0) >> 4, |
| 101 | + weights_val_2 & 0x0F); |
| 102 | + |
| 103 | + // Note that the unpacked 4-bit values are unsigned, therefore they must |
| 104 | + // first be "centered" around 0 by subtracting 8 before applying the |
| 105 | + // scale and zero point. |
| 106 | + sums[comp] += dot( |
| 107 | + x_texel, (vec4(weights_texel) - 8.0) * scales[comp] + zeros[comp]); |
131 | 108 | } |
132 | | - write_texel(t_out, out_pos.xyz, vec4(rc, 0, 0, 0)); |
133 | | - |
134 | | - #endif |
| 109 | + } |
| 110 | + } |
| 111 | + write_texel(ret, ret_pos, sums); |
135 | 112 | } |
0 commit comments