diff --git a/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h index 0cfd7f2f119..72650bb7040 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h +++ b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h @@ -68,6 +68,18 @@ */ #define mod4(x) ((x) & 3) +#define ALIGN_UP_4(x) (((x) + 3) & ~3) + +#define DIV_UP_8(x) (((x) + 7) >> 3) +#define DIV_UP_4(x) (((x) + 3) >> 2) + +#define DIV_4(x) ((x) >> 2) +#define DIV_2(x) ((x) >> 1) + +#define MUL_8(x) ((x) << 3) +#define MUL_4(x) ((x) << 2) +#define MUL_2(x) ((x) << 1) + /* * Get the staging buffer indices that contain the data of the texel that * corresponds to the provided tensor index. Since the texel have 4 elements, diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_coop.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_coop.glsl index 715f84d3a56..f46c1f01c7b 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_coop.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_coop.glsl @@ -13,187 +13,147 @@ #define T ${buffer_scalar_type(DTYPE)} #define VEC4_T ${buffer_gvec_type(DTYPE, 4)} -#define TILE_ROWS ${TILE_ROWS} - -#define NGROUPS 8 -#define NWORKERS 8 +#define WGS ${WGS} ${define_required_extensions(DTYPE)} -$if WEIGHT_STORAGE == "buffer": - ${define_required_extensions("uint8")} +${define_required_extensions("uint8")} #extension GL_EXT_control_flow_attributes : require +#extension GL_EXT_debug_printf : require layout(std430) buffer; -${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_mat1", DTYPE, IN_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_qmat2", "uint8", WEIGHT_STORAGE, is_scalar_array=False)} +#include "indexing_utils.h" + +${layout_declare_tensor(B, "w", "t_output", DTYPE, IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_input", DTYPE, IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_qmat2", "uint", WEIGHT_STORAGE, is_scalar_array=False)} ${layout_declare_tensor(B, "r", "t_qparams", DTYPE, "buffer", is_scalar_array=False)} layout(push_constant) uniform restrict Block { - ivec4 out_sizes; - ivec4 mat1_sizes; - ivec4 qmat2_sizes; + ivec4 output_sizes; + ivec4 input_sizes; + ivec4 weight_sizes; }; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; layout(constant_id = 3) const int group_size = 64; -shared VEC4_T partial_sums[NGROUPS][NWORKERS][TILE_ROWS][2]; +shared VEC4_T partial_sums[WGS][2]; + +$if IO_STORAGE == "buffer": + #define BUFFER_IO +$if WEIGHT_STORAGE == "buffer": + #define BUFFER_WEIGHT + +#include "qlinear_utils.glslh" -/* - * This shader computes a linear operator between a floating point input matrix - * x and a weights matrix that is quantized to 4 bits. Please refer to the - * q_4w_linear shader for more details. - * - * This shader implements a co-operative algorithm to compute the output. The - * work group size is {NGROUP, 1, NWORKERS}, and each group of NWORKERS threads - * cooperative to compute TILE_ROWS * 2 output texels. Therefore, - * NGROUP * TILE_ROWS * 2 output texels are computed across one work group. - * - * The threads co-operate by each thread computing a partial reduction along the - * K dimension. To illustrate the computation, consider a scalar variant of the - * algorithm that computes the dot product of 2 vectors. Also assume that - * NWORKERS is 8. - * - * Thread 1 in each group will compute: - * (mat1[0] * mat2[0]) + (mat1[8] * mat2[8]) + (mat1[16] * mat2[16]) + ... - * - * Thread 2 in each group will compute: - * (mat1[1] * mat2[1]) + (mat2[9] * mat2[9]) + (mat1[17] * mat2[17]) + ... - * - * Thread 3 in each group will compute: - * (mat1[2] * mat2[2]) + (mat2[10] * mat2[10]) + (mat1[18] * mat2[18]) + ... - * - * The partial accumulations is structured such that memory accesses in each - * loop iteration can be coalesced. - * - * Then, at the end first thread in each group will accumulate the partial - * accumulations computed by each thread to obtain the final result. - * - * Note that this shader assumes that all tensors are width packed. - */ void main() { - const uint out_row = gl_GlobalInvocationID.y * TILE_ROWS; - // Each thread writes out 2 texels along the width axis, equivalent to 8 - // scalar elements. Therefore multiply the thread_idx.x by 8. - const uint out_col = gl_GlobalInvocationID.x << 3; - // Similar reasoning to the above, each thread works on 2 texels along the - // width axis so multiply thread_idx.x by 2. - const int out_col_texel_idx = int(gl_GlobalInvocationID.x) << 1; - - const uint gid = gl_LocalInvocationID.x; // group id - const uint wid = gl_LocalInvocationID.z; // worker id - - if (out_col >= out_sizes.x || out_row >= out_sizes.y) { + const uint lid = gl_LocalInvocationID.x; + const uint n8 = gl_GlobalInvocationID.y; + // The output tensor will have a shape of [n, 1, 1, 1]. Each thread computes + // 8 output elements, so each thread will write to 8 elements starting at the + // tensor index (gid.x * 8, 0, 0, 0). + const uint n = MUL_8(n8); + const uint K4 = DIV_UP_4(input_sizes.x); + + if (n >= output_sizes.x) { return; } - const int num_blocks = mat1_sizes.x / group_size; + VEC4_T out_texels[2]; + out_texels[0] = VEC4_T(0); + out_texels[1] = VEC4_T(0); - VEC4_T mat1[TILE_ROWS]; - VEC4_T qmat2[4][2]; - VEC4_T local_sums[TILE_ROWS][2]; + // initialize the group index to a value larger than the largest possible + uint cur_group_idx = input_sizes.x; - [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { - local_sums[r][0] = VEC4_T(0); - local_sums[r][1] = VEC4_T(0); - } + // Each thread in the work group accumulates a partial result. + for (uint k4 = lid; k4 < DIV_UP_4(input_sizes.x); k4 += WGS) { + const uint k = MUL_4(k4); + const uint group_idx = k / group_size; - VEC4_T scales[2]; - VEC4_T zeros[2]; - - $if WEIGHT_STORAGE == "buffer": - const int qmat2_stride = qmat2_sizes.x >> 2; - $if PARAMS_STORAGE == "buffer": - const int qparams_y_stride = out_sizes.x >> 2; - const int qparams_z_stride = qparams_y_stride * 2; - - for (int block_idx = 0; block_idx < num_blocks; ++block_idx) { - $if PARAMS_STORAGE == "buffer": - scales[0] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx]; - zeros[0] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + qparams_y_stride]; - - scales[1] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + 1]; - zeros[1] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + 1 + qparams_y_stride]; - $else: - scales[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 0, block_idx), 0); - zeros[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 1, block_idx), 0); - - scales[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 0, block_idx), 0); - zeros[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 1, block_idx), 0); - - for (uint g_idx = 4 * wid; g_idx < group_size; g_idx += (4 * NWORKERS)) { - const uint k = block_idx * group_size + g_idx; - - // Preload B - [[unroll]] for (int r = 0; r < 4; ++r) { - $if WEIGHT_STORAGE == "buffer": - const u8vec4 packed_weight_tex = t_qmat2[(k + r) * qmat2_stride + gl_GlobalInvocationID.x]; - $else: - const uvec4 packed_weight_tex = texelFetch( - t_qmat2, - ivec2(gl_GlobalInvocationID.x, k + r), - 0); - - qmat2[r][0] = (VEC4_T((packed_weight_tex & 0xF0) >> 4) - 8.0) * scales[0] + zeros[0]; - qmat2[r][1] = (VEC4_T(packed_weight_tex & 0x0F) - 8.0) * scales[1] + zeros[1]; - } + VEC4_T scales[2]; + VEC4_T zeros[2]; - // Preload A - [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { - $if IN_STORAGE == "buffer": - mat1[r] = t_mat1[((out_row + r) * mat1_sizes.x + k) >> 2]; - $else: - mat1[r] = texelFetch(t_mat1, ivec3(k >> 2, out_row + r, 0), 0); - } + // Only update the scales/zeros if the current iteration is now working on a + // new quantization group. + if (group_idx != cur_group_idx) { + // The qparams tensor contains the quantization scales and zeros, with + // shape [2, N, K / group_size, 1]. + // Loading a texel from the qparams tensor will return 2 scales and 2 + // zeros for 2 adjacent output channels. + uint qparams_bufi = group_idx * DIV_2(output_sizes.x) + DIV_2(n); + VEC4_T scales_zeros_texels[4]; + $for comp in range(4): + scales_zeros_texels[${comp}] = t_qparams[qparams_bufi++]; - // Accumulate local output tile - [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { - local_sums[r][0] += mat1[r].x * qmat2[0][0] - + mat1[r].y * qmat2[1][0] - + mat1[r].z * qmat2[2][0] - + mat1[r].w * qmat2[3][0]; - - local_sums[r][1] += mat1[r].x * qmat2[0][1] - + mat1[r].y * qmat2[1][1] - + mat1[r].z * qmat2[2][1] - + mat1[r].w * qmat2[3][1]; - } + scales[0] = VEC4_T(scales_zeros_texels[0].xz, scales_zeros_texels[1].xz); + zeros[0] = VEC4_T(scales_zeros_texels[0].yw, scales_zeros_texels[1].yw); + + scales[1] = VEC4_T(scales_zeros_texels[2].xz, scales_zeros_texels[3].xz); + zeros[1] = VEC4_T(scales_zeros_texels[2].yw, scales_zeros_texels[3].yw); + + cur_group_idx = group_idx; } + // The input tensor will have a shape of [K, 1, 1, 1]; in each iteration, + // load 4 elements starting from the tensor index (k, 0, 0, 0). + VEC4_T in_texel = load_input_texel(k4); + // Extract each element of the in_texel into a separate vectorized variable; + // these are used to "broadcast" the input values in subsequent fma calls. + VEC4_T in_texel_val[4]; + $for comp in range(4): + in_texel_val[${comp}] = VEC4_T(in_texel[${comp}]); + + uvec4 packed_weight_block = load_transposed_weight_block(k4, n8, K4); + + VEC4_T weight_texels[2]; + $for comp in range(4): + { + weight_texels[0].x = extract_4bit_from_transposed_block(packed_weight_block, 0, ${comp}); + weight_texels[0].y = extract_4bit_from_transposed_block(packed_weight_block, 1, ${comp}); + weight_texels[0].z = extract_4bit_from_transposed_block(packed_weight_block, 2, ${comp}); + weight_texels[0].w = extract_4bit_from_transposed_block(packed_weight_block, 3, ${comp}); + + weight_texels[1].x = extract_4bit_from_transposed_block(packed_weight_block, 4, ${comp}); + weight_texels[1].y = extract_4bit_from_transposed_block(packed_weight_block, 5, ${comp}); + weight_texels[1].z = extract_4bit_from_transposed_block(packed_weight_block, 6, ${comp}); + weight_texels[1].w = extract_4bit_from_transposed_block(packed_weight_block, 7, ${comp}); + + weight_texels[0] = fma(weight_texels[0], scales[0], zeros[0]); + weight_texels[1] = fma(weight_texels[1], scales[1], zeros[1]); + + out_texels[0] = fma(in_texel_val[${comp}], weight_texels[0], out_texels[0]); + out_texels[1] = fma(in_texel_val[${comp}], weight_texels[1], out_texels[1]); + } } - [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { - partial_sums[gid][wid][r][0] = local_sums[r][0]; - partial_sums[gid][wid][r][1] = local_sums[r][1]; - } + partial_sums[lid][0] = out_texels[0]; + partial_sums[lid][1] = out_texels[1]; memoryBarrierShared(); barrier(); - if (wid != 0) { - return; + // Tree reduction to compute the overall result. + for (int i = WGS / 2; i > 0; i /= 2) { + if (lid < i) { + partial_sums[lid][0] = partial_sums[lid][0] + partial_sums[lid + i][0]; + partial_sums[lid][1] = partial_sums[lid][1] + partial_sums[lid + i][1]; + } + memoryBarrierShared(); + barrier(); } - VEC4_T sums[TILE_ROWS][2]; + // Only the first thread will write out result + if (lid == 0) { + out_texels[0] = partial_sums[0][0]; + out_texels[1] = partial_sums[0][1]; - for (int r = 0; r < TILE_ROWS; ++r) { - sums[r][0] = VEC4_T(0); - sums[r][1] = VEC4_T(0); - [[unroll]] for (int worker = 0; worker < NWORKERS; ++ worker) { - sums[r][0] += partial_sums[gid][worker][r][0]; - sums[r][1] += partial_sums[gid][worker][r][1]; + uint n4 = DIV_4(n); + write_output_texel(out_texels[0], n4); + if (n + 4 < output_sizes.x) { + write_output_texel(out_texels[1], n4 + 1); } } - - [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { - $if OUT_STORAGE == "buffer": - t_out[((out_row + r) * out_sizes.x + out_col) >> 2] = sums[r][0]; - t_out[((out_row + r) * out_sizes.x + out_col + 4) >> 2] = sums[r][1]; - $else: - imageStore(t_out, ivec3(out_col_texel_idx, out_row + r, 0), sums[r][0]); - imageStore(t_out, ivec3(out_col_texel_idx + 1, out_row + r, 0), sums[r][1]); - } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_coop.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_coop.yaml index 25ffe94f430..04e803a2e94 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_coop.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_coop.yaml @@ -7,17 +7,13 @@ linear_qga4w_coop: parameter_names_with_default_values: DTYPE: float - OUT_STORAGE: texture3d - IN_STORAGE: texture3d + IO_STORAGE: texture3d WEIGHT_STORAGE: texture2d - PARAMS_STORAGE: buffer - TILE_ROWS: 1 + WGS: 64 shader_variants: - NAME: linear_qga4w_coop_texture3d_texture3d_texture2d_float - NAME: linear_qga4w_coop_buffer_buffer_texture2d_float - OUT_STORAGE: buffer - IN_STORAGE: buffer + IO_STORAGE: buffer - NAME: linear_qga4w_coop_buffer_buffer_buffer_float - OUT_STORAGE: buffer - IN_STORAGE: buffer + IO_STORAGE: buffer WEIGHT_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/no_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/no_op.yaml index bfeaba2496b..f888e8661d3 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/no_op.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/no_op.yaml @@ -13,6 +13,7 @@ no_op: - VALUE: half - VALUE: float - VALUE: int32 + - VALUE: uint32 - VALUE: int8 - VALUE: uint8 STORAGE: diff --git a/backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_block_4x8.glsl b/backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_block_4x8.glsl new file mode 100644 index 00000000000..e42cf05dd7f --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_block_4x8.glsl @@ -0,0 +1,154 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +layout(std430) buffer; + +#include "indexing_utils.h" + +${layout_declare_tensor(B, "w", "t_qmat2", "uint", STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_input", "uint", "buffer")} + +layout(push_constant) uniform restrict Block { + ivec4 qmat2_sizes; + ivec2 orig_sizes; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +$if STORAGE == "buffer": + #define BUFFER_WEIGHT + +#include "qlinear_weight_pack_utils.glslh" + +#define extract_4bit(input_block_data, col, row) \ + (extract_4bit_from_packed_uint_le(input_block_data[row], col)) + +/* + * This shader packs the weight tensor into blocks for efficient consumption. + * + * The input tensor has shape [K/2, N] where each element is a uint8 containing + * 2 packed 4-bit values. The logical tensor shape is [K, N] of 4-bit values. + * + * The transformation partitions the tensor into blocks of size 4x8 (4-bit values) + * and transposes each block to 8x4, then packs the result so that each uvec4 + * contains an entire transposed block. + * + * Original block (4x8 4-bit values, shown as 2x8 uint8 values): + * w00|w10, w20|w30, + * w01|w11, w21|w31, + * w02|w12, w22|w32, + * w03|w13, w23|w33, + * w04|w14, w24|w34, + * w05|w15, w25|w35, + * w06|w16, w26|w36, + * w07|w17, w27|w37, + * + * Transposed block (8x4 4-bit values, packed into uvec4): + * w00|w01, w02|w03, w04|w05, w06|w07 + * w10|w11, w12|w13, w14|w15, w16|w17 + * w20|w21, w22|w23, w24|w25, w26|w27 + * w30|w31, w32|w33, w34|w35, w36|w37 + */ +void main() { + // Each thread writes out 2 adjacent 8 wide x 4 high transposed block. Each + // block is packed as one uvec4. + ivec2 block_pos = ivec2( + MUL_2(gl_GlobalInvocationID.x), + gl_GlobalInvocationID.y); + + // There are K wide x N high 4-bit values in the original weight tensor + const int input_width = orig_sizes.x; // K + const int input_height = orig_sizes.y; // N + + const int input_width_uint = DIV_UP_8(input_width); + + // Original block spans 4 wide x 8 high 4-bit values. Since uint is used to + // read the input tensor, each block spans 0.5 wide x 8 high uint values. + const ivec2 block_start = ivec2( + DIV_2(block_pos.x), + MUL_8(block_pos.y)); + + // Check bounds + if (block_start.x >= input_width_uint || block_start.y >= input_height) { + return; + } + + // Read input block. Note that this block will contain the source data for + // both output blocks, as it contains 1 wide x 8 high uint values, which is + // equivalent to 8 wide x 8 high 4-bit values. + uint input_block_data[8]; + + // Read in 8 rows along the same column of uints, each uint contains 4 4-bit + // values. This will be the source data for the transposed block. + for (int i = 0; i < 8; ++i) { + uint input_bufi = (block_start.y + i) * input_width_uint + block_start.x; + input_block_data[i] = t_input[input_bufi]; + } + + for (int col_offset = 0; col_offset <= 4; col_offset+=4) { + uvec4 output_block; + + output_block.x = pack_8x4bit_into_uint( + extract_4bit(input_block_data, col_offset, 0), + extract_4bit(input_block_data, col_offset, 1), + extract_4bit(input_block_data, col_offset, 2), + extract_4bit(input_block_data, col_offset, 3), + extract_4bit(input_block_data, col_offset, 4), + extract_4bit(input_block_data, col_offset, 5), + extract_4bit(input_block_data, col_offset, 6), + extract_4bit(input_block_data, col_offset, 7)); + + output_block.y = pack_8x4bit_into_uint( + extract_4bit(input_block_data, col_offset + 1, 0), + extract_4bit(input_block_data, col_offset + 1, 1), + extract_4bit(input_block_data, col_offset + 1, 2), + extract_4bit(input_block_data, col_offset + 1, 3), + extract_4bit(input_block_data, col_offset + 1, 4), + extract_4bit(input_block_data, col_offset + 1, 5), + extract_4bit(input_block_data, col_offset + 1, 6), + extract_4bit(input_block_data, col_offset + 1, 7)); + + output_block.z = pack_8x4bit_into_uint( + extract_4bit(input_block_data, col_offset + 2, 0), + extract_4bit(input_block_data, col_offset + 2, 1), + extract_4bit(input_block_data, col_offset + 2, 2), + extract_4bit(input_block_data, col_offset + 2, 3), + extract_4bit(input_block_data, col_offset + 2, 4), + extract_4bit(input_block_data, col_offset + 2, 5), + extract_4bit(input_block_data, col_offset + 2, 6), + extract_4bit(input_block_data, col_offset + 2, 7)); + + output_block.w = pack_8x4bit_into_uint( + extract_4bit(input_block_data, col_offset + 3, 0), + extract_4bit(input_block_data, col_offset + 3, 1), + extract_4bit(input_block_data, col_offset + 3, 2), + extract_4bit(input_block_data, col_offset + 3, 3), + extract_4bit(input_block_data, col_offset + 3, 4), + extract_4bit(input_block_data, col_offset + 3, 5), + extract_4bit(input_block_data, col_offset + 3, 6), + extract_4bit(input_block_data, col_offset + 3, 7)); + + const uint qmat2_texel_stride_x = DIV_UP_4(qmat2_sizes.x); + write_transposed_weight_block( + output_block, + block_pos.x, + block_pos.y, + qmat2_texel_stride_x); + + if (MUL_8(block_start.x) + 4 >= input_width) { + return; + } + // Otherwise, implement the block position to write to the next block in the + // following iteration. + block_pos.x += 1; + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_block_4x8.yaml b/backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_block_4x8.yaml new file mode 100644 index 00000000000..c72a2cc1df6 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_block_4x8.yaml @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +pack_int4_linear_weight_transposed_block_4x8: + parameter_names_with_default_values: + STORAGE: buffer + shader_variants: + - NAME: pack_int4_linear_weight_transposed_block_4x8_buffer + STORAGE: buffer + - NAME: pack_int4_linear_weight_transposed_block_4x8_texture2d + STORAGE: texture2d diff --git a/backends/vulkan/runtime/graph/ops/glsl/qlinear_utils.glslh b/backends/vulkan/runtime/graph/ops/glsl/qlinear_utils.glslh new file mode 100644 index 00000000000..987ae06773f --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/qlinear_utils.glslh @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef QLINEAR_UTILS_H +#define QLINEAR_UTILS_H + +/*********************************** + * Packed Weight data read/write functions + * + * These functions assume that t_qmat2 is declared in the shader layout as a storage + * buffer or storage image. + */ + +#ifdef BUFFER_WEIGHT + +uvec4 load_transposed_weight_block(const uint k4, const uint n8, const uint K4) { + return t_qmat2[n8 * K4 + k4]; +} + +#else // TEXTURE_WEIGHT + +uvec4 load_transposed_weight_block(const uint k4, const uint n8, const uint K4) { + return texelFetch(t_qmat2, ivec2(k4, n8), 0); +} + +#endif // BUFFER_WEIGHT + +/*********************************** + * Packed weight data extraction functions + */ + +float extract_4bit_from_transposed_block(const uvec4 block, const uint col, const uint row) { + return float(int((block[row] >> (4 * (7 - col))) & 15) - 8); +} + +/*********************************** + * Input/Output read/write functions + * + * These functions assume that t_input and t_output are declared in the shader layout as + * storage buffers or storage images. + */ + +#ifdef BUFFER_IO + +VEC4_T load_input_texel(const uint k4) { + return t_input[k4]; +} + +void write_output_texel(const VEC4_T out_texel, const uint n4) { + t_output[n4] = out_texel; +} + +#else // TEXTURE_IO + +VEC4_T load_input_texel(const uint k4) { + return texelFetch(t_input, ivec3(k4, 0, 0), 0); +} + +void write_output_texel(const VEC4_T out_texel, const uint n4) { + imageStore(t_output, ivec3(n4, 0, 0), out_texel); +} + +#endif // BUFFER_IO + +#endif // QLINEAR_UTILS_H diff --git a/backends/vulkan/runtime/graph/ops/glsl/qlinear_weight_pack_utils.glslh b/backends/vulkan/runtime/graph/ops/glsl/qlinear_weight_pack_utils.glslh new file mode 100644 index 00000000000..1f481f4f859 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/qlinear_weight_pack_utils.glslh @@ -0,0 +1,58 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef QLINEAR_WEIGHT_PACK_UTILS_H +#define QLINEAR_WEIGHT_PACK_UTILS_H + +/*********************************** + * Packed Weight data write functions + * + * These functions assume that t_qmat2 has been defined in the shader layout as either + * a storage buffer or a storage image. + */ + +#ifdef BUFFER_WEIGHT + +void write_transposed_weight_block(const uvec4 block, const uint k4, const uint n8, const uint K4) { + t_qmat2[n8 * K4 + k4] = block; +} + +#else // TEXTURE_WEIGHT + +void write_transposed_weight_block(const uvec4 block, const uint k4, const uint n8, const uint K4) { + imageStore(t_qmat2, ivec2(k4, n8), block); +} + +#endif // BUFFER_WEIGHT + +/*********************************** + * Utilities for packing weight data + */ + +uint extract_4bit_from_packed_uint_le(const uint packed, const uint i) { + // account for little endian + uint byte = packed >> (8 * (i / 2)) & 255; + return (byte >> (4 - 4 * (i % 2))) & 15; +} + +uint pack_8x4bit_into_uint( + const uint val0, + const uint val1, + const uint val2, + const uint val3, + const uint val4, + const uint val5, + const uint val6, + const uint val7) { + return uint( + (val0 << 28) | (val1 << 24) | (val2 << 20) | (val3 << 16) | (val4 << 12) | + (val5 << 8) | (val6 << 4) | val7 + ); +} + +#endif // QLINEAR_WEIGHT_PACK_UTILS_H diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQGANW.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQGANW.cpp index d9425b8b62f..5e6bb35b029 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQGANW.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQGANW.cpp @@ -50,25 +50,28 @@ void resize_linear_qga4w_node( const std::vector& extra_args) { (void)extra_args; - vTensorPtr out = graph->get_tensor(args[0].refs[0]); - vTensorPtr mat1 = graph->get_tensor(args[1].refs[0]); - vTensorPtr mat2 = graph->get_tensor(args[1].refs[1]); + ValueRef out = args.at(0).refs.at(0); + ValueRef mat1 = args.at(1).refs.at(0); + ValueRef mat2_data = extra_args.at(0); - const int out_cols = utils::val_at(-2, mat1->sizes()); - const int out_rows = utils::val_at(-1, mat2->sizes()) * 2; + std::vector mat1_sizes = graph->sizes_of(mat1); + std::vector mat2_sizes = graph->sizes_of(mat2_data); + + const int64_t out_cols = utils::val_at(-2, mat1_sizes); + const int64_t out_rows = utils::val_at(-2, mat2_sizes); std::vector new_out_sizes(3); - if (mat1->sizes().size() == 2) { + if (mat1_sizes.size() == 2) { new_out_sizes.resize(2); new_out_sizes.at(0) = out_cols; new_out_sizes.at(1) = out_rows; } else { - new_out_sizes.at(0) = mat1->sizes().at(0); + new_out_sizes.at(0) = mat1_sizes.at(0); new_out_sizes.at(1) = out_cols; new_out_sizes.at(2) = out_rows; } - out->virtual_resize(new_out_sizes); + graph->virtual_resize(out, new_out_sizes); } /** @@ -117,14 +120,18 @@ utils::uvec3 linear_qga4w_global_wg_size( const bool use_coop_algorithm = shader.kernel_name.find("_coop") != std::string::npos; - utils::uvec3 global_wg_size = graph->logical_limits_of(out); - global_wg_size[0] = utils::div_up(global_wg_size[0], uint32_t(2)); - if (!use_coop_algorithm) { + utils::uvec3 global_wg_size = graph->logical_limits_of(out); + global_wg_size[0] = utils::div_up(global_wg_size[0], uint32_t(2)); + global_wg_size[1] = utils::div_up(global_wg_size[1], uint32_t(3)); + return global_wg_size; } - return global_wg_size; + uint32_t output_channels = graph->size_at(-1, out); + uint32_t batch_size = graph->size_at(-2, out); + + return {64, utils::div_up(output_channels, 8u), batch_size}; } utils::uvec3 linear_qga4w_local_wg_size( @@ -139,7 +146,7 @@ utils::uvec3 linear_qga4w_local_wg_size( shader.kernel_name.find("_coop") != std::string::npos; if (use_coop_algorithm) { - return {8, 1, 8}; + return {64, 1, 1}; } else { return graph->create_local_wg_size(global_workgroup_size); } @@ -155,13 +162,19 @@ void add_linear_qga4w_node( check_linear_qga4w_args( graph, mat1, mat2_data, group_size, scales_and_zeros_data, out); + bool is_gemv = should_use_coop_algorithm(&graph, mat1); const uint32_t group_size_val = graph.extract_scalar(group_size); - ValueRef mat2 = - prepack_int4_linear_weight_transposed_interleaved(graph, mat2_data); + ValueRef mat2 = is_gemv + ? prepack_int4_linear_weight_transposed_block_4x8(graph, mat2_data) + : prepack_int4_linear_weight_transposed_interleaved(graph, mat2_data); + + ValueRef scales_and_zeros = is_gemv + ? prepack_standard( + graph, scales_and_zeros_data, utils::kBuffer, utils::kWidthPacked) + : prepack_standard_hw_transposed( + graph, scales_and_zeros_data, utils::kBuffer, utils::kWidthPacked); - ValueRef scales_and_zeros = prepack_standard_hw_transposed( - graph, scales_and_zeros_data, utils::kBuffer, utils::kWidthPacked); graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, pick_linear_qga4w_shader, @@ -178,7 +191,7 @@ void add_linear_qga4w_node( // Specialization Constants {SV(group_size_val)}, // Resize Args - {}, + {mat2_data}, // Resizing Logic resize_linear_qga4w_node)); } diff --git a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp index f429ab0fc25..bfaad716059 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp @@ -322,6 +322,75 @@ ValueRef prepack_int4_linear_weight_transposed_interleaved( return qmat2; } +ValueRef prepack_int4_linear_weight_transposed_block_4x8( + ComputeGraph& graph, + const ValueRef qmat2_data) { + std::vector qmat2_orig_sizes = graph.sizes_of(qmat2_data); + const int64_t ndim = graph.dim_of(qmat2_data); + + const int64_t K_div2 = qmat2_orig_sizes.at(ndim - 1); // Input is [N, K/2] + const int64_t N = qmat2_orig_sizes.at(ndim - 2); + // Logical K dimension. Each value in the tensor is a uint8 that contains 2 + // packed 4-bit values. + const int64_t K = K_div2 * 2; + + // This packing format partitions the weight tensor into 4 wide x 8 high + // blocks. To figure out the size of the output tensor, determine the number + // of blocks along the width and height dims. + const int64_t num_blocks_K = utils::div_up(K, int64_t(4)); + const int64_t num_blocks_N = utils::div_up(N, int64_t(8)); + // Each transposed block is 8 wide x 4 high. In terms of 8-bit values, the + // block is 4 wide x 4 high. To maximize memory loading efficiency, the packed + // weight tensor will use a base data type of uint32_t; in terms of uint32_t, + // each block is 1 wide x 4 high. However, each block is also flattened as it + // is stored, so that the whole block can be loaded at once. As a result, the + // stored block will be 4 wide x 1 high. + const int64_t output_width = num_blocks_K * 4; + const int64_t output_height = num_blocks_N; + + // Store the original sizes of the tensor to pass to the shader + utils::ivec2 orig_sizes{ + utils::safe_downcast(K), utils::safe_downcast(N)}; + + std::vector qmat2_sizes{output_height, output_width}; + + utils::StorageType storage_type = utils::kTexture2D; + uint32_t max_extent = graph.context()->adapter_ptr()->max_texture2d_dim(); + if (output_width > max_extent * 4 || output_height > max_extent) { + storage_type = utils::kBuffer; + } + + ValueRef qmat2 = graph.add_tensor( + qmat2_sizes, vkcompute::vkapi::kUInt, storage_type, utils::kWidthPacked); + + // Global workgroup size: each thread writes out two adjacent blocks + utils::uvec3 global_wg_size{ + utils::div_up(utils::safe_downcast(num_blocks_K), uint32_t(2)), + utils::safe_downcast(num_blocks_N), + 1u}; + + std::string kernel_name = "pack_int4_linear_weight_transposed_block_4x8"; + add_storage_type_suffix(kernel_name, storage_type); + + graph.prepack_nodes().emplace_back(new PrepackNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_wg_size, + graph.create_local_wg_size(global_wg_size), + // Inputs and Outputs + qmat2_data, + qmat2, + // UBOs + {}, + // Specialization Constants + {}, + // Push Constants + {graph.sizes_pc_of(qmat2), + PushConstantDataInfo(&orig_sizes, sizeof(utils::ivec2))})); + + return qmat2; +} + void prepack_op(ComputeGraph& graph, const std::vector& args) { return add_prepack_standard_node(graph, args[0], args[1]); } diff --git a/backends/vulkan/runtime/graph/ops/impl/Staging.h b/backends/vulkan/runtime/graph/ops/impl/Staging.h index 090a3718295..0b1568ca139 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Staging.h +++ b/backends/vulkan/runtime/graph/ops/impl/Staging.h @@ -89,9 +89,14 @@ ValueRef prepack_direct_copy_buffer( // // Op specific prepack functions +// ValueRef prepack_int4_linear_weight_transposed_interleaved( ComputeGraph& graph, const ValueRef qmat2_data); +ValueRef prepack_int4_linear_weight_transposed_block_4x8( + ComputeGraph& graph, + const ValueRef qmat2_data); + } // namespace vkcompute