From 4b42e35fd6bf21495751d6aecb526fbe12e84e01 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Thu, 17 Jul 2025 10:26:00 -0700 Subject: [PATCH] [ET-VK][qlinear] Faster weight only quantized linear gemv kernel Pull Request resolved: https://github.com/pytorch/executorch/pull/12444 ## Changes * Introduce a new compute shader for int4 linear's gemv cases that performs much better than the existing shader. This shader is inspired from MNN's gemv_1x1_conv_buf.cl shader. With this compute kernel, transformer models' text generation can execute much faster than before. On Samsung Galaxy S24 for Llama 3.2 1B, generating 128 tokens: Before: ~25 tok/s After: ~49 tok/s ## Why this new shader is faster The biggest reason is due to vectorized loading of the uint4 weight buffer. This new shader loads the weight buffer as a buffer/image of `uvec4`, whereas the old shader loads the weight buffer as a buffer/image of `u8vec4`. Using the Adreno Offline Compiler, I found that in the former, only one load instruction was used to load from the weight tensor, whereas in the latter 16 load instructions were used to load from the weight tensor. It appears that the data loading was not being vectorized at the assembly level. This is potentially behaviour that can be approved in the SPIR-V shader compiler. An additional factor is better weight packing layout. The new prepacking routine results in better memory coalescing between threads in a work group. The final major factor is the use of tree based reduction to co-operatively reduce partial results into the final output. Previously, a single thread was responsible for the final reduction. ## Future Work * Introduce faster shader for int4 linear gemm cases * Update QCSNW to also use these updated shaders ghstack-source-id: 296864718 Differential Revision: [D78275584](https://our.internmc.facebook.com/intern/diff/D78275584/) --- .../runtime/graph/ops/glsl/indexing_utils.h | 12 + .../graph/ops/glsl/linear_qga4w_coop.glsl | 250 ++++++++---------- .../graph/ops/glsl/linear_qga4w_coop.yaml | 12 +- .../vulkan/runtime/graph/ops/glsl/no_op.yaml | 1 + ...t4_linear_weight_transposed_block_4x8.glsl | 154 +++++++++++ ...t4_linear_weight_transposed_block_4x8.yaml | 14 + .../graph/ops/glsl/qlinear_utils.glslh | 70 +++++ .../ops/glsl/qlinear_weight_pack_utils.glslh | 58 ++++ .../graph/ops/impl/QuantizedLinearQGANW.cpp | 49 ++-- .../vulkan/runtime/graph/ops/impl/Staging.cpp | 69 +++++ .../vulkan/runtime/graph/ops/impl/Staging.h | 5 + 11 files changed, 523 insertions(+), 171 deletions(-) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_block_4x8.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_block_4x8.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/qlinear_utils.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/qlinear_weight_pack_utils.glslh diff --git a/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h index 0cfd7f2f119..72650bb7040 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h +++ b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h @@ -68,6 +68,18 @@ */ #define mod4(x) ((x) & 3) +#define ALIGN_UP_4(x) (((x) + 3) & ~3) + +#define DIV_UP_8(x) (((x) + 7) >> 3) +#define DIV_UP_4(x) (((x) + 3) >> 2) + +#define DIV_4(x) ((x) >> 2) +#define DIV_2(x) ((x) >> 1) + +#define MUL_8(x) ((x) << 3) +#define MUL_4(x) ((x) << 2) +#define MUL_2(x) ((x) << 1) + /* * Get the staging buffer indices that contain the data of the texel that * corresponds to the provided tensor index. Since the texel have 4 elements, diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_coop.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_coop.glsl index 715f84d3a56..f46c1f01c7b 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_coop.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_coop.glsl @@ -13,187 +13,147 @@ #define T ${buffer_scalar_type(DTYPE)} #define VEC4_T ${buffer_gvec_type(DTYPE, 4)} -#define TILE_ROWS ${TILE_ROWS} - -#define NGROUPS 8 -#define NWORKERS 8 +#define WGS ${WGS} ${define_required_extensions(DTYPE)} -$if WEIGHT_STORAGE == "buffer": - ${define_required_extensions("uint8")} +${define_required_extensions("uint8")} #extension GL_EXT_control_flow_attributes : require +#extension GL_EXT_debug_printf : require layout(std430) buffer; -${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_mat1", DTYPE, IN_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_qmat2", "uint8", WEIGHT_STORAGE, is_scalar_array=False)} +#include "indexing_utils.h" + +${layout_declare_tensor(B, "w", "t_output", DTYPE, IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_input", DTYPE, IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_qmat2", "uint", WEIGHT_STORAGE, is_scalar_array=False)} ${layout_declare_tensor(B, "r", "t_qparams", DTYPE, "buffer", is_scalar_array=False)} layout(push_constant) uniform restrict Block { - ivec4 out_sizes; - ivec4 mat1_sizes; - ivec4 qmat2_sizes; + ivec4 output_sizes; + ivec4 input_sizes; + ivec4 weight_sizes; }; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; layout(constant_id = 3) const int group_size = 64; -shared VEC4_T partial_sums[NGROUPS][NWORKERS][TILE_ROWS][2]; +shared VEC4_T partial_sums[WGS][2]; + +$if IO_STORAGE == "buffer": + #define BUFFER_IO +$if WEIGHT_STORAGE == "buffer": + #define BUFFER_WEIGHT + +#include "qlinear_utils.glslh" -/* - * This shader computes a linear operator between a floating point input matrix - * x and a weights matrix that is quantized to 4 bits. Please refer to the - * q_4w_linear shader for more details. - * - * This shader implements a co-operative algorithm to compute the output. The - * work group size is {NGROUP, 1, NWORKERS}, and each group of NWORKERS threads - * cooperative to compute TILE_ROWS * 2 output texels. Therefore, - * NGROUP * TILE_ROWS * 2 output texels are computed across one work group. - * - * The threads co-operate by each thread computing a partial reduction along the - * K dimension. To illustrate the computation, consider a scalar variant of the - * algorithm that computes the dot product of 2 vectors. Also assume that - * NWORKERS is 8. - * - * Thread 1 in each group will compute: - * (mat1[0] * mat2[0]) + (mat1[8] * mat2[8]) + (mat1[16] * mat2[16]) + ... - * - * Thread 2 in each group will compute: - * (mat1[1] * mat2[1]) + (mat2[9] * mat2[9]) + (mat1[17] * mat2[17]) + ... - * - * Thread 3 in each group will compute: - * (mat1[2] * mat2[2]) + (mat2[10] * mat2[10]) + (mat1[18] * mat2[18]) + ... - * - * The partial accumulations is structured such that memory accesses in each - * loop iteration can be coalesced. - * - * Then, at the end first thread in each group will accumulate the partial - * accumulations computed by each thread to obtain the final result. - * - * Note that this shader assumes that all tensors are width packed. - */ void main() { - const uint out_row = gl_GlobalInvocationID.y * TILE_ROWS; - // Each thread writes out 2 texels along the width axis, equivalent to 8 - // scalar elements. Therefore multiply the thread_idx.x by 8. - const uint out_col = gl_GlobalInvocationID.x << 3; - // Similar reasoning to the above, each thread works on 2 texels along the - // width axis so multiply thread_idx.x by 2. - const int out_col_texel_idx = int(gl_GlobalInvocationID.x) << 1; - - const uint gid = gl_LocalInvocationID.x; // group id - const uint wid = gl_LocalInvocationID.z; // worker id - - if (out_col >= out_sizes.x || out_row >= out_sizes.y) { + const uint lid = gl_LocalInvocationID.x; + const uint n8 = gl_GlobalInvocationID.y; + // The output tensor will have a shape of [n, 1, 1, 1]. Each thread computes + // 8 output elements, so each thread will write to 8 elements starting at the + // tensor index (gid.x * 8, 0, 0, 0). + const uint n = MUL_8(n8); + const uint K4 = DIV_UP_4(input_sizes.x); + + if (n >= output_sizes.x) { return; } - const int num_blocks = mat1_sizes.x / group_size; + VEC4_T out_texels[2]; + out_texels[0] = VEC4_T(0); + out_texels[1] = VEC4_T(0); - VEC4_T mat1[TILE_ROWS]; - VEC4_T qmat2[4][2]; - VEC4_T local_sums[TILE_ROWS][2]; + // initialize the group index to a value larger than the largest possible + uint cur_group_idx = input_sizes.x; - [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { - local_sums[r][0] = VEC4_T(0); - local_sums[r][1] = VEC4_T(0); - } + // Each thread in the work group accumulates a partial result. + for (uint k4 = lid; k4 < DIV_UP_4(input_sizes.x); k4 += WGS) { + const uint k = MUL_4(k4); + const uint group_idx = k / group_size; - VEC4_T scales[2]; - VEC4_T zeros[2]; - - $if WEIGHT_STORAGE == "buffer": - const int qmat2_stride = qmat2_sizes.x >> 2; - $if PARAMS_STORAGE == "buffer": - const int qparams_y_stride = out_sizes.x >> 2; - const int qparams_z_stride = qparams_y_stride * 2; - - for (int block_idx = 0; block_idx < num_blocks; ++block_idx) { - $if PARAMS_STORAGE == "buffer": - scales[0] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx]; - zeros[0] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + qparams_y_stride]; - - scales[1] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + 1]; - zeros[1] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + 1 + qparams_y_stride]; - $else: - scales[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 0, block_idx), 0); - zeros[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 1, block_idx), 0); - - scales[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 0, block_idx), 0); - zeros[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 1, block_idx), 0); - - for (uint g_idx = 4 * wid; g_idx < group_size; g_idx += (4 * NWORKERS)) { - const uint k = block_idx * group_size + g_idx; - - // Preload B - [[unroll]] for (int r = 0; r < 4; ++r) { - $if WEIGHT_STORAGE == "buffer": - const u8vec4 packed_weight_tex = t_qmat2[(k + r) * qmat2_stride + gl_GlobalInvocationID.x]; - $else: - const uvec4 packed_weight_tex = texelFetch( - t_qmat2, - ivec2(gl_GlobalInvocationID.x, k + r), - 0); - - qmat2[r][0] = (VEC4_T((packed_weight_tex & 0xF0) >> 4) - 8.0) * scales[0] + zeros[0]; - qmat2[r][1] = (VEC4_T(packed_weight_tex & 0x0F) - 8.0) * scales[1] + zeros[1]; - } + VEC4_T scales[2]; + VEC4_T zeros[2]; - // Preload A - [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { - $if IN_STORAGE == "buffer": - mat1[r] = t_mat1[((out_row + r) * mat1_sizes.x + k) >> 2]; - $else: - mat1[r] = texelFetch(t_mat1, ivec3(k >> 2, out_row + r, 0), 0); - } + // Only update the scales/zeros if the current iteration is now working on a + // new quantization group. + if (group_idx != cur_group_idx) { + // The qparams tensor contains the quantization scales and zeros, with + // shape [2, N, K / group_size, 1]. + // Loading a texel from the qparams tensor will return 2 scales and 2 + // zeros for 2 adjacent output channels. + uint qparams_bufi = group_idx * DIV_2(output_sizes.x) + DIV_2(n); + VEC4_T scales_zeros_texels[4]; + $for comp in range(4): + scales_zeros_texels[${comp}] = t_qparams[qparams_bufi++]; - // Accumulate local output tile - [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { - local_sums[r][0] += mat1[r].x * qmat2[0][0] - + mat1[r].y * qmat2[1][0] - + mat1[r].z * qmat2[2][0] - + mat1[r].w * qmat2[3][0]; - - local_sums[r][1] += mat1[r].x * qmat2[0][1] - + mat1[r].y * qmat2[1][1] - + mat1[r].z * qmat2[2][1] - + mat1[r].w * qmat2[3][1]; - } + scales[0] = VEC4_T(scales_zeros_texels[0].xz, scales_zeros_texels[1].xz); + zeros[0] = VEC4_T(scales_zeros_texels[0].yw, scales_zeros_texels[1].yw); + + scales[1] = VEC4_T(scales_zeros_texels[2].xz, scales_zeros_texels[3].xz); + zeros[1] = VEC4_T(scales_zeros_texels[2].yw, scales_zeros_texels[3].yw); + + cur_group_idx = group_idx; } + // The input tensor will have a shape of [K, 1, 1, 1]; in each iteration, + // load 4 elements starting from the tensor index (k, 0, 0, 0). + VEC4_T in_texel = load_input_texel(k4); + // Extract each element of the in_texel into a separate vectorized variable; + // these are used to "broadcast" the input values in subsequent fma calls. + VEC4_T in_texel_val[4]; + $for comp in range(4): + in_texel_val[${comp}] = VEC4_T(in_texel[${comp}]); + + uvec4 packed_weight_block = load_transposed_weight_block(k4, n8, K4); + + VEC4_T weight_texels[2]; + $for comp in range(4): + { + weight_texels[0].x = extract_4bit_from_transposed_block(packed_weight_block, 0, ${comp}); + weight_texels[0].y = extract_4bit_from_transposed_block(packed_weight_block, 1, ${comp}); + weight_texels[0].z = extract_4bit_from_transposed_block(packed_weight_block, 2, ${comp}); + weight_texels[0].w = extract_4bit_from_transposed_block(packed_weight_block, 3, ${comp}); + + weight_texels[1].x = extract_4bit_from_transposed_block(packed_weight_block, 4, ${comp}); + weight_texels[1].y = extract_4bit_from_transposed_block(packed_weight_block, 5, ${comp}); + weight_texels[1].z = extract_4bit_from_transposed_block(packed_weight_block, 6, ${comp}); + weight_texels[1].w = extract_4bit_from_transposed_block(packed_weight_block, 7, ${comp}); + + weight_texels[0] = fma(weight_texels[0], scales[0], zeros[0]); + weight_texels[1] = fma(weight_texels[1], scales[1], zeros[1]); + + out_texels[0] = fma(in_texel_val[${comp}], weight_texels[0], out_texels[0]); + out_texels[1] = fma(in_texel_val[${comp}], weight_texels[1], out_texels[1]); + } } - [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { - partial_sums[gid][wid][r][0] = local_sums[r][0]; - partial_sums[gid][wid][r][1] = local_sums[r][1]; - } + partial_sums[lid][0] = out_texels[0]; + partial_sums[lid][1] = out_texels[1]; memoryBarrierShared(); barrier(); - if (wid != 0) { - return; + // Tree reduction to compute the overall result. + for (int i = WGS / 2; i > 0; i /= 2) { + if (lid < i) { + partial_sums[lid][0] = partial_sums[lid][0] + partial_sums[lid + i][0]; + partial_sums[lid][1] = partial_sums[lid][1] + partial_sums[lid + i][1]; + } + memoryBarrierShared(); + barrier(); } - VEC4_T sums[TILE_ROWS][2]; + // Only the first thread will write out result + if (lid == 0) { + out_texels[0] = partial_sums[0][0]; + out_texels[1] = partial_sums[0][1]; - for (int r = 0; r < TILE_ROWS; ++r) { - sums[r][0] = VEC4_T(0); - sums[r][1] = VEC4_T(0); - [[unroll]] for (int worker = 0; worker < NWORKERS; ++ worker) { - sums[r][0] += partial_sums[gid][worker][r][0]; - sums[r][1] += partial_sums[gid][worker][r][1]; + uint n4 = DIV_4(n); + write_output_texel(out_texels[0], n4); + if (n + 4 < output_sizes.x) { + write_output_texel(out_texels[1], n4 + 1); } } - - [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { - $if OUT_STORAGE == "buffer": - t_out[((out_row + r) * out_sizes.x + out_col) >> 2] = sums[r][0]; - t_out[((out_row + r) * out_sizes.x + out_col + 4) >> 2] = sums[r][1]; - $else: - imageStore(t_out, ivec3(out_col_texel_idx, out_row + r, 0), sums[r][0]); - imageStore(t_out, ivec3(out_col_texel_idx + 1, out_row + r, 0), sums[r][1]); - } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_coop.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_coop.yaml index 25ffe94f430..04e803a2e94 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_coop.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_coop.yaml @@ -7,17 +7,13 @@ linear_qga4w_coop: parameter_names_with_default_values: DTYPE: float - OUT_STORAGE: texture3d - IN_STORAGE: texture3d + IO_STORAGE: texture3d WEIGHT_STORAGE: texture2d - PARAMS_STORAGE: buffer - TILE_ROWS: 1 + WGS: 64 shader_variants: - NAME: linear_qga4w_coop_texture3d_texture3d_texture2d_float - NAME: linear_qga4w_coop_buffer_buffer_texture2d_float - OUT_STORAGE: buffer - IN_STORAGE: buffer + IO_STORAGE: buffer - NAME: linear_qga4w_coop_buffer_buffer_buffer_float - OUT_STORAGE: buffer - IN_STORAGE: buffer + IO_STORAGE: buffer WEIGHT_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/no_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/no_op.yaml index bfeaba2496b..f888e8661d3 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/no_op.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/no_op.yaml @@ -13,6 +13,7 @@ no_op: - VALUE: half - VALUE: float - VALUE: int32 + - VALUE: uint32 - VALUE: int8 - VALUE: uint8 STORAGE: diff --git a/backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_block_4x8.glsl b/backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_block_4x8.glsl new file mode 100644 index 00000000000..e42cf05dd7f --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_block_4x8.glsl @@ -0,0 +1,154 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +layout(std430) buffer; + +#include "indexing_utils.h" + +${layout_declare_tensor(B, "w", "t_qmat2", "uint", STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_input", "uint", "buffer")} + +layout(push_constant) uniform restrict Block { + ivec4 qmat2_sizes; + ivec2 orig_sizes; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +$if STORAGE == "buffer": + #define BUFFER_WEIGHT + +#include "qlinear_weight_pack_utils.glslh" + +#define extract_4bit(input_block_data, col, row) \ + (extract_4bit_from_packed_uint_le(input_block_data[row], col)) + +/* + * This shader packs the weight tensor into blocks for efficient consumption. + * + * The input tensor has shape [K/2, N] where each element is a uint8 containing + * 2 packed 4-bit values. The logical tensor shape is [K, N] of 4-bit values. + * + * The transformation partitions the tensor into blocks of size 4x8 (4-bit values) + * and transposes each block to 8x4, then packs the result so that each uvec4 + * contains an entire transposed block. + * + * Original block (4x8 4-bit values, shown as 2x8 uint8 values): + * w00|w10, w20|w30, + * w01|w11, w21|w31, + * w02|w12, w22|w32, + * w03|w13, w23|w33, + * w04|w14, w24|w34, + * w05|w15, w25|w35, + * w06|w16, w26|w36, + * w07|w17, w27|w37, + * + * Transposed block (8x4 4-bit values, packed into uvec4): + * w00|w01, w02|w03, w04|w05, w06|w07 + * w10|w11, w12|w13, w14|w15, w16|w17 + * w20|w21, w22|w23, w24|w25, w26|w27 + * w30|w31, w32|w33, w34|w35, w36|w37 + */ +void main() { + // Each thread writes out 2 adjacent 8 wide x 4 high transposed block. Each + // block is packed as one uvec4. + ivec2 block_pos = ivec2( + MUL_2(gl_GlobalInvocationID.x), + gl_GlobalInvocationID.y); + + // There are K wide x N high 4-bit values in the original weight tensor + const int input_width = orig_sizes.x; // K + const int input_height = orig_sizes.y; // N + + const int input_width_uint = DIV_UP_8(input_width); + + // Original block spans 4 wide x 8 high 4-bit values. Since uint is used to + // read the input tensor, each block spans 0.5 wide x 8 high uint values. + const ivec2 block_start = ivec2( + DIV_2(block_pos.x), + MUL_8(block_pos.y)); + + // Check bounds + if (block_start.x >= input_width_uint || block_start.y >= input_height) { + return; + } + + // Read input block. Note that this block will contain the source data for + // both output blocks, as it contains 1 wide x 8 high uint values, which is + // equivalent to 8 wide x 8 high 4-bit values. + uint input_block_data[8]; + + // Read in 8 rows along the same column of uints, each uint contains 4 4-bit + // values. This will be the source data for the transposed block. + for (int i = 0; i < 8; ++i) { + uint input_bufi = (block_start.y + i) * input_width_uint + block_start.x; + input_block_data[i] = t_input[input_bufi]; + } + + for (int col_offset = 0; col_offset <= 4; col_offset+=4) { + uvec4 output_block; + + output_block.x = pack_8x4bit_into_uint( + extract_4bit(input_block_data, col_offset, 0), + extract_4bit(input_block_data, col_offset, 1), + extract_4bit(input_block_data, col_offset, 2), + extract_4bit(input_block_data, col_offset, 3), + extract_4bit(input_block_data, col_offset, 4), + extract_4bit(input_block_data, col_offset, 5), + extract_4bit(input_block_data, col_offset, 6), + extract_4bit(input_block_data, col_offset, 7)); + + output_block.y = pack_8x4bit_into_uint( + extract_4bit(input_block_data, col_offset + 1, 0), + extract_4bit(input_block_data, col_offset + 1, 1), + extract_4bit(input_block_data, col_offset + 1, 2), + extract_4bit(input_block_data, col_offset + 1, 3), + extract_4bit(input_block_data, col_offset + 1, 4), + extract_4bit(input_block_data, col_offset + 1, 5), + extract_4bit(input_block_data, col_offset + 1, 6), + extract_4bit(input_block_data, col_offset + 1, 7)); + + output_block.z = pack_8x4bit_into_uint( + extract_4bit(input_block_data, col_offset + 2, 0), + extract_4bit(input_block_data, col_offset + 2, 1), + extract_4bit(input_block_data, col_offset + 2, 2), + extract_4bit(input_block_data, col_offset + 2, 3), + extract_4bit(input_block_data, col_offset + 2, 4), + extract_4bit(input_block_data, col_offset + 2, 5), + extract_4bit(input_block_data, col_offset + 2, 6), + extract_4bit(input_block_data, col_offset + 2, 7)); + + output_block.w = pack_8x4bit_into_uint( + extract_4bit(input_block_data, col_offset + 3, 0), + extract_4bit(input_block_data, col_offset + 3, 1), + extract_4bit(input_block_data, col_offset + 3, 2), + extract_4bit(input_block_data, col_offset + 3, 3), + extract_4bit(input_block_data, col_offset + 3, 4), + extract_4bit(input_block_data, col_offset + 3, 5), + extract_4bit(input_block_data, col_offset + 3, 6), + extract_4bit(input_block_data, col_offset + 3, 7)); + + const uint qmat2_texel_stride_x = DIV_UP_4(qmat2_sizes.x); + write_transposed_weight_block( + output_block, + block_pos.x, + block_pos.y, + qmat2_texel_stride_x); + + if (MUL_8(block_start.x) + 4 >= input_width) { + return; + } + // Otherwise, implement the block position to write to the next block in the + // following iteration. + block_pos.x += 1; + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_block_4x8.yaml b/backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_block_4x8.yaml new file mode 100644 index 00000000000..c72a2cc1df6 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_block_4x8.yaml @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +pack_int4_linear_weight_transposed_block_4x8: + parameter_names_with_default_values: + STORAGE: buffer + shader_variants: + - NAME: pack_int4_linear_weight_transposed_block_4x8_buffer + STORAGE: buffer + - NAME: pack_int4_linear_weight_transposed_block_4x8_texture2d + STORAGE: texture2d diff --git a/backends/vulkan/runtime/graph/ops/glsl/qlinear_utils.glslh b/backends/vulkan/runtime/graph/ops/glsl/qlinear_utils.glslh new file mode 100644 index 00000000000..987ae06773f --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/qlinear_utils.glslh @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef QLINEAR_UTILS_H +#define QLINEAR_UTILS_H + +/*********************************** + * Packed Weight data read/write functions + * + * These functions assume that t_qmat2 is declared in the shader layout as a storage + * buffer or storage image. + */ + +#ifdef BUFFER_WEIGHT + +uvec4 load_transposed_weight_block(const uint k4, const uint n8, const uint K4) { + return t_qmat2[n8 * K4 + k4]; +} + +#else // TEXTURE_WEIGHT + +uvec4 load_transposed_weight_block(const uint k4, const uint n8, const uint K4) { + return texelFetch(t_qmat2, ivec2(k4, n8), 0); +} + +#endif // BUFFER_WEIGHT + +/*********************************** + * Packed weight data extraction functions + */ + +float extract_4bit_from_transposed_block(const uvec4 block, const uint col, const uint row) { + return float(int((block[row] >> (4 * (7 - col))) & 15) - 8); +} + +/*********************************** + * Input/Output read/write functions + * + * These functions assume that t_input and t_output are declared in the shader layout as + * storage buffers or storage images. + */ + +#ifdef BUFFER_IO + +VEC4_T load_input_texel(const uint k4) { + return t_input[k4]; +} + +void write_output_texel(const VEC4_T out_texel, const uint n4) { + t_output[n4] = out_texel; +} + +#else // TEXTURE_IO + +VEC4_T load_input_texel(const uint k4) { + return texelFetch(t_input, ivec3(k4, 0, 0), 0); +} + +void write_output_texel(const VEC4_T out_texel, const uint n4) { + imageStore(t_output, ivec3(n4, 0, 0), out_texel); +} + +#endif // BUFFER_IO + +#endif // QLINEAR_UTILS_H diff --git a/backends/vulkan/runtime/graph/ops/glsl/qlinear_weight_pack_utils.glslh b/backends/vulkan/runtime/graph/ops/glsl/qlinear_weight_pack_utils.glslh new file mode 100644 index 00000000000..1f481f4f859 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/qlinear_weight_pack_utils.glslh @@ -0,0 +1,58 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef QLINEAR_WEIGHT_PACK_UTILS_H +#define QLINEAR_WEIGHT_PACK_UTILS_H + +/*********************************** + * Packed Weight data write functions + * + * These functions assume that t_qmat2 has been defined in the shader layout as either + * a storage buffer or a storage image. + */ + +#ifdef BUFFER_WEIGHT + +void write_transposed_weight_block(const uvec4 block, const uint k4, const uint n8, const uint K4) { + t_qmat2[n8 * K4 + k4] = block; +} + +#else // TEXTURE_WEIGHT + +void write_transposed_weight_block(const uvec4 block, const uint k4, const uint n8, const uint K4) { + imageStore(t_qmat2, ivec2(k4, n8), block); +} + +#endif // BUFFER_WEIGHT + +/*********************************** + * Utilities for packing weight data + */ + +uint extract_4bit_from_packed_uint_le(const uint packed, const uint i) { + // account for little endian + uint byte = packed >> (8 * (i / 2)) & 255; + return (byte >> (4 - 4 * (i % 2))) & 15; +} + +uint pack_8x4bit_into_uint( + const uint val0, + const uint val1, + const uint val2, + const uint val3, + const uint val4, + const uint val5, + const uint val6, + const uint val7) { + return uint( + (val0 << 28) | (val1 << 24) | (val2 << 20) | (val3 << 16) | (val4 << 12) | + (val5 << 8) | (val6 << 4) | val7 + ); +} + +#endif // QLINEAR_WEIGHT_PACK_UTILS_H diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQGANW.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQGANW.cpp index d9425b8b62f..5e6bb35b029 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQGANW.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQGANW.cpp @@ -50,25 +50,28 @@ void resize_linear_qga4w_node( const std::vector& extra_args) { (void)extra_args; - vTensorPtr out = graph->get_tensor(args[0].refs[0]); - vTensorPtr mat1 = graph->get_tensor(args[1].refs[0]); - vTensorPtr mat2 = graph->get_tensor(args[1].refs[1]); + ValueRef out = args.at(0).refs.at(0); + ValueRef mat1 = args.at(1).refs.at(0); + ValueRef mat2_data = extra_args.at(0); - const int out_cols = utils::val_at(-2, mat1->sizes()); - const int out_rows = utils::val_at(-1, mat2->sizes()) * 2; + std::vector mat1_sizes = graph->sizes_of(mat1); + std::vector mat2_sizes = graph->sizes_of(mat2_data); + + const int64_t out_cols = utils::val_at(-2, mat1_sizes); + const int64_t out_rows = utils::val_at(-2, mat2_sizes); std::vector new_out_sizes(3); - if (mat1->sizes().size() == 2) { + if (mat1_sizes.size() == 2) { new_out_sizes.resize(2); new_out_sizes.at(0) = out_cols; new_out_sizes.at(1) = out_rows; } else { - new_out_sizes.at(0) = mat1->sizes().at(0); + new_out_sizes.at(0) = mat1_sizes.at(0); new_out_sizes.at(1) = out_cols; new_out_sizes.at(2) = out_rows; } - out->virtual_resize(new_out_sizes); + graph->virtual_resize(out, new_out_sizes); } /** @@ -117,14 +120,18 @@ utils::uvec3 linear_qga4w_global_wg_size( const bool use_coop_algorithm = shader.kernel_name.find("_coop") != std::string::npos; - utils::uvec3 global_wg_size = graph->logical_limits_of(out); - global_wg_size[0] = utils::div_up(global_wg_size[0], uint32_t(2)); - if (!use_coop_algorithm) { + utils::uvec3 global_wg_size = graph->logical_limits_of(out); + global_wg_size[0] = utils::div_up(global_wg_size[0], uint32_t(2)); + global_wg_size[1] = utils::div_up(global_wg_size[1], uint32_t(3)); + return global_wg_size; } - return global_wg_size; + uint32_t output_channels = graph->size_at(-1, out); + uint32_t batch_size = graph->size_at(-2, out); + + return {64, utils::div_up(output_channels, 8u), batch_size}; } utils::uvec3 linear_qga4w_local_wg_size( @@ -139,7 +146,7 @@ utils::uvec3 linear_qga4w_local_wg_size( shader.kernel_name.find("_coop") != std::string::npos; if (use_coop_algorithm) { - return {8, 1, 8}; + return {64, 1, 1}; } else { return graph->create_local_wg_size(global_workgroup_size); } @@ -155,13 +162,19 @@ void add_linear_qga4w_node( check_linear_qga4w_args( graph, mat1, mat2_data, group_size, scales_and_zeros_data, out); + bool is_gemv = should_use_coop_algorithm(&graph, mat1); const uint32_t group_size_val = graph.extract_scalar(group_size); - ValueRef mat2 = - prepack_int4_linear_weight_transposed_interleaved(graph, mat2_data); + ValueRef mat2 = is_gemv + ? prepack_int4_linear_weight_transposed_block_4x8(graph, mat2_data) + : prepack_int4_linear_weight_transposed_interleaved(graph, mat2_data); + + ValueRef scales_and_zeros = is_gemv + ? prepack_standard( + graph, scales_and_zeros_data, utils::kBuffer, utils::kWidthPacked) + : prepack_standard_hw_transposed( + graph, scales_and_zeros_data, utils::kBuffer, utils::kWidthPacked); - ValueRef scales_and_zeros = prepack_standard_hw_transposed( - graph, scales_and_zeros_data, utils::kBuffer, utils::kWidthPacked); graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, pick_linear_qga4w_shader, @@ -178,7 +191,7 @@ void add_linear_qga4w_node( // Specialization Constants {SV(group_size_val)}, // Resize Args - {}, + {mat2_data}, // Resizing Logic resize_linear_qga4w_node)); } diff --git a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp index f429ab0fc25..bfaad716059 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp @@ -322,6 +322,75 @@ ValueRef prepack_int4_linear_weight_transposed_interleaved( return qmat2; } +ValueRef prepack_int4_linear_weight_transposed_block_4x8( + ComputeGraph& graph, + const ValueRef qmat2_data) { + std::vector qmat2_orig_sizes = graph.sizes_of(qmat2_data); + const int64_t ndim = graph.dim_of(qmat2_data); + + const int64_t K_div2 = qmat2_orig_sizes.at(ndim - 1); // Input is [N, K/2] + const int64_t N = qmat2_orig_sizes.at(ndim - 2); + // Logical K dimension. Each value in the tensor is a uint8 that contains 2 + // packed 4-bit values. + const int64_t K = K_div2 * 2; + + // This packing format partitions the weight tensor into 4 wide x 8 high + // blocks. To figure out the size of the output tensor, determine the number + // of blocks along the width and height dims. + const int64_t num_blocks_K = utils::div_up(K, int64_t(4)); + const int64_t num_blocks_N = utils::div_up(N, int64_t(8)); + // Each transposed block is 8 wide x 4 high. In terms of 8-bit values, the + // block is 4 wide x 4 high. To maximize memory loading efficiency, the packed + // weight tensor will use a base data type of uint32_t; in terms of uint32_t, + // each block is 1 wide x 4 high. However, each block is also flattened as it + // is stored, so that the whole block can be loaded at once. As a result, the + // stored block will be 4 wide x 1 high. + const int64_t output_width = num_blocks_K * 4; + const int64_t output_height = num_blocks_N; + + // Store the original sizes of the tensor to pass to the shader + utils::ivec2 orig_sizes{ + utils::safe_downcast(K), utils::safe_downcast(N)}; + + std::vector qmat2_sizes{output_height, output_width}; + + utils::StorageType storage_type = utils::kTexture2D; + uint32_t max_extent = graph.context()->adapter_ptr()->max_texture2d_dim(); + if (output_width > max_extent * 4 || output_height > max_extent) { + storage_type = utils::kBuffer; + } + + ValueRef qmat2 = graph.add_tensor( + qmat2_sizes, vkcompute::vkapi::kUInt, storage_type, utils::kWidthPacked); + + // Global workgroup size: each thread writes out two adjacent blocks + utils::uvec3 global_wg_size{ + utils::div_up(utils::safe_downcast(num_blocks_K), uint32_t(2)), + utils::safe_downcast(num_blocks_N), + 1u}; + + std::string kernel_name = "pack_int4_linear_weight_transposed_block_4x8"; + add_storage_type_suffix(kernel_name, storage_type); + + graph.prepack_nodes().emplace_back(new PrepackNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_wg_size, + graph.create_local_wg_size(global_wg_size), + // Inputs and Outputs + qmat2_data, + qmat2, + // UBOs + {}, + // Specialization Constants + {}, + // Push Constants + {graph.sizes_pc_of(qmat2), + PushConstantDataInfo(&orig_sizes, sizeof(utils::ivec2))})); + + return qmat2; +} + void prepack_op(ComputeGraph& graph, const std::vector& args) { return add_prepack_standard_node(graph, args[0], args[1]); } diff --git a/backends/vulkan/runtime/graph/ops/impl/Staging.h b/backends/vulkan/runtime/graph/ops/impl/Staging.h index 090a3718295..0b1568ca139 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Staging.h +++ b/backends/vulkan/runtime/graph/ops/impl/Staging.h @@ -89,9 +89,14 @@ ValueRef prepack_direct_copy_buffer( // // Op specific prepack functions +// ValueRef prepack_int4_linear_weight_transposed_interleaved( ComputeGraph& graph, const ValueRef qmat2_data); +ValueRef prepack_int4_linear_weight_transposed_block_4x8( + ComputeGraph& graph, + const ValueRef qmat2_data); + } // namespace vkcompute