From 596ee331992d94bde83a35acd0a9c2bed6e93c0a Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Mon, 14 Oct 2024 14:08:46 -0700 Subject: [PATCH 1/2] [ET-VK] Fix implementation of int4 quantized linear ## Context Fix the existing implementation of int4 weight quantized linear to conform with how the `_weight_int4packed_mm` op works in the ATen library. For some additional context, the current op implementation does not actually match the behaviour of `_weight_int4packed_mm`. The ATen op expects that the weights have already been packed into a specific format, with `inner_k_tiles` as a packing parameter. The packing is accomplished via calling the `_convert_weight_to_int4pack` operator. Thus the current implementation in vulkan is equivalent to calling `_convert_weight_to_int4pack` + `_weight_int4packed_mm`. To address this discrepancy, the operator implementation is registered under the `linear_weight_int4` custom op as of this diff. The problems with the existing implementation were as follows: * The expected sizes of the scales and zeros tensor was incorrect. Previously, the sizes were assumed to be `(2, N, num_groups)` but the correct size is `(num_groups, N, 2)` * Previously, when unpacking a uint8_t into 2 unpacked int4 values, it was assumed that the LSB was the first value and the MSB was the second value. However, this ordering should be flipped * The original implementation expected the output tensor to be channels packed, but in practice we want the output tensor to be width packed This diff addresses the above issues, and introduces a dedicated test binary to test against an equivalent reference implementation expressed with ATen functions. Differential Revision: [D64354773](https://our.internmc.facebook.com/intern/diff/D64354773/) [ghstack-poisoned] --- .../vulkan/runtime/api/containers/Tensor.h | 10 + backends/vulkan/runtime/gen_vulkan_spv.py | 33 +-- backends/vulkan/runtime/graph/ComputeGraph.h | 4 + .../graph/ops/glsl/buffer_to_buffer.yaml | 1 + .../vulkan/runtime/graph/ops/glsl/no_op.yaml | 1 + .../runtime/graph/ops/glsl/q_4w_linear.glsl | 179 +++++++------- .../runtime/graph/ops/glsl/q_4w_linear.yaml | 7 +- .../graph/ops/impl/QuantizedLinear.cpp | 141 ++++++++++- .../graph/ops/impl/QuantizedMatMul.cpp | 184 --------------- .../vulkan/runtime/graph/ops/impl/Staging.cpp | 5 +- .../test/op_tests/linear_weight_int4_test.cpp | 218 ++++++++++++++++++ backends/vulkan/test/op_tests/targets.bzl | 38 +++ .../vulkan/test/vulkan_compute_api_test.cpp | 136 ----------- 13 files changed, 512 insertions(+), 445 deletions(-) delete mode 100644 backends/vulkan/runtime/graph/ops/impl/QuantizedMatMul.cpp create mode 100644 backends/vulkan/test/op_tests/linear_weight_int4_test.cpp diff --git a/backends/vulkan/runtime/api/containers/Tensor.h b/backends/vulkan/runtime/api/containers/Tensor.h index 7a113c939f2..3873aeaace7 100644 --- a/backends/vulkan/runtime/api/containers/Tensor.h +++ b/backends/vulkan/runtime/api/containers/Tensor.h @@ -430,6 +430,16 @@ class vTensor final { return axis_map_; } + /* + * Return true if the tensor's axis map is {0, 1, 2, concat_dim}. This means + * that the width dim is mapped to the width axis of the texture, the height + * dim is mapped to the height axis of the texture, the channels dim is mapped + * to the depth axis of the texture. + */ + inline bool is_standard_axis_map() const { + return axis_map_.at(0) == 0 && axis_map_.at(1) == 1 && axis_map_.at(2) == 2; + } + inline const std::vector& strides() const { return strides_; } diff --git a/backends/vulkan/runtime/gen_vulkan_spv.py b/backends/vulkan/runtime/gen_vulkan_spv.py index 6ee29d45f18..c133094dbfb 100644 --- a/backends/vulkan/runtime/gen_vulkan_spv.py +++ b/backends/vulkan/runtime/gen_vulkan_spv.py @@ -319,21 +319,26 @@ def define_active_storage_type(storage_type: str): raise AssertionError(f"Invalid storage type: {storage_type}") -def define_required_extensions(dtype: str): +def define_required_extensions(dtypes: Union[str, List[str]]): out_str = "\n" - nbit = None - glsl_type = None - - if dtype == "half": - nbit = "16bit" - glsl_type = "float16" - if dtype == "int8": - nbit = "8bit" - glsl_type = "int8" - - if nbit is not None and glsl_type is not None: - out_str += f"#extension GL_EXT_shader_{nbit}_storage : require\n" - out_str += f"#extension GL_EXT_shader_explicit_arithmetic_types_{glsl_type} : require\n" + dtype_list = dtypes if isinstance(dtypes, list) else [dtypes] + + for dtype in dtype_list: + nbit = None + glsl_type = None + if dtype == "half": + nbit = "16bit" + glsl_type = "float16" + elif dtype == "int16" or dtype == "uint16": + nbit = "16bit" + glsl_type = "int16" + elif dtype == "int8" or dtype == "uint8": + nbit = "8bit" + glsl_type = "int8" + + if nbit is not None and glsl_type is not None: + out_str += f"#extension GL_EXT_shader_{nbit}_storage : require\n" + out_str += f"#extension GL_EXT_shader_explicit_arithmetic_types_{glsl_type} : require\n" return out_str diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index a4bb714e38e..1c44b9f2f33 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -342,6 +342,10 @@ class ComputeGraph final { return values_.at(idx).toTensor().axis_map_ubo(); } + inline bool is_standard_axis_map(const ValueRef idx) { + return values_.at(idx).toTensor().is_standard_axis_map(); + } + inline vkapi::BufferBindInfo logical_limits_ubo(const ValueRef idx) { return values_.at(idx).toTensor().logical_limits_ubo(); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_buffer.yaml index 8ea4cbe561e..9abd9c1deac 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_buffer.yaml @@ -14,5 +14,6 @@ buffer_to_buffer: - VALUE: float - VALUE: int - VALUE: int8 + - VALUE: uint8 shader_variants: - NAME: buffer_to_buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/no_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/no_op.yaml index 825da11b24e..e64e1bd260a 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/no_op.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/no_op.yaml @@ -14,6 +14,7 @@ no_op: - VALUE: float - VALUE: int - VALUE: int8 + - VALUE: uint8 STORAGE: - VALUE: texture3d - VALUE: texture2d diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl b/backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl index de42f9ed996..38a8bfbdec6 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl @@ -19,117 +19,94 @@ ${define_active_storage_type(STORAGE)} -${define_required_extensions(DTYPE)} -${define_required_extensions("int8")} +${define_required_extensions([DTYPE, "uint8", "uint16"])} +#extension GL_EXT_control_flow_attributes : require layout(std430) buffer; -${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)} -${layout_declare_tensor(1, "r", "t_mat1", DTYPE, STORAGE)} -${layout_declare_tensor(2, "r", "t_mat2", "int8", "buffer")} -${layout_declare_tensor(3, "r", "t_scales_and_zeros", DTYPE, STORAGE)} - -$if STORAGE == "texture3d": - ${layout_declare_ubo(4, "ivec4", "out_sizes")} - ${layout_declare_ubo(5, "ivec4", "mat1_sizes")} - ${layout_declare_ubo(6, "ivec4", "mat2_strides")} - ${layout_declare_ubo(7, "ivec4", "scales_strides")} -$else: - ${layout_declare_ubo(4, "ivec4", "out_sizes")} - ${layout_declare_ubo(5, "ivec4", "out_strides")} - ${layout_declare_ubo(6, "ivec4", "mat1_sizes")} - ${layout_declare_ubo(7, "ivec4", "mat1_strides")} - ${layout_declare_ubo(8, "ivec4", "mat2_strides")} - ${layout_declare_ubo(9, "ivec4", "scales_strides")} +${layout_declare_tensor(B, "w", "ret", DTYPE, STORAGE)} +${layout_declare_tensor(B, "r", "x", DTYPE, STORAGE)} +${layout_declare_tensor(B, "r", "weights", "uint8", "buffer")} +${layout_declare_tensor(B, "r", "qparams", DTYPE, STORAGE)} +${layout_declare_ubo(B, "ivec3", "ret_limits")} +${layout_declare_ubo(B, "ivec4", "x_sizes")} +${layout_declare_ubo(B, "ivec4", "weights_strides")} +${layout_declare_ubo(B, "ivec4", "qparams_strides")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; layout(constant_id = 3) const int group_size = 1; +/* + * This shader computes a linear operator between a floating point input matrix + * x and a weights matrix that is quantized to 4 bits. + * + * The (W, H, C) shape of each tensor is: + * - x: (K, M) + * - weights: (K / 2, N) + * - The weights tensor has a data type of `uint8`. Each element in the tensor + * contains 2 4-bit values packed into a uint8. + * - qparams: (2, N, number_of_groups) + * - This tensor contains the scales and zeros quantization parameters for the + * weights tensor. The weight tensor is quantized group-wise, which means + * that every `group_size` elements along the K dimension of the weights + * tensor has independent quantization parameters. Along the width dim, the + * first value contains the scale for the group and the second value + * contains the zero point for the group. + * + * Note that this shader assumes that all tensors are width packed. + */ void main() { - - const ivec4 out_pos = ivec4( - gl_GlobalInvocationID.x, // n = 0..N-1 - gl_GlobalInvocationID.y, // m = 0..M-1 - gl_GlobalInvocationID.z % out_sizes.z, - gl_GlobalInvocationID.z / out_sizes.z); - - if (any(greaterThanEqual(out_pos, out_sizes))) { - return; + // output positions being calculated are (n, m), (n + 1, m), ... + // This means multiplying the m-th row of x with the n-th, (n+1)-th, ... rows + // of the weights tensor. + const u16vec3 ret_pos = u16vec3(gl_GlobalInvocationID); + if (any(greaterThanEqual(ret_pos, ret_limits))) { + return; + } + + // Since ret is width packed, need to multiply by 4 + const uint16_t n = uint16_t(ret_pos.x * 4); + + // K is guaranteed to be a multiple of group size + const uint16_t num_blocks = uint16_t(x_sizes.x / group_size); + + uint16_t k_texel_i = uint16_t(0); + vec4 sums = vec4(0.0); + for (uint16_t block_idx = uint16_t(0); block_idx < num_blocks; block_idx++) { + vec4 scales; + vec4 zeros; + + [[unroll]] for (int comp = 0; comp < 4; ++comp) { + const vec4 scale_and_zero = load_texel( + qparams, u16vec3(0, n + comp, block_idx)); + scales[comp] = scale_and_zero.x; + zeros[comp] = scale_and_zero.y; } - const uint K = mat1_sizes.x; - const uint n = out_pos.x; - const uint m = out_pos.y; - const uint mask = uint(0x0f); - - float rc = 0.0; - int k = 0; - const uint k_block = (K + group_size - 1) / group_size; - - #ifdef USING_BUFFER - ivec4 mat1_pos = ivec4(0, m, out_pos.z, out_pos.w); - ivec4 mat2_pos = ivec4(0, n, out_pos.z, out_pos.w); - ivec4 scale_pos = ivec4(0, n, 0, out_pos.w); - ivec4 zero_pos = ivec4(0, n, 1, out_pos.w); - - for (int kb = 0; kb < k_block; kb++) { - scale_pos.x = kb; - const int scale_bufi = tidx_to_bufi(scale_pos, scales_strides); - const float scale = float(t_scales_and_zeros[scale_bufi]); - - zero_pos.x = kb; - const int zero_bufi = tidx_to_bufi(zero_pos, scales_strides); - const float zero = float(t_scales_and_zeros[zero_bufi]) - scale * 8.0; - - for(uint idx = 0; idx < group_size && k < K; idx++, k++) { - mat1_pos.x = k; - const int mat1_bufi = tidx_to_bufi(mat1_pos, mat1_strides); - const float mat1_val = float(t_mat1[mat1_bufi]); - - mat2_pos.x = k / 2; - const int mat2_bufi = tidx_to_bufi(mat2_pos, mat2_strides); - // Bitwise op treats sign bit from int8 as a value bit instead, - // since there is no uint8_t datatype - uint mat2_val = (t_mat2[mat2_bufi] & 0xFF); - mat2_val = (k & 1) == 0 ? mat2_val & mask : (mat2_val >> 4); - - rc += mat1_val * (scale * float(mat2_val) + zero); - } - } - - const int out_bufi = tidx_to_bufi(out_pos, out_strides); - t_out[out_bufi] = FLOAT_T(rc); - - #else // Using texture - ivec3 mat1_pos = ivec3(0, m, out_pos.z); - ivec4 mat2_pos = ivec4(0, n, out_pos.z, out_pos.w); - ivec3 scale_zero_pos = ivec3(0, n, 0); - uint K_texel = K / FOUR; - - for (int kb = 0; kb < k_block; kb++) { - scale_zero_pos.x = kb; - const vec4 scale_zero = load_texel(t_scales_and_zeros, scale_zero_pos); - const float scale = scale_zero.x; - const float zero = scale_zero.y - scale * 8.0; - - for(uint idx = 0; idx < group_size && k < K_texel; idx += FOUR, k++) { - mat1_pos.x = k; - const VEC4_T mat1_tex = load_texel(t_mat1, mat1_pos); - - mat2_pos.x = k * 2; // k * FOUR / 2 - const int mat2_id = tidx_to_bufi(mat2_pos, mat2_strides); - - for (int texel_pos = 0; texel_pos < FOUR; texel_pos++) { - // Bitwise op treats sign bit from int8 as a value bit instead, - // since there is no uint8_t datatype - uint mat2_val = (t_mat2[mat2_id + texel_pos / 2] & 0xFF); - mat2_val = (texel_pos & 1) == 0 ? mat2_val & mask : (mat2_val >> 4); - rc += mat1_tex[texel_pos] * (scale * float(mat2_val) + zero); - } - } + for (uint16_t i = uint16_t(0); i < group_size; i += uint16_t(4), k_texel_i++) { + const VEC4_T x_texel = load_texel( + x, u16vec3(k_texel_i, ret_pos.y, ret_pos.z)); + + [[unroll]] for (int comp = 0; comp < 4; ++comp) { + const int weights_bufi = (n + comp) * weights_strides.y + (k_texel_i * 2); + // Need to read 4 unpacked values, which corresponds to 2 packed values + const uint8_t weights_val_1 = weights[weights_bufi]; + const uint8_t weights_val_2 = weights[weights_bufi + 1]; + + const u8vec4 weights_texel = u8vec4( + (weights_val_1 & 0xF0) >> 4, + weights_val_1 & 0x0F, + (weights_val_2 & 0xF0) >> 4, + weights_val_2 & 0x0F); + + // Note that the unpacked 4-bit values are unsigned, therefore they must + // first be "centered" around 0 by subtracting 8 before applying the + // scale and zero point. + sums[comp] += dot( + x_texel, (vec4(weights_texel) - 8.0) * scales[comp] + zeros[comp]); } - write_texel(t_out, out_pos.xyz, vec4(rc, 0, 0, 0)); - - #endif + } + } + write_texel(ret, ret_pos, sums); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.yaml b/backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.yaml index fd65068080a..40d95d4a05f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.yaml @@ -7,13 +7,10 @@ q_4w_linear: parameter_names_with_default_values: DTYPE: float - STORAGE: buffer + STORAGE: texture3d generate_variant_forall: DTYPE: - VALUE: float - VALUE: half - STORAGE: - - VALUE: buffer - - VALUE: texture3d shader_variants: - - NAME: q_4w_linear + - NAME: q_4w_linear_texture3d diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp index 206a4eafa36..ba58bc1ef60 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp @@ -16,7 +16,7 @@ namespace vkcompute { -void check_qlinear_args( +void check_q_8w_linear_args( const ComputeGraph& graph, const ValueRef mat1, const ValueRef qmat2_data, @@ -38,7 +38,7 @@ void check_qlinear_args( utils::val_at(-1, scales_sizes) == utils::val_at(-2, qmat2_sizes)); } -void resize_qlinear_node( +void resize_q_8w_linear_node( ComputeGraph* graph, const std::vector& args, const std::vector& extra_args) { @@ -110,18 +110,151 @@ void add_q_8w_linear_node( // Specialization Constants {}, // Resizing Logic - resize_qlinear_node)); + resize_q_8w_linear_node)); } void weight_int8pack_mm( ComputeGraph& graph, const std::vector& args) { - check_qlinear_args(graph, args[0], args[1], args[2], args[3]); + check_q_8w_linear_args(graph, args[0], args[1], args[2], args[3]); return add_q_8w_linear_node(graph, args[0], args[1], args[2], args[3]); } +void check_q_4w_linear_args( + ComputeGraph& graph, + const ValueRef mat1, + const ValueRef mat2_data, + const ValueRef group_size, + const ValueRef scales_and_zeros, + const ValueRef out) { + VK_CHECK_COND(graph.val_is_tensor(mat1)); + VK_CHECK_COND(graph.val_is_tref(mat2_data)); + VK_CHECK_COND(graph.val_is_tref(scales_and_zeros)); + + VK_CHECK_COND(graph.dim_of(mat1) <= 3); + VK_CHECK_COND(graph.dim_of(mat2_data) == 2); + VK_CHECK_COND(graph.dim_of(scales_and_zeros) == 3); + + VK_CHECK_COND(graph.size_at(-3, mat1) == 1); + const int K = graph.size_at(-1, mat1); + VK_CHECK_COND(graph.size_at(-1, mat2_data) * 2 == K); + + const int group_size_val = graph.extract_scalar(group_size); + VK_CHECK_COND(K % group_size_val == 0); + + VK_CHECK_COND(graph.packed_dim_of(mat1) == WHCN::kWidthDim); + VK_CHECK_COND(graph.packed_dim_of(out) == WHCN::kWidthDim); + + VK_CHECK_COND(graph.is_standard_axis_map(mat1)); + VK_CHECK_COND(graph.is_standard_axis_map(out)); +} + +void resize_q_4w_linear_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + (void)extra_args; + + vTensorPtr out = graph->get_tensor(args[0].refs[0]); + vTensorPtr mat1 = graph->get_tensor(args[1].refs[0]); + vTensorPtr mat2 = graph->get_tensor(args[1].refs[1]); + + const int out_cols = utils::val_at(-2, mat1->sizes()); + const int out_rows = utils::val_at(-2, mat2->sizes()); + + std::vector new_out_sizes(3); + if (mat1->sizes().size() == 2) { + new_out_sizes.resize(2); + new_out_sizes.at(0) = out_cols; + new_out_sizes.at(1) = out_rows; + } else { + new_out_sizes.at(0) = mat1->sizes().at(0); + new_out_sizes.at(1) = out_cols; + new_out_sizes.at(2) = out_rows; + } + + out->virtual_resize(new_out_sizes); +} + +void add_q_4w_linear_node( + ComputeGraph& graph, + const ValueRef mat1, + const ValueRef mat2_data, + const ValueRef group_size, + const ValueRef scales_and_zeros_data, + const ValueRef out) { + check_q_4w_linear_args( + graph, mat1, mat2_data, group_size, scales_and_zeros_data, out); + + utils::StorageType storage_type = graph.storage_type_of(out); + + ValueRef mat2 = + prepack_buffer_if_tensor_ref(graph, mat2_data, utils::kWidthPacked); + + ValueRef scales_and_zeros = + prepack_if_tensor_ref(graph, scales_and_zeros_data, utils::kWidthPacked); + + std::string kernel_name = "q_4w_linear"; + add_storage_type_suffix(kernel_name, storage_type); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + const uint32_t group_size_val = graph.extract_scalar(group_size); + + vkapi::ParamsBindList ubos({}); + if (storage_type == utils::kBuffer) { + ubos.append(graph.sizes_ubo(out)); + ubos.append(graph.strides_ubo(out)); + ubos.append(graph.sizes_ubo(mat1)); + ubos.append(graph.strides_ubo(mat1)); + ubos.append(graph.strides_ubo(mat2)); + ubos.append(graph.strides_ubo(scales_and_zeros)); + } else { + ubos.append(graph.logical_limits_ubo(out)); + ubos.append(graph.sizes_ubo(mat1)); + ubos.append(graph.strides_ubo(mat2)); + ubos.append(graph.strides_ubo(scales_and_zeros)); + } + + auto out_sizes = graph.sizes_of(out); + uint32_t N = utils::val_at(-1, out_sizes); + uint32_t M = utils::val_at(-2, out_sizes); + + utils::uvec3 global_wg_size = {N, M, 1}; + utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size); + + graph.execute_nodes().emplace_back(new DispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_wg_size, + local_wg_size, + // Inputs and Outputs + {{out, vkapi::MemoryAccessType::WRITE}, + {{mat1, mat2, scales_and_zeros}, vkapi::MemoryAccessType::READ}}, + // Shader params buffers + ubos, + // Specialization Constants + {SV(group_size_val)}, + // Resizing Logic + resize_q_4w_linear_node, + {})); +} + +void linear_weight_int4( + ComputeGraph& graph, + const std::vector& args) { + return add_q_4w_linear_node( + graph, + args[0], // mat1 + args[1], // mat2 + args[2], // group_size + args[3], // scales_and_zeros + args[4] // out + ); +} + REGISTER_OPERATORS { VK_REGISTER_OP(aten._weight_int8pack_mm.default, weight_int8pack_mm); + VK_REGISTER_OP(et_vk.linear_weight_int4.default, linear_weight_int4); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedMatMul.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedMatMul.cpp deleted file mode 100644 index 17291d292a3..00000000000 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedMatMul.cpp +++ /dev/null @@ -1,184 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include -#include -#include - -namespace vkcompute { - -void check_q_matmul_args( - ComputeGraph& graph, - const ValueRef mat1, - const ValueRef mat2_data, - const ValueRef group_size_data, - const ValueRef scales_and_zeros, - const ValueRef out) { - const std::vector mat1_sizes = graph.sizes_of(mat1); - const std::vector mat2_sizes = graph.sizes_of(mat2_data); - const std::vector scales_and_zeros_sizes = - graph.sizes_of(scales_and_zeros); - - const uint32_t group_size = graph.extract_scalar(group_size_data); - - VK_CHECK_COND(mat1_sizes.size() == 2); - VK_CHECK_COND(mat1_sizes.size() == mat2_sizes.size()); - - using namespace WHCN; - VK_CHECK_COND(graph.packed_dim_of(mat1) == kWidthDim); - VK_CHECK_COND(graph.packed_dim_of(mat2_data) == kWidthDim); - // VK_CHECK_COND(graph.packed_dim_of(scales_and_zeros) == kWidthDim); - - if (graph.storage_type_of(scales_and_zeros) == utils::kBuffer) { - VK_CHECK_COND(graph.packed_dim_of(scales_and_zeros) == kWidthDim); - } else { - VK_CHECK_COND(graph.packed_dim_of(scales_and_zeros) == kChannelsDim); - } - - if (graph.storage_type_of(out) == utils::kBuffer) { - VK_CHECK_COND(graph.packed_dim_of(out) == kWidthDim); - } else { - VK_CHECK_COND(graph.packed_dim_of(out) == kChannelsDim); - } - - const int mat1_K = utils::val_at(-1, mat1_sizes); - const int mat2_K = utils::val_at(-1, mat2_sizes) * 2; - const int N = utils::val_at(-2, mat2_sizes); - - VK_CHECK_COND(mat1_K == mat2_K); - - VK_CHECK_COND(mat2_K % group_size == 0); - - const uint32_t k_groups = mat2_K / group_size; - - VK_CHECK_COND(scales_and_zeros_sizes.size() == 3); - VK_CHECK_COND(utils::val_at(-1, scales_and_zeros_sizes) == k_groups); - VK_CHECK_COND(utils::val_at(-2, scales_and_zeros_sizes) == N); - VK_CHECK_COND(utils::val_at(-3, scales_and_zeros_sizes) == 2); - - // Match https://fburl.com/code/6ostkknm - std::vector valid_group_sizes = {32, 64, 128, 256}; - - bool is_valid_group_size = false; - for (auto valid_group_size : valid_group_sizes) { - if (group_size == valid_group_size) { - is_valid_group_size = true; - break; - } - } - - VK_CHECK_COND(is_valid_group_size); -} - -void resize_q_matmul_node( - ComputeGraph* graph, - const std::vector& args, - const std::vector& extra_args) { - (void)extra_args; - - vTensorPtr out = graph->get_tensor(args[0].refs[0]); - vTensorPtr mat1 = graph->get_tensor(args[1].refs[0]); - vTensorPtr mat2 = graph->get_tensor(args[1].refs[1]); - - const int out_cols = utils::val_at(-2, mat1->sizes()); - const int out_rows = utils::val_at(-2, mat2->sizes()); - - std::vector new_out_sizes(3); - if (mat1->sizes().size() == 2) { - new_out_sizes.resize(2); - new_out_sizes.at(0) = out_cols; - new_out_sizes.at(1) = out_rows; - } else { - new_out_sizes.at(0) = mat1->sizes().at(0); - new_out_sizes.at(1) = out_cols; - new_out_sizes.at(2) = out_rows; - } - - out->virtual_resize(new_out_sizes); -} - -void add_q_matmul_node( - ComputeGraph& graph, - const ValueRef mat1, - const ValueRef mat2_data, - const ValueRef group_size, - const ValueRef scales_and_zeros_data, - const ValueRef out) { - auto storage_type = graph.storage_type_of(out); - - ValueRef mat2 = - prepack_buffer_if_tensor_ref(graph, mat2_data, utils::kWidthPacked); - - ValueRef scales_and_zeros = - prepack_if_tensor_ref(graph, scales_and_zeros_data, utils::kWidthPacked); - - std::string kernel_name = "q_4w_linear"; - - add_dtype_suffix(kernel_name, graph.dtype_of(out)); - add_storage_type_suffix(kernel_name, storage_type); - - const uint32_t group_size_val = graph.extract_scalar(group_size); - - vkapi::ParamsBindList ubos({}); - if (storage_type == utils::kBuffer) { - ubos.append(graph.sizes_ubo(out)); - ubos.append(graph.strides_ubo(out)); - ubos.append(graph.sizes_ubo(mat1)); - ubos.append(graph.strides_ubo(mat1)); - ubos.append(graph.strides_ubo(mat2)); - ubos.append(graph.strides_ubo(scales_and_zeros)); - } else { - ubos.append(graph.sizes_ubo(out)); - ubos.append(graph.sizes_ubo(mat1)); - ubos.append(graph.strides_ubo(mat2)); - ubos.append(graph.strides_ubo(scales_and_zeros)); - } - - auto out_sizes = graph.sizes_of(out); - uint32_t N = utils::val_at(-1, out_sizes); - uint32_t M = utils::val_at(-2, out_sizes); - - utils::uvec3 global_wg_size = {N, M, 1}; - - utils::uvec3 local_wg_size = adaptive_work_group_size(global_wg_size); - - graph.execute_nodes().emplace_back(new DispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - global_wg_size, - local_wg_size, - // Inputs and Outputs - {{out, vkapi::MemoryAccessType::WRITE}, - {{mat1, mat2, scales_and_zeros}, vkapi::MemoryAccessType::READ}}, - // Shader params buffers - ubos, - // Specialization Constants - {SV(group_size_val)}, - // Resizing Logic - resize_q_matmul_node, - {})); -} - -void int4pack_mm(ComputeGraph& graph, const std::vector& args) { - check_q_matmul_args(graph, args[0], args[1], args[2], args[3], args[4]); - return add_q_matmul_node( - graph, - args[0], // mat1 - args[1], // mat2 - args[2], // group_size - args[3], // scales_and_zeros - args[4] // out - ); -} - -REGISTER_OPERATORS { - VK_REGISTER_OP(aten._weight_int4pack_mm.default, int4pack_mm); -} - -} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp index d634947a510..4a709fce994 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp @@ -9,6 +9,7 @@ #include #include +#include #include #include @@ -147,7 +148,9 @@ ValueRef prepack_buffer( const utils::GPUMemoryLayout layout) { ValueRef v = graph.add_tensor_like(vref, utils::kBuffer, layout); - vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR("buffer_to_buffer"); + std::string kernel_name = "buffer_to_buffer"; + add_dtype_suffix(kernel_name, graph.dtype_of(vref)); + vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name); vkapi::ParamsBindList ubos; ubos.append({graph.numel_ubo(v)}); diff --git a/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp b/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp new file mode 100644 index 00000000000..d9444b50ed4 --- /dev/null +++ b/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp @@ -0,0 +1,218 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include +#include + +#include + +// +// Reference Implementations +// + +at::Tensor linear_weight_int4_reference_impl( + const at::Tensor& x, + const at::Tensor& weights_4x2, + const int64_t groupsize, + const at::Tensor& scales_and_zeros, + const int64_t inner_k_tiles) { + const std::vector original_x_size(x.sizes().vec()); + const size_t ndim = original_x_size.size(); + const int64_t out_features = weights_4x2.size(0); + const at::Tensor x_flattened = x.reshape({-1, original_x_size[ndim - 1]}); + const at::Tensor packed_weights = + at::_convert_weight_to_int4pack(weights_4x2, inner_k_tiles); + at::Tensor out = at::_weight_int4pack_mm( + x_flattened, packed_weights, groupsize, scales_and_zeros); + std::vector out_shape( + original_x_size.begin(), original_x_size.end()); + out_shape.at(ndim - 1) = out_features; + return out.reshape(out_shape); +} + +at::Tensor dequantize_and_linear( + const at::Tensor& x, + const at::Tensor& weights_4x2, + const int64_t groupsize, + const at::Tensor& scales_and_zeros, + const int64_t inner_k_tiles) { + std::vector weights_shape(weights_4x2.sizes().vec()); + weights_shape[1] *= 2; + + at::Tensor weights_dequantized = + at::empty(weights_shape, at::device(at::kCPU).dtype(at::kFloat)); + + const int64_t N = weights_dequantized.size(0); + const int64_t K = weights_dequantized.size(1); + + const int k_groups = K / groupsize; + for (int n = 0; n < N; n++) { + for (int k = 0; k < K; k += 2) { + const int group_idx = k / groupsize; + // const int scale_idx = k_groups * n + group_idx; + const uint8_t packed_val = weights_4x2[n][k / 2].item().to(); + const uint8_t second_val = packed_val & 0x0F; + const uint8_t first_val = (packed_val & 0xF0) >> 4; + + const float scale = scales_and_zeros[group_idx][n][0].item().to(); + const float zero = scales_and_zeros[group_idx][n][1].item().to(); + + weights_dequantized[n][k] = (float(first_val) - 8.0) * scale + zero; + weights_dequantized[n][k + 1] = (float(second_val) - 8.0) * scale + zero; + } + } + + return at::linear(x, weights_dequantized); +} + +// +// Test functions +// + +void test_reference_linear_int4( + const int B, + const int M, + const int K, + const int N, + const int group_size = 32, + const int inner_k_tiles = 8) { + assert(K % group_size == 0); + + at::Tensor x = at::rand({B, M, K}, at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor weights_4x2 = + at::randint(0, 256, {N, K / 2}, at::device(at::kCPU).dtype(at::kByte)); + + const int k_groups = K / group_size; + at::Tensor scales_and_zeros = + at::rand({k_groups, N, 2}, at::device(at::kCPU).dtype(at::kFloat)); + + at::Tensor out = linear_weight_int4_reference_impl( + x, weights_4x2, group_size, scales_and_zeros, inner_k_tiles); + + at::Tensor out_ref = dequantize_and_linear( + x, weights_4x2, group_size, scales_and_zeros, inner_k_tiles); + + ASSERT_TRUE(at::allclose(out, out_ref)); +} + +vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) { + using namespace vkcompute; + switch (at_scalartype) { + case c10::kFloat: + return vkapi::kFloat; + case c10::kHalf: + return vkapi::kHalf; + case c10::kInt: + return vkapi::kInt; + case c10::kLong: + return vkapi::kInt; + case c10::kChar: + return vkapi::kChar; + case c10::kByte: + return vkapi::kByte; + default: + VK_THROW("Unsupported at::ScalarType!"); + } +} + +void test_vulkan_linear_int4( + const int B, + const int M, + const int K, + const int N, + const int group_size = 32, + const int inner_k_tiles = 8) { + assert(K % group_size == 0); + + at::Tensor x = at::rand({B, M, K}, at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor weights_4x2 = + at::randint(0, 256, {N, K / 2}, at::device(at::kCPU).dtype(at::kByte)); + + const int k_groups = K / group_size; + at::Tensor scales_and_zeros = + at::rand({k_groups, N, 2}, at::device(at::kCPU).dtype(at::kFloat)); + + at::Tensor out_ref = dequantize_and_linear( + x, weights_4x2, group_size, scales_and_zeros, inner_k_tiles); + + // Build Vulkan graph + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(utils::kTexture3D); + ComputeGraph graph(config); + +#define MAKE_TENSORREF_FOR(x) \ + ValueRef r_##x = graph.add_tensorref( \ + x.sizes().vec(), \ + from_at_scalartype(x.scalar_type()), \ + x.const_data_ptr()); + + MAKE_TENSORREF_FOR(weights_4x2); + MAKE_TENSORREF_FOR(scales_and_zeros); + +#define MAKE_INPUT_FOR(x) \ + IOValueRef r_##x = graph.add_input_tensor( \ + x.sizes().vec(), from_at_scalartype(x.scalar_type())); + + MAKE_INPUT_FOR(x); + + const ValueRef r_out = graph.add_tensor( + out_ref.sizes().vec(), from_at_scalartype(out_ref.scalar_type())); + + VK_GET_OP_FN("et_vk.linear_weight_int4.default") + (graph, + {r_x.value, + r_weights_4x2, + graph.add_scalar(group_size), + r_scales_and_zeros, + r_out}); + + ValueRef staging_out = graph.set_output_tensor(r_out); + + graph.prepare(); + graph.encode_prepack(); + graph.prepack(); + graph.encode_execute(); + + // + // Run model + // + + graph.propagate_resize(); + graph.copy_into_staging(r_x.staging, x.const_data_ptr(), x.numel()); + + graph.execute(); + + at::Tensor vk_out = at::empty_like(out_ref); + graph.copy_from_staging( + staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); + + ASSERT_TRUE(at::allclose(vk_out, out_ref, 1e-4, 1e-4)); +} + +TEST(VulkanSDPATest, test_reference_impl) { + test_reference_linear_int4( + /*B = */ 1, + /*M = */ 4, + /*K = */ 128, + /*N = */ 32); +} + +TEST(VulkanSDPATest, test_vulkan_impl) { + test_vulkan_linear_int4( + /*B = */ 1, + /*M = */ 4, + /*K = */ 128, + /*N = */ 32); +} diff --git a/backends/vulkan/test/op_tests/targets.bzl b/backends/vulkan/test/op_tests/targets.bzl index 3acf1debe50..270e1b768a8 100644 --- a/backends/vulkan/test/op_tests/targets.bzl +++ b/backends/vulkan/test/op_tests/targets.bzl @@ -186,3 +186,41 @@ def define_common_targets(is_fbcode = False): runtime.external_dep_location("libtorch"), ], ) + + runtime.cxx_binary( + name = "linear_weight_int4_test_bin", + srcs = [ + "linear_weight_int4_test.cpp", + ], + compiler_flags = [ + "-Wno-unused-variable", + ], + define_static_target = False, + deps = [ + "//third-party/googletest:gtest_main", + "//executorch/backends/vulkan:vulkan_graph_runtime", + runtime.external_dep_location("libtorch"), + ], + ) + + runtime.cxx_test( + name = "linear_weight_int4_test", + srcs = [ + "linear_weight_int4_test.cpp", + ], + contacts = ["oncall+ai_infra_mobile_platform@xmail.facebook.com"], + fbandroid_additional_loaded_sonames = [ + "torch-code-gen", + "vulkan_graph_runtime", + "vulkan_graph_runtime_shaderlib", + ], + platforms = [ANDROID], + use_instrumentation_test = True, + deps = [ + "//third-party/googletest:gtest_main", + "//executorch/backends/vulkan:vulkan_graph_runtime", + "//executorch/extension/llm/custom_ops:custom_ops_aot_lib", + "//executorch/extension/tensor:tensor", + runtime.external_dep_location("libtorch"), + ], + ) diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index d6126ba78d2..c6cae91b2d7 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -3025,142 +3025,6 @@ TEST(VulkanComputeGraphOpsTest, grid_priors_test) { /*data_out_expected = */ {4, 4, 12, 4, 20, 4, 4, 12, 12, 12, 20, 12}); } -void test_int4pack_mm( - std::vector MKN, - uint32_t group_size, - utils::StorageType storage_type) { - GraphConfig config; - ComputeGraph graph(config); - - const uint32_t M = MKN[0]; - const uint32_t K = MKN[1]; - const uint32_t N = MKN[2]; - - const std::vector mat1_size = {M, K}; - const std::vector mat2_size = {K, N}; - const std::vector mat2_q_size = {N, K / 2}; // Transposed and packed - const std::vector out_size = {M, N}; - - std::vector A_data = create_random_float_buffer(M * K); - IOValueRef A = graph.add_input_tensor(mat1_size, vkapi::kFloat, storage_type); - graph.copy_into_staging(A.staging, A_data.data(), A_data.size()); - - // Quantized but un-packed weights - std::vector B_quant_data = create_random_uint8_buffer(K * N, 0, 16); - - // Pack and transpose weights to correspond to int4 weight format - std::vector B_int4_data = - int4mm_pack_weights(mat2_size, B_quant_data.data()); - - IOValueRef B_int4 = - graph.add_input_tensor(mat2_q_size, vkapi::kQInt8, utils::kBuffer); - graph.copy_into_staging( - B_int4.staging, B_int4_data.data(), B_int4_data.size()); - - const int k_groups = K / group_size; - - // Random scales and zeroes. Keep scales small to avoid overflow and zeroes in - // int4 range - IOValueRef scales_and_zeros; - - if (storage_type == utils::kBuffer) { - scales_and_zeros.value = graph.add_tensor( - {2, N, k_groups}, vkapi::kFloat, storage_type, utils::kWidthPacked); - } else { - scales_and_zeros.value = graph.add_tensor( - {2, N, k_groups}, vkapi::kFloat, storage_type, utils::kChannelsPacked); - } - - scales_and_zeros.staging = graph.set_input_tensor(scales_and_zeros.value); - - std::vector s_data(graph.numel_of(scales_and_zeros.value)); - const int zeros_stride = s_data.size() / 2; - for (size_t i = 0; i < zeros_stride; i++) { - s_data[i] = rand() % 100; - s_data[i + zeros_stride] = rand() % 16; - } - - graph.copy_into_staging( - scales_and_zeros.staging, s_data.data(), s_data.size()); - - IOValueRef out_int4; - - if (storage_type == utils::kBuffer) { - out_int4.value = graph.add_tensor(out_size, vkapi::kFloat, utils::kBuffer); - } else { - out_int4.value = - graph.add_tensor(out_size, vkapi::kFloat, utils::kChannelsPacked); - } - - VK_GET_OP_FN("aten._weight_int4pack_mm.default") - (graph, - {A.value, - B_int4.value, - graph.add_scalar(group_size), - scales_and_zeros.value, - out_int4.value}); - - out_int4.staging = graph.set_output_tensor(out_int4.value); - - // Dequantized matmul for comparison - IOValueRef B_deq = - graph.add_input_tensor(mat2_size, vkapi::kFloat, storage_type); - std::vector B_deq_data = int4mm_dequantize_weights( - mat2_size, B_quant_data.data(), group_size, s_data.data()); - graph.copy_into_staging(B_deq.staging, B_deq_data.data(), B_deq_data.size()); - - IOValueRef out_deq; - out_deq.value = graph.add_tensor(out_size, vkapi::kFloat, storage_type); - - VK_GET_OP_FN("aten.mm.default") - (graph, {A.value, B_deq.value, out_deq.value}); - - out_deq.staging = graph.set_output_tensor(out_deq.value); - - graph.prepare(); - graph.encode_prepack(); - graph.prepack(); - graph.encode_execute(); - graph.propagate_resize(); - graph.execute(); - - // Compare outputs - std::vector out_int4_data(graph.numel_of(out_int4.value)); - graph.copy_from_staging( - out_int4.staging, out_int4_data.data(), out_int4_data.size()); - - std::vector out_deq_data(graph.numel_of(out_deq.value)); - graph.copy_from_staging( - out_deq.staging, out_deq_data.data(), out_deq_data.size()); - - for (int i = 0; i < out_int4_data.size(); i++) { - EXPECT_TRUE(check_close(out_int4_data[i], out_deq_data[i])); - } -} - -TEST(VulkanComputeGraphOpsTest, int4pack_mm_test) { - if (!context()->adapter_ptr()->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - - for (auto storage_type : {utils::kBuffer, utils::kTexture3D}) { - // Vector multiplication, single group per row - test_int4pack_mm({1, 32, 1}, 32, storage_type); - - // Vector multiplication, multiple groups per row - test_int4pack_mm({1, 256, 1}, 64, storage_type); - - // Square matrices, single group per row - test_int4pack_mm({32, 32, 32}, 32, storage_type); - - // Irregular matrices, single group per row - test_int4pack_mm({37, 32, 19}, 32, storage_type); - - // Irregular matrices, multiple groups per row - test_int4pack_mm({37, 256, 19}, 64, storage_type); - } -} - void test_transpose_view_mm( const int B, const int M, From 2bff3c76b49a55acc7a78df3210a2636d1bc792b Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Mon, 14 Oct 2024 16:14:49 -0700 Subject: [PATCH 2/2] Update on "[ET-VK] Fix implementation of int4 quantized linear" ## Context Fix the existing implementation of int4 weight quantized linear to conform with how the `_weight_int4packed_mm` op works in the ATen library. For some additional context, the current op implementation does not actually match the behaviour of `_weight_int4packed_mm`. The ATen op expects that the weights have already been packed into a specific format, with `inner_k_tiles` as a packing parameter. The packing is accomplished via calling the `_convert_weight_to_int4pack` operator. Thus the current implementation in vulkan is equivalent to calling `_convert_weight_to_int4pack` + `_weight_int4packed_mm`. To address this discrepancy, the operator implementation is registered under the `linear_weight_int4` custom op as of this diff. The problems with the existing implementation were as follows: * The expected sizes of the scales and zeros tensor was incorrect. Previously, the sizes were assumed to be `(2, N, num_groups)` but the correct size is `(num_groups, N, 2)` * Previously, when unpacking a uint8_t into 2 unpacked int4 values, it was assumed that the LSB was the first value and the MSB was the second value. However, this ordering should be flipped * The original implementation expected the output tensor to be channels packed, but in practice we want the output tensor to be width packed This diff addresses the above issues, and introduces a dedicated test binary to test against an equivalent reference implementation expressed with ATen functions. Differential Revision: [D64354773](https://our.internmc.facebook.com/intern/diff/D64354773/) [ghstack-poisoned] --- .../vulkan/runtime/api/containers/Tensor.cpp | 2 +- .../vulkan/runtime/api/containers/Tensor.h | 2 +- backends/vulkan/runtime/graph/ComputeGraph.h | 8 ++++-- .../graph/ops/impl/QuantizedLinear.cpp | 28 ++++++++----------- backends/vulkan/runtime/vk_api/Adapter.h | 16 +++++++---- backends/vulkan/runtime/vk_api/Device.cpp | 5 ++++ backends/vulkan/runtime/vk_api/Device.h | 1 + .../test/op_tests/linear_weight_int4_test.cpp | 4 +++ 8 files changed, 40 insertions(+), 26 deletions(-) diff --git a/backends/vulkan/runtime/api/containers/Tensor.cpp b/backends/vulkan/runtime/api/containers/Tensor.cpp index dcc982add19..d3d32266d8b 100644 --- a/backends/vulkan/runtime/api/containers/Tensor.cpp +++ b/backends/vulkan/runtime/api/containers/Tensor.cpp @@ -474,7 +474,7 @@ vTensor::vTensor( if (dtype == vkapi::kHalf) { VK_CHECK_COND( - api::context()->adapter_ptr()->has_16bit_storage(), + api::context()->adapter_ptr()->supports_16bit_storage_buffers(), "Half dtype is only available if the physical device supports float16 " "storage buffers!"); } diff --git a/backends/vulkan/runtime/api/containers/Tensor.h b/backends/vulkan/runtime/api/containers/Tensor.h index 3873aeaace7..bd83e600385 100644 --- a/backends/vulkan/runtime/api/containers/Tensor.h +++ b/backends/vulkan/runtime/api/containers/Tensor.h @@ -436,7 +436,7 @@ class vTensor final { * dim is mapped to the height axis of the texture, the channels dim is mapped * to the depth axis of the texture. */ - inline bool is_standard_axis_map() const { + inline bool has_standard_axis_map() const { return axis_map_.at(0) == 0 && axis_map_.at(1) == 1 && axis_map_.at(2) == 2; } diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index 1c44b9f2f33..f2d971a56b3 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -342,8 +342,8 @@ class ComputeGraph final { return values_.at(idx).toTensor().axis_map_ubo(); } - inline bool is_standard_axis_map(const ValueRef idx) { - return values_.at(idx).toTensor().is_standard_axis_map(); + inline bool has_standard_axis_map(const ValueRef idx) { + return values_.at(idx).toTensor().has_standard_axis_map(); } inline vkapi::BufferBindInfo logical_limits_ubo(const ValueRef idx) { @@ -694,6 +694,10 @@ class ComputeGraph final { // Miscellaneous Utilities // + inline bool int16_shader_types_enabled() const { + return context_->adapter_ptr()->supports_int16_shader_types(); + } + /* * Check whether the GPU supports 8 bit buffers. */ diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp index ba58bc1ef60..66972a4c60d 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp @@ -127,6 +127,8 @@ void check_q_4w_linear_args( const ValueRef group_size, const ValueRef scales_and_zeros, const ValueRef out) { + VK_CHECK_COND(graph.int16_shader_types_enabled()); + VK_CHECK_COND(graph.val_is_tensor(mat1)); VK_CHECK_COND(graph.val_is_tref(mat2_data)); VK_CHECK_COND(graph.val_is_tref(scales_and_zeros)); @@ -145,8 +147,8 @@ void check_q_4w_linear_args( VK_CHECK_COND(graph.packed_dim_of(mat1) == WHCN::kWidthDim); VK_CHECK_COND(graph.packed_dim_of(out) == WHCN::kWidthDim); - VK_CHECK_COND(graph.is_standard_axis_map(mat1)); - VK_CHECK_COND(graph.is_standard_axis_map(out)); + VK_CHECK_COND(graph.has_standard_axis_map(mat1)); + VK_CHECK_COND(graph.has_standard_axis_map(out)); } void resize_q_4w_linear_node( @@ -201,19 +203,10 @@ void add_q_4w_linear_node( const uint32_t group_size_val = graph.extract_scalar(group_size); vkapi::ParamsBindList ubos({}); - if (storage_type == utils::kBuffer) { - ubos.append(graph.sizes_ubo(out)); - ubos.append(graph.strides_ubo(out)); - ubos.append(graph.sizes_ubo(mat1)); - ubos.append(graph.strides_ubo(mat1)); - ubos.append(graph.strides_ubo(mat2)); - ubos.append(graph.strides_ubo(scales_and_zeros)); - } else { - ubos.append(graph.logical_limits_ubo(out)); - ubos.append(graph.sizes_ubo(mat1)); - ubos.append(graph.strides_ubo(mat2)); - ubos.append(graph.strides_ubo(scales_and_zeros)); - } + ubos.append(graph.logical_limits_ubo(out)); + ubos.append(graph.sizes_ubo(mat1)); + ubos.append(graph.strides_ubo(mat2)); + ubos.append(graph.strides_ubo(scales_and_zeros)); auto out_sizes = graph.sizes_of(out); uint32_t N = utils::val_at(-1, out_sizes); @@ -248,7 +241,10 @@ void linear_weight_int4( args[1], // mat2 args[2], // group_size args[3], // scales_and_zeros - args[4] // out + // There is an unused variable inner_k_tiles which is used to call + // _convert_weight_to_int4pack in the AOT custom op, which is why the 4th + // argument is skipped. + args[5] // out ); } diff --git a/backends/vulkan/runtime/vk_api/Adapter.h b/backends/vulkan/runtime/vk_api/Adapter.h index f03f06e1f48..545f59502ef 100644 --- a/backends/vulkan/runtime/vk_api/Adapter.h +++ b/backends/vulkan/runtime/vk_api/Adapter.h @@ -155,30 +155,34 @@ class Adapter final { // Physical Device Features - inline bool has_16bit_storage() { + inline bool supports_16bit_storage_buffers() { return physical_device_.shader_16bit_storage.storageBuffer16BitAccess == VK_TRUE; } - inline bool has_8bit_storage() { + inline bool supports_8bit_storage_buffers() { return physical_device_.shader_8bit_storage.storageBuffer8BitAccess == VK_TRUE; } - inline bool has_16bit_compute() { + inline bool supports_float16_shader_types() { return physical_device_.shader_float16_int8_types.shaderFloat16 == VK_TRUE; } - inline bool has_8bit_compute() { + inline bool supports_int8_shader_types() { return physical_device_.shader_float16_int8_types.shaderInt8 == VK_TRUE; } + inline bool supports_int16_shader_types() { + return physical_device_.supports_int16_shader_types; + } + inline bool has_full_float16_buffers_support() { - return has_16bit_storage() && has_16bit_compute(); + return supports_16bit_storage_buffers() && supports_float16_shader_types(); } inline bool has_full_int8_buffers_support() { - return has_8bit_storage() && has_8bit_compute(); + return supports_16bit_storage_buffers() && supports_int8_shader_types(); } // Command Buffer Submission diff --git a/backends/vulkan/runtime/vk_api/Device.cpp b/backends/vulkan/runtime/vk_api/Device.cpp index 46e534f09f3..08d4565dbab 100644 --- a/backends/vulkan/runtime/vk_api/Device.cpp +++ b/backends/vulkan/runtime/vk_api/Device.cpp @@ -30,6 +30,7 @@ PhysicalDevice::PhysicalDevice(VkPhysicalDevice physical_device_handle) VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_FLOAT16_INT8_FEATURES_KHR}, queue_families{}, num_compute_queues(0), + supports_int16_shader_types(false), has_unified_memory(false), has_timestamps(properties.limits.timestampComputeAndGraphics), timestamp_period(properties.limits.timestampPeriod), @@ -49,6 +50,10 @@ PhysicalDevice::PhysicalDevice(VkPhysicalDevice physical_device_handle) vkGetPhysicalDeviceFeatures2(handle, &features2); + if (features2.features.shaderInt16 == VK_TRUE) { + supports_int16_shader_types = true; + } + // Check if there are any memory types have both the HOST_VISIBLE and the // DEVICE_LOCAL property flags const VkMemoryPropertyFlags unified_memory_flags = diff --git a/backends/vulkan/runtime/vk_api/Device.h b/backends/vulkan/runtime/vk_api/Device.h index 9f4b83540e5..6d6e28857af 100644 --- a/backends/vulkan/runtime/vk_api/Device.h +++ b/backends/vulkan/runtime/vk_api/Device.h @@ -35,6 +35,7 @@ struct PhysicalDevice final { // Metadata uint32_t num_compute_queues; + bool supports_int16_shader_types; bool has_unified_memory; bool has_timestamps; float timestamp_period; diff --git a/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp b/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp index d9444b50ed4..047c09b8067 100644 --- a/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp +++ b/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp @@ -176,6 +176,7 @@ void test_vulkan_linear_int4( r_weights_4x2, graph.add_scalar(group_size), r_scales_and_zeros, + kDummyValueRef, r_out}); ValueRef staging_out = graph.set_output_tensor(r_out); @@ -210,6 +211,9 @@ TEST(VulkanSDPATest, test_reference_impl) { } TEST(VulkanSDPATest, test_vulkan_impl) { + if (!vkcompute::api::context()->adapter_ptr()->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } test_vulkan_linear_int4( /*B = */ 1, /*M = */ 4,