|
10 | 10 |
|
11 | 11 | #define PRECISION ${PRECISION}
|
12 | 12 |
|
13 |
| -#define T ${buffer_scalar_type(DTYPE)} |
14 |
| -#define VEC4_T ${buffer_gvec_type(DTYPE, 4)} |
15 |
| - |
16 |
| -#define TILE_ROWS ${TILE_ROWS} |
| 13 | +#define T ${texel_load_component_type(DTYPE, IO_STORAGE)} |
| 14 | +#define VEC4_T ${texel_load_type(DTYPE, IO_STORAGE)} |
17 | 15 |
|
18 | 16 | ${define_required_extensions(DTYPE)}
|
19 |
| -$if WEIGHT_STORAGE == "buffer": |
20 |
| - ${define_required_extensions("uint8")} |
21 |
| - |
22 |
| -#extension GL_EXT_control_flow_attributes : require |
23 | 17 |
|
24 | 18 | layout(std430) buffer;
|
25 | 19 |
|
26 |
| -${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)} |
27 |
| -${layout_declare_tensor(B, "r", "t_mat1", DTYPE, IN_STORAGE, is_scalar_array=False)} |
28 |
| -${layout_declare_tensor(B, "r", "t_qmat2", "uint8", WEIGHT_STORAGE, is_scalar_array=False)} |
| 20 | +#include "indexing_utils.h" |
| 21 | + |
| 22 | +${layout_declare_tensor(B, "w", "t_output", DTYPE, IO_STORAGE, is_scalar_array=False)} |
| 23 | +${layout_declare_tensor(B, "r", "t_input", DTYPE, IO_STORAGE, is_scalar_array=False)} |
| 24 | +${layout_declare_tensor(B, "r", "t_qmat2", "uint", WEIGHT_STORAGE, is_scalar_array=False)} |
29 | 25 | ${layout_declare_tensor(B, "r", "t_qparams", DTYPE, "buffer", is_scalar_array=False)}
|
30 | 26 |
|
31 | 27 | layout(push_constant) uniform restrict Block {
|
32 |
| - ivec4 out_sizes; |
33 |
| - ivec4 mat1_sizes; |
34 |
| - ivec4 qmat2_sizes; |
| 28 | + ivec4 output_sizes; |
| 29 | + ivec4 input_sizes; |
| 30 | + ivec4 weight_sizes; |
35 | 31 | };
|
36 | 32 |
|
37 | 33 | layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
|
38 | 34 |
|
39 | 35 | layout(constant_id = 3) const int group_size = 64;
|
40 | 36 |
|
41 |
| -/* |
42 |
| - * This shader computes a linear operator between a floating point input matrix |
43 |
| - * x and a weights matrix that is quantized to 4 bits. |
44 |
| - * |
45 |
| - * The (W, H, C) shape of each tensor is: |
46 |
| - * - x: (K, M) |
47 |
| - * - weights: (N / 2, K) |
48 |
| - * - The weights tensor has a data type of `uint8`. Each element in the tensor |
49 |
| - * contains 2 4-bit values packed into a uint8. |
50 |
| - * - See the pack_int4_linear_weight_transposed_interleave shader to see more |
51 |
| - * details on how the weight tensor is stored. |
52 |
| - * - qparams: (2, N, number_of_groups) |
53 |
| - * - This tensor contains the scales and zeros quantization parameters for the |
54 |
| - * weights tensor. The weight tensor is quantized group-wise, which means |
55 |
| - * that every `group_size` elements along the K dimension of the weights |
56 |
| - * tensor has independent quantization parameters. Along the width dim, the |
57 |
| - * first value contains the scale for the group and the second value |
58 |
| - * contains the zero point for the group. |
59 |
| - * |
60 |
| - * Each thread computes a tile of TILE_ROWS * 2 texels of the output tensor. |
61 |
| - * |
62 |
| - * Note that this shader assumes that all tensors are width packed. |
63 |
| - */ |
| 37 | +$if IO_STORAGE == "buffer": |
| 38 | + #define BUFFER_IO |
| 39 | +$if WEIGHT_STORAGE == "buffer": |
| 40 | + #define BUFFER_WEIGHT |
| 41 | + |
| 42 | +#include "qlinear_utils.glslh" |
| 43 | + |
64 | 44 | void main() {
|
65 |
| - const uint out_row = gl_GlobalInvocationID.y * TILE_ROWS; |
66 |
| - // Each thread writes out 2 texels along the width axis, equivalent to 8 |
67 |
| - // scalar elements. Therefore multiply the thread_idx.x by 8. |
68 |
| - const uint out_col = gl_GlobalInvocationID.x << 3; |
69 |
| - // Similar reasoning to the above, each thread works on 2 texels along the |
70 |
| - // width axis so multiply thread_idx.x by 2. |
71 |
| - const int out_col_texel_idx = int(gl_GlobalInvocationID.x) << 1; |
72 |
| - |
73 |
| - if (out_col >= out_sizes.x || out_row >= out_sizes.y) { |
| 45 | + // Each thread writes out a 8 wide x 4 high tile of output values |
| 46 | + const uint n8 = gl_GlobalInvocationID.x; |
| 47 | + const uint m4 = gl_GlobalInvocationID.y; |
| 48 | + |
| 49 | + const uint n = MUL_8(n8); // output col idx |
| 50 | + const uint m = MUL_4(m4); // output row idx |
| 51 | + const uint n4 = MUL_2(n8); // output col texel idx |
| 52 | + |
| 53 | + const uint group_num = input_sizes.x / group_size; |
| 54 | + const uint group_ntexels = DIV_UP_4(group_size); |
| 55 | + |
| 56 | + if (n >= output_sizes.x || m >= output_sizes.y) { |
74 | 57 | return;
|
75 | 58 | }
|
76 | 59 |
|
77 |
| - const int num_blocks = mat1_sizes.x / group_size; |
| 60 | + const uint K4 = DIV_UP_4(input_sizes.x); |
| 61 | + const uint N4 = DIV_UP_4(output_sizes.x); // number of texels in each row |
78 | 62 |
|
79 |
| - VEC4_T mat1[TILE_ROWS]; |
80 |
| - VEC4_T qmat2[4][2]; |
81 |
| - VEC4_T sums[TILE_ROWS][2]; |
| 63 | + VEC4_T out_texels[4][2]; |
| 64 | + // Initialize to 0 |
| 65 | + $for row_i in range(4): |
| 66 | + $for col_i in range(2): |
| 67 | + out_texels[${row_i}][${col_i}] = VEC4_T(0.00); |
82 | 68 |
|
83 |
| - [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { |
84 |
| - sums[r][0] = VEC4_T(0); |
85 |
| - sums[r][1] = VEC4_T(0); |
86 |
| - } |
| 69 | + for (uint group_i = 0; group_i < group_num; ++group_i) { |
| 70 | + // Load quantization scales and zeros for the current group |
| 71 | + VEC4_T scales[2]; |
| 72 | + VEC4_T zeros[2]; |
| 73 | + { |
| 74 | + uint qparams_bufi = group_i * DIV_2(output_sizes.x) + DIV_2(n); |
87 | 75 |
|
88 |
| - VEC4_T scales[2]; |
89 |
| - VEC4_T zeros[2]; |
90 |
| - |
91 |
| - $if WEIGHT_STORAGE == "buffer": |
92 |
| - const int qmat2_stride = qmat2_sizes.x >> 2; |
93 |
| - $if PARAMS_STORAGE == "buffer": |
94 |
| - const int qparams_y_stride = out_sizes.x >> 2; |
95 |
| - const int qparams_z_stride = qparams_y_stride * 2; |
96 |
| - |
97 |
| - for (int block_idx = 0; block_idx < num_blocks; ++block_idx) { |
98 |
| - $if PARAMS_STORAGE == "buffer": |
99 |
| - scales[0] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx]; |
100 |
| - zeros[0] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + qparams_y_stride]; |
101 |
| - |
102 |
| - scales[1] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + 1]; |
103 |
| - zeros[1] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + 1 + qparams_y_stride]; |
104 |
| - $else: |
105 |
| - scales[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 0, block_idx), 0); |
106 |
| - zeros[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 1, block_idx), 0); |
107 |
| - |
108 |
| - scales[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 0, block_idx), 0); |
109 |
| - zeros[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 1, block_idx), 0); |
110 |
| - |
111 |
| - for (int g_idx = 0; g_idx < group_size; g_idx += 4) { |
112 |
| - const int k = block_idx * group_size + g_idx; |
113 |
| - |
114 |
| - // Preload B |
115 |
| - [[unroll]] for (int r = 0; r < 4; ++r) { |
116 |
| - $if WEIGHT_STORAGE == "buffer": |
117 |
| - const u8vec4 packed_weight_tex = t_qmat2[(k + r) * qmat2_stride + gl_GlobalInvocationID.x]; |
118 |
| - $else: |
119 |
| - const uvec4 packed_weight_tex = texelFetch( |
120 |
| - t_qmat2, |
121 |
| - ivec2(gl_GlobalInvocationID.x, k + r), |
122 |
| - 0); |
123 |
| - |
124 |
| - qmat2[r][0] = (VEC4_T((packed_weight_tex & 0xF0) >> 4) - 8.0) * scales[0] + zeros[0]; |
125 |
| - qmat2[r][1] = (VEC4_T(packed_weight_tex & 0x0F) - 8.0) * scales[1] + zeros[1]; |
126 |
| - } |
127 |
| - |
128 |
| - // Preload A |
129 |
| - [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { |
130 |
| - $if IN_STORAGE == "buffer": |
131 |
| - mat1[r] = t_mat1[((out_row + r) * mat1_sizes.x + k) >> 2]; |
132 |
| - $else: |
133 |
| - mat1[r] = texelFetch(t_mat1, ivec3(k >> 2, out_row + r, 0), 0); |
134 |
| - } |
135 |
| - |
136 |
| - // Accumulate output tile |
137 |
| - [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { |
138 |
| - sums[r][0] += mat1[r].x * qmat2[0][0] |
139 |
| - + mat1[r].y * qmat2[1][0] |
140 |
| - + mat1[r].z * qmat2[2][0] |
141 |
| - + mat1[r].w * qmat2[3][0]; |
142 |
| - |
143 |
| - sums[r][1] += mat1[r].x * qmat2[0][1] |
144 |
| - + mat1[r].y * qmat2[1][1] |
145 |
| - + mat1[r].z * qmat2[2][1] |
146 |
| - + mat1[r].w * qmat2[3][1]; |
147 |
| - } |
| 76 | + VEC4_T scales_zeros_texels[4]; |
| 77 | + $for comp in range(4): |
| 78 | + scales_zeros_texels[${comp}] = t_qparams[qparams_bufi++]; |
| 79 | + |
| 80 | + scales[0] = VEC4_T(scales_zeros_texels[0].xz, scales_zeros_texels[1].xz); |
| 81 | + zeros[0] = VEC4_T(scales_zeros_texels[0].yw, scales_zeros_texels[1].yw); |
| 82 | + |
| 83 | + scales[1] = VEC4_T(scales_zeros_texels[2].xz, scales_zeros_texels[3].xz); |
| 84 | + zeros[1] = VEC4_T(scales_zeros_texels[2].yw, scales_zeros_texels[3].yw); |
| 85 | + } |
| 86 | + |
| 87 | + for (uint inner_k4 = 0; inner_k4 < group_ntexels; inner_k4++) { |
| 88 | + const uint k4 = group_i * group_ntexels + inner_k4; |
| 89 | + |
| 90 | + // Load 4x4 block of the input tensor, with the top left corner of the |
| 91 | + // block at (k, m) |
| 92 | + VEC4_T in_texels[4]; |
| 93 | + $for comp in range(4): |
| 94 | + in_texels[${comp}] = load_input_texel_2d(k4, m + ${comp}, K4); |
| 95 | + |
| 96 | + uvec4 packed_weight_block = load_transposed_weight_block(k4, n8, K4); |
| 97 | + |
| 98 | + VEC4_T weight_texels[2]; |
| 99 | + $for tile_k in range(4): |
| 100 | + // Process weight row k + comp |
| 101 | + { |
| 102 | + // Weight columns n + 0, 1, 2, 3 |
| 103 | + weight_texels[0].x = extract_4bit_from_transposed_block(packed_weight_block, 0, ${tile_k}); |
| 104 | + weight_texels[0].y = extract_4bit_from_transposed_block(packed_weight_block, 1, ${tile_k}); |
| 105 | + weight_texels[0].z = extract_4bit_from_transposed_block(packed_weight_block, 2, ${tile_k}); |
| 106 | + weight_texels[0].w = extract_4bit_from_transposed_block(packed_weight_block, 3, ${tile_k}); |
| 107 | + |
| 108 | + // Weight colums n + 4, 5, 6, 7 |
| 109 | + weight_texels[1].x = extract_4bit_from_transposed_block(packed_weight_block, 4, ${tile_k}); |
| 110 | + weight_texels[1].y = extract_4bit_from_transposed_block(packed_weight_block, 5, ${tile_k}); |
| 111 | + weight_texels[1].z = extract_4bit_from_transposed_block(packed_weight_block, 6, ${tile_k}); |
| 112 | + weight_texels[1].w = extract_4bit_from_transposed_block(packed_weight_block, 7, ${tile_k}); |
| 113 | + |
| 114 | + weight_texels[0] = fma(weight_texels[0], scales[0], zeros[0]); |
| 115 | + weight_texels[1] = fma(weight_texels[1], scales[1], zeros[1]); |
| 116 | + |
| 117 | + $for tile_m in range(4): |
| 118 | + out_texels[${tile_m}][0] = fma(VEC4_T(in_texels[${tile_m}][${tile_k}]), weight_texels[0], out_texels[${tile_m}][0]); |
| 119 | + out_texels[${tile_m}][1] = fma(VEC4_T(in_texels[${tile_m}][${tile_k}]), weight_texels[1], out_texels[${tile_m}][1]); |
| 120 | + } |
148 | 121 | }
|
149 | 122 | }
|
150 | 123 |
|
151 |
| - [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { |
152 |
| - $if OUT_STORAGE == "buffer": |
153 |
| - if (out_row + r < out_sizes.y) { |
154 |
| - t_out[((out_row + r) * out_sizes.x + out_col) >> 2] = sums[r][0]; |
155 |
| - t_out[((out_row + r) * out_sizes.x + out_col + 4) >> 2] = sums[r][1]; |
156 |
| - } |
157 |
| - $else: |
158 |
| - imageStore(t_out, ivec3(out_col_texel_idx, out_row + r, 0), sums[r][0]); |
159 |
| - imageStore(t_out, ivec3(out_col_texel_idx + 1, out_row + r, 0), sums[r][1]); |
| 124 | + for (uint row_i = 0; row_i < 4 && m + row_i < output_sizes.y; ++row_i) { |
| 125 | + write_output_texel_2d(out_texels[row_i][0], n4, m + row_i, N4); |
| 126 | + if (n + 4 < output_sizes.x) { |
| 127 | + write_output_texel_2d(out_texels[row_i][1], n4 + 1, m + row_i, N4); |
| 128 | + } |
160 | 129 | }
|
161 | 130 | }
|
0 commit comments