diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_coop.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_coop.glsl index f46c1f01c7b..81d2a5f0aed 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_coop.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_coop.glsl @@ -10,17 +10,14 @@ #define PRECISION ${PRECISION} -#define T ${buffer_scalar_type(DTYPE)} -#define VEC4_T ${buffer_gvec_type(DTYPE, 4)} +#define T ${texel_load_component_type(DTYPE, IO_STORAGE)} +#define VEC4_T ${texel_load_type(DTYPE, IO_STORAGE)} #define WGS ${WGS} ${define_required_extensions(DTYPE)} ${define_required_extensions("uint8")} -#extension GL_EXT_control_flow_attributes : require -#extension GL_EXT_debug_printf : require - layout(std430) buffer; #include "indexing_utils.h" @@ -99,7 +96,7 @@ void main() { } // The input tensor will have a shape of [K, 1, 1, 1]; in each iteration, // load 4 elements starting from the tensor index (k, 0, 0, 0). - VEC4_T in_texel = load_input_texel(k4); + VEC4_T in_texel = load_input_texel_1d(k4); // Extract each element of the in_texel into a separate vectorized variable; // these are used to "broadcast" the input values in subsequent fma calls. VEC4_T in_texel_val[4]; @@ -151,9 +148,9 @@ void main() { out_texels[1] = partial_sums[0][1]; uint n4 = DIV_4(n); - write_output_texel(out_texels[0], n4); + write_output_texel_1d(out_texels[0], n4); if (n + 4 < output_sizes.x) { - write_output_texel(out_texels[1], n4 + 1); + write_output_texel_1d(out_texels[1], n4 + 1); } } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_tiled.glsl index 64d0991e489..97327ea5818 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_tiled.glsl @@ -10,152 +10,121 @@ #define PRECISION ${PRECISION} -#define T ${buffer_scalar_type(DTYPE)} -#define VEC4_T ${buffer_gvec_type(DTYPE, 4)} - -#define TILE_ROWS ${TILE_ROWS} +#define T ${texel_load_component_type(DTYPE, IO_STORAGE)} +#define VEC4_T ${texel_load_type(DTYPE, IO_STORAGE)} ${define_required_extensions(DTYPE)} -$if WEIGHT_STORAGE == "buffer": - ${define_required_extensions("uint8")} - -#extension GL_EXT_control_flow_attributes : require layout(std430) buffer; -${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_mat1", DTYPE, IN_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_qmat2", "uint8", WEIGHT_STORAGE, is_scalar_array=False)} +#include "indexing_utils.h" + +${layout_declare_tensor(B, "w", "t_output", DTYPE, IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_input", DTYPE, IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_qmat2", "uint", WEIGHT_STORAGE, is_scalar_array=False)} ${layout_declare_tensor(B, "r", "t_qparams", DTYPE, "buffer", is_scalar_array=False)} layout(push_constant) uniform restrict Block { - ivec4 out_sizes; - ivec4 mat1_sizes; - ivec4 qmat2_sizes; + ivec4 output_sizes; + ivec4 input_sizes; + ivec4 weight_sizes; }; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; layout(constant_id = 3) const int group_size = 64; -/* - * This shader computes a linear operator between a floating point input matrix - * x and a weights matrix that is quantized to 4 bits. - * - * The (W, H, C) shape of each tensor is: - * - x: (K, M) - * - weights: (N / 2, K) - * - The weights tensor has a data type of `uint8`. Each element in the tensor - * contains 2 4-bit values packed into a uint8. - * - See the pack_int4_linear_weight_transposed_interleave shader to see more - * details on how the weight tensor is stored. - * - qparams: (2, N, number_of_groups) - * - This tensor contains the scales and zeros quantization parameters for the - * weights tensor. The weight tensor is quantized group-wise, which means - * that every `group_size` elements along the K dimension of the weights - * tensor has independent quantization parameters. Along the width dim, the - * first value contains the scale for the group and the second value - * contains the zero point for the group. - * - * Each thread computes a tile of TILE_ROWS * 2 texels of the output tensor. - * - * Note that this shader assumes that all tensors are width packed. - */ +$if IO_STORAGE == "buffer": + #define BUFFER_IO +$if WEIGHT_STORAGE == "buffer": + #define BUFFER_WEIGHT + +#include "qlinear_utils.glslh" + void main() { - const uint out_row = gl_GlobalInvocationID.y * TILE_ROWS; - // Each thread writes out 2 texels along the width axis, equivalent to 8 - // scalar elements. Therefore multiply the thread_idx.x by 8. - const uint out_col = gl_GlobalInvocationID.x << 3; - // Similar reasoning to the above, each thread works on 2 texels along the - // width axis so multiply thread_idx.x by 2. - const int out_col_texel_idx = int(gl_GlobalInvocationID.x) << 1; - - if (out_col >= out_sizes.x || out_row >= out_sizes.y) { + // Each thread writes out a 8 wide x 4 high tile of output values + const uint n8 = gl_GlobalInvocationID.x; + const uint m4 = gl_GlobalInvocationID.y; + + const uint n = MUL_8(n8); // output col idx + const uint m = MUL_4(m4); // output row idx + const uint n4 = MUL_2(n8); // output col texel idx + + const uint group_num = input_sizes.x / group_size; + const uint group_ntexels = DIV_UP_4(group_size); + + if (n >= output_sizes.x || m >= output_sizes.y) { return; } - const int num_blocks = mat1_sizes.x / group_size; + const uint K4 = DIV_UP_4(input_sizes.x); + const uint N4 = DIV_UP_4(output_sizes.x); // number of texels in each row - VEC4_T mat1[TILE_ROWS]; - VEC4_T qmat2[4][2]; - VEC4_T sums[TILE_ROWS][2]; + VEC4_T out_texels[4][2]; + // Initialize to 0 + $for row_i in range(4): + $for col_i in range(2): + out_texels[${row_i}][${col_i}] = VEC4_T(0.00); - [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { - sums[r][0] = VEC4_T(0); - sums[r][1] = VEC4_T(0); - } + for (uint group_i = 0; group_i < group_num; ++group_i) { + // Load quantization scales and zeros for the current group + VEC4_T scales[2]; + VEC4_T zeros[2]; + { + uint qparams_bufi = group_i * DIV_2(output_sizes.x) + DIV_2(n); - VEC4_T scales[2]; - VEC4_T zeros[2]; - - $if WEIGHT_STORAGE == "buffer": - const int qmat2_stride = qmat2_sizes.x >> 2; - $if PARAMS_STORAGE == "buffer": - const int qparams_y_stride = out_sizes.x >> 2; - const int qparams_z_stride = qparams_y_stride * 2; - - for (int block_idx = 0; block_idx < num_blocks; ++block_idx) { - $if PARAMS_STORAGE == "buffer": - scales[0] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx]; - zeros[0] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + qparams_y_stride]; - - scales[1] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + 1]; - zeros[1] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + 1 + qparams_y_stride]; - $else: - scales[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 0, block_idx), 0); - zeros[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 1, block_idx), 0); - - scales[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 0, block_idx), 0); - zeros[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 1, block_idx), 0); - - for (int g_idx = 0; g_idx < group_size; g_idx += 4) { - const int k = block_idx * group_size + g_idx; - - // Preload B - [[unroll]] for (int r = 0; r < 4; ++r) { - $if WEIGHT_STORAGE == "buffer": - const u8vec4 packed_weight_tex = t_qmat2[(k + r) * qmat2_stride + gl_GlobalInvocationID.x]; - $else: - const uvec4 packed_weight_tex = texelFetch( - t_qmat2, - ivec2(gl_GlobalInvocationID.x, k + r), - 0); - - qmat2[r][0] = (VEC4_T((packed_weight_tex & 0xF0) >> 4) - 8.0) * scales[0] + zeros[0]; - qmat2[r][1] = (VEC4_T(packed_weight_tex & 0x0F) - 8.0) * scales[1] + zeros[1]; - } - - // Preload A - [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { - $if IN_STORAGE == "buffer": - mat1[r] = t_mat1[((out_row + r) * mat1_sizes.x + k) >> 2]; - $else: - mat1[r] = texelFetch(t_mat1, ivec3(k >> 2, out_row + r, 0), 0); - } - - // Accumulate output tile - [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { - sums[r][0] += mat1[r].x * qmat2[0][0] - + mat1[r].y * qmat2[1][0] - + mat1[r].z * qmat2[2][0] - + mat1[r].w * qmat2[3][0]; - - sums[r][1] += mat1[r].x * qmat2[0][1] - + mat1[r].y * qmat2[1][1] - + mat1[r].z * qmat2[2][1] - + mat1[r].w * qmat2[3][1]; - } + VEC4_T scales_zeros_texels[4]; + $for comp in range(4): + scales_zeros_texels[${comp}] = t_qparams[qparams_bufi++]; + + scales[0] = VEC4_T(scales_zeros_texels[0].xz, scales_zeros_texels[1].xz); + zeros[0] = VEC4_T(scales_zeros_texels[0].yw, scales_zeros_texels[1].yw); + + scales[1] = VEC4_T(scales_zeros_texels[2].xz, scales_zeros_texels[3].xz); + zeros[1] = VEC4_T(scales_zeros_texels[2].yw, scales_zeros_texels[3].yw); + } + + for (uint inner_k4 = 0; inner_k4 < group_ntexels; inner_k4++) { + const uint k4 = group_i * group_ntexels + inner_k4; + + // Load 4x4 block of the input tensor, with the top left corner of the + // block at (k, m) + VEC4_T in_texels[4]; + $for comp in range(4): + in_texels[${comp}] = load_input_texel_2d(k4, m + ${comp}, K4); + + uvec4 packed_weight_block = load_transposed_weight_block(k4, n8, K4); + + VEC4_T weight_texels[2]; + $for tile_k in range(4): + // Process weight row k + comp + { + // Weight columns n + 0, 1, 2, 3 + weight_texels[0].x = extract_4bit_from_transposed_block(packed_weight_block, 0, ${tile_k}); + weight_texels[0].y = extract_4bit_from_transposed_block(packed_weight_block, 1, ${tile_k}); + weight_texels[0].z = extract_4bit_from_transposed_block(packed_weight_block, 2, ${tile_k}); + weight_texels[0].w = extract_4bit_from_transposed_block(packed_weight_block, 3, ${tile_k}); + + // Weight colums n + 4, 5, 6, 7 + weight_texels[1].x = extract_4bit_from_transposed_block(packed_weight_block, 4, ${tile_k}); + weight_texels[1].y = extract_4bit_from_transposed_block(packed_weight_block, 5, ${tile_k}); + weight_texels[1].z = extract_4bit_from_transposed_block(packed_weight_block, 6, ${tile_k}); + weight_texels[1].w = extract_4bit_from_transposed_block(packed_weight_block, 7, ${tile_k}); + + weight_texels[0] = fma(weight_texels[0], scales[0], zeros[0]); + weight_texels[1] = fma(weight_texels[1], scales[1], zeros[1]); + + $for tile_m in range(4): + out_texels[${tile_m}][0] = fma(VEC4_T(in_texels[${tile_m}][${tile_k}]), weight_texels[0], out_texels[${tile_m}][0]); + out_texels[${tile_m}][1] = fma(VEC4_T(in_texels[${tile_m}][${tile_k}]), weight_texels[1], out_texels[${tile_m}][1]); + } } } - [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { - $if OUT_STORAGE == "buffer": - if (out_row + r < out_sizes.y) { - t_out[((out_row + r) * out_sizes.x + out_col) >> 2] = sums[r][0]; - t_out[((out_row + r) * out_sizes.x + out_col + 4) >> 2] = sums[r][1]; - } - $else: - imageStore(t_out, ivec3(out_col_texel_idx, out_row + r, 0), sums[r][0]); - imageStore(t_out, ivec3(out_col_texel_idx + 1, out_row + r, 0), sums[r][1]); + for (uint row_i = 0; row_i < 4 && m + row_i < output_sizes.y; ++row_i) { + write_output_texel_2d(out_texels[row_i][0], n4, m + row_i, N4); + if (n + 4 < output_sizes.x) { + write_output_texel_2d(out_texels[row_i][1], n4 + 1, m + row_i, N4); + } } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_tiled.yaml index 8475c7d48a3..94d10dcf978 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_tiled.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_tiled.yaml @@ -7,17 +7,12 @@ linear_qga4w_tiled: parameter_names_with_default_values: DTYPE: float - OUT_STORAGE: texture3d - IN_STORAGE: texture3d + IO_STORAGE: texture3d WEIGHT_STORAGE: texture2d - PARAMS_STORAGE: buffer - TILE_ROWS: 3 shader_variants: - NAME: linear_qga4w_tiled_texture3d_texture3d_texture2d_float - NAME: linear_qga4w_tiled_buffer_buffer_texture2d_float - OUT_STORAGE: buffer - IN_STORAGE: buffer + IO_STORAGE: buffer - NAME: linear_qga4w_tiled_buffer_buffer_buffer_float - OUT_STORAGE: buffer - IN_STORAGE: buffer + IO_STORAGE: buffer WEIGHT_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/qlinear_utils.glslh b/backends/vulkan/runtime/graph/ops/glsl/qlinear_utils.glslh index 987ae06773f..80ec44c153a 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/qlinear_utils.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/qlinear_utils.glslh @@ -34,8 +34,23 @@ uvec4 load_transposed_weight_block(const uint k4, const uint n8, const uint K4) * Packed weight data extraction functions */ -float extract_4bit_from_transposed_block(const uvec4 block, const uint col, const uint row) { - return float(int((block[row] >> (4 * (7 - col))) & 15) - 8); +/* + * uvec4 block contains a packed 4 high x 8 wide matrix of 4-bit signed integers. This + * function extracts the 4-bit values at the given column and row index. + * + * Each uint in the uvec4 corresponds to one row; thus the desired row can be extracted + * via block[row]. From there, column 0 is packed in bits 28-31, column 1 is packed into + * bits 24-27, column 3 is packed into bits 20-23, and so on. To extract the desired + * value: + * + * 1. First, shift the row uint by 4 * (7 - col) bits + * 2. Apply a mask of 0b1111 = 15 + * + * Finally, convert the masked value to int and subtract it by int to obtain the desired + * signed integer. + */ +T extract_4bit_from_transposed_block(const uvec4 block, const uint col, const uint row) { + return T(int((block[row] >> (4 * (7 - col))) & 15) - 8); } /*********************************** @@ -47,24 +62,55 @@ float extract_4bit_from_transposed_block(const uvec4 block, const uint col, cons #ifdef BUFFER_IO -VEC4_T load_input_texel(const uint k4) { +VEC4_T load_input_texel_1d(const uint k4) { return t_input[k4]; } -void write_output_texel(const VEC4_T out_texel, const uint n4) { +VEC4_T load_input_texel_2d( + const uint k4, + const uint m, + const uint K4) { + return t_input[(m * K4) + k4]; +} + +void write_output_texel_1d(const VEC4_T out_texel, const uint n4) { t_output[n4] = out_texel; } +void write_output_texel_2d( + const VEC4_T out_texel, + const uint n4, + const uint m, + const uint N4) { + t_output[m * N4 + n4] = out_texel; +} + #else // TEXTURE_IO -VEC4_T load_input_texel(const uint k4) { +VEC4_T load_input_texel_1d(const uint k4) { return texelFetch(t_input, ivec3(k4, 0, 0), 0); } -void write_output_texel(const VEC4_T out_texel, const uint n4) { +VEC4_T load_input_texel_2d( + const uint k4, + const uint m, + const uint K4) { + return texelFetch(t_input, ivec3(k4, m, 0), 0); +} + + +void write_output_texel_1d(const VEC4_T out_texel, const uint n4) { imageStore(t_output, ivec3(n4, 0, 0), out_texel); } +void write_output_texel_2d( + const VEC4_T out_texel, + const uint n4, + const uint m, + const uint N4) { + imageStore(t_output, ivec3(n4, m, 0), out_texel); +} + #endif // BUFFER_IO #endif // QLINEAR_UTILS_H diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQGANW.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQGANW.cpp index 5e6bb35b029..8c7c6b0cdf9 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQGANW.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQGANW.cpp @@ -121,16 +121,26 @@ utils::uvec3 linear_qga4w_global_wg_size( shader.kernel_name.find("_coop") != std::string::npos; if (!use_coop_algorithm) { + // Constructing the global workgroup size for the tiled algorithm utils::uvec3 global_wg_size = graph->logical_limits_of(out); + // Each shader thread computes a 4 high x 8 wide tile of the output matrix, + // which is equivalent to 4 x 2 texels. Since the output tensor must be + // width packed, div-up the "texel-width" of the output by 2 and the height + // of the output tensor by 4 to obtain the number of tiles that need to be + // computed. global_wg_size[0] = utils::div_up(global_wg_size[0], uint32_t(2)); - - global_wg_size[1] = utils::div_up(global_wg_size[1], uint32_t(3)); + global_wg_size[1] = utils::div_up(global_wg_size[1], uint32_t(4)); return global_wg_size; } uint32_t output_channels = graph->size_at(-1, out); uint32_t batch_size = graph->size_at(-2, out); + // Constructing the global workgroup size of the co-operative algorithm. The + // local work group size is 64, and each local work group co-operates to + // compute 8 output channels of the output. Therefore, a total of + // (output_channels / 8 x 64) threads should be launched, assuming a batch + // size of 1. return {64, utils::div_up(output_channels, 8u), batch_size}; } @@ -162,18 +172,13 @@ void add_linear_qga4w_node( check_linear_qga4w_args( graph, mat1, mat2_data, group_size, scales_and_zeros_data, out); - bool is_gemv = should_use_coop_algorithm(&graph, mat1); const uint32_t group_size_val = graph.extract_scalar(group_size); - ValueRef mat2 = is_gemv - ? prepack_int4_linear_weight_transposed_block_4x8(graph, mat2_data) - : prepack_int4_linear_weight_transposed_interleaved(graph, mat2_data); + ValueRef mat2 = + prepack_int4_linear_weight_transposed_block_4x8(graph, mat2_data); - ValueRef scales_and_zeros = is_gemv - ? prepack_standard( - graph, scales_and_zeros_data, utils::kBuffer, utils::kWidthPacked) - : prepack_standard_hw_transposed( - graph, scales_and_zeros_data, utils::kBuffer, utils::kWidthPacked); + ValueRef scales_and_zeros = prepack_standard( + graph, scales_and_zeros_data, utils::kBuffer, utils::kWidthPacked); graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph,