diff --git a/backends/vulkan/runtime/api/containers/Tensor.cpp b/backends/vulkan/runtime/api/containers/Tensor.cpp index dcc982add19..d3d32266d8b 100644 --- a/backends/vulkan/runtime/api/containers/Tensor.cpp +++ b/backends/vulkan/runtime/api/containers/Tensor.cpp @@ -474,7 +474,7 @@ vTensor::vTensor( if (dtype == vkapi::kHalf) { VK_CHECK_COND( - api::context()->adapter_ptr()->has_16bit_storage(), + api::context()->adapter_ptr()->supports_16bit_storage_buffers(), "Half dtype is only available if the physical device supports float16 " "storage buffers!"); } diff --git a/backends/vulkan/runtime/api/containers/Tensor.h b/backends/vulkan/runtime/api/containers/Tensor.h index 7a113c939f2..bd83e600385 100644 --- a/backends/vulkan/runtime/api/containers/Tensor.h +++ b/backends/vulkan/runtime/api/containers/Tensor.h @@ -430,6 +430,16 @@ class vTensor final { return axis_map_; } + /* + * Return true if the tensor's axis map is {0, 1, 2, concat_dim}. This means + * that the width dim is mapped to the width axis of the texture, the height + * dim is mapped to the height axis of the texture, the channels dim is mapped + * to the depth axis of the texture. + */ + inline bool has_standard_axis_map() const { + return axis_map_.at(0) == 0 && axis_map_.at(1) == 1 && axis_map_.at(2) == 2; + } + inline const std::vector& strides() const { return strides_; } diff --git a/backends/vulkan/runtime/gen_vulkan_spv.py b/backends/vulkan/runtime/gen_vulkan_spv.py index 6ee29d45f18..c133094dbfb 100644 --- a/backends/vulkan/runtime/gen_vulkan_spv.py +++ b/backends/vulkan/runtime/gen_vulkan_spv.py @@ -319,21 +319,26 @@ def define_active_storage_type(storage_type: str): raise AssertionError(f"Invalid storage type: {storage_type}") -def define_required_extensions(dtype: str): +def define_required_extensions(dtypes: Union[str, List[str]]): out_str = "\n" - nbit = None - glsl_type = None - - if dtype == "half": - nbit = "16bit" - glsl_type = "float16" - if dtype == "int8": - nbit = "8bit" - glsl_type = "int8" - - if nbit is not None and glsl_type is not None: - out_str += f"#extension GL_EXT_shader_{nbit}_storage : require\n" - out_str += f"#extension GL_EXT_shader_explicit_arithmetic_types_{glsl_type} : require\n" + dtype_list = dtypes if isinstance(dtypes, list) else [dtypes] + + for dtype in dtype_list: + nbit = None + glsl_type = None + if dtype == "half": + nbit = "16bit" + glsl_type = "float16" + elif dtype == "int16" or dtype == "uint16": + nbit = "16bit" + glsl_type = "int16" + elif dtype == "int8" or dtype == "uint8": + nbit = "8bit" + glsl_type = "int8" + + if nbit is not None and glsl_type is not None: + out_str += f"#extension GL_EXT_shader_{nbit}_storage : require\n" + out_str += f"#extension GL_EXT_shader_explicit_arithmetic_types_{glsl_type} : require\n" return out_str diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index a4bb714e38e..f2d971a56b3 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -342,6 +342,10 @@ class ComputeGraph final { return values_.at(idx).toTensor().axis_map_ubo(); } + inline bool has_standard_axis_map(const ValueRef idx) { + return values_.at(idx).toTensor().has_standard_axis_map(); + } + inline vkapi::BufferBindInfo logical_limits_ubo(const ValueRef idx) { return values_.at(idx).toTensor().logical_limits_ubo(); } @@ -690,6 +694,10 @@ class ComputeGraph final { // Miscellaneous Utilities // + inline bool int16_shader_types_enabled() const { + return context_->adapter_ptr()->supports_int16_shader_types(); + } + /* * Check whether the GPU supports 8 bit buffers. */ diff --git a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_buffer.yaml index 8ea4cbe561e..9abd9c1deac 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_buffer.yaml @@ -14,5 +14,6 @@ buffer_to_buffer: - VALUE: float - VALUE: int - VALUE: int8 + - VALUE: uint8 shader_variants: - NAME: buffer_to_buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/no_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/no_op.yaml index 825da11b24e..e64e1bd260a 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/no_op.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/no_op.yaml @@ -14,6 +14,7 @@ no_op: - VALUE: float - VALUE: int - VALUE: int8 + - VALUE: uint8 STORAGE: - VALUE: texture3d - VALUE: texture2d diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl b/backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl index de42f9ed996..b702a110a65 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl @@ -19,117 +19,94 @@ ${define_active_storage_type(STORAGE)} -${define_required_extensions(DTYPE)} -${define_required_extensions("int8")} +${define_required_extensions([DTYPE, "uint8", "uint16"])} +#extension GL_EXT_control_flow_attributes : require layout(std430) buffer; -${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)} -${layout_declare_tensor(1, "r", "t_mat1", DTYPE, STORAGE)} -${layout_declare_tensor(2, "r", "t_mat2", "int8", "buffer")} -${layout_declare_tensor(3, "r", "t_scales_and_zeros", DTYPE, STORAGE)} - -$if STORAGE == "texture3d": - ${layout_declare_ubo(4, "ivec4", "out_sizes")} - ${layout_declare_ubo(5, "ivec4", "mat1_sizes")} - ${layout_declare_ubo(6, "ivec4", "mat2_strides")} - ${layout_declare_ubo(7, "ivec4", "scales_strides")} -$else: - ${layout_declare_ubo(4, "ivec4", "out_sizes")} - ${layout_declare_ubo(5, "ivec4", "out_strides")} - ${layout_declare_ubo(6, "ivec4", "mat1_sizes")} - ${layout_declare_ubo(7, "ivec4", "mat1_strides")} - ${layout_declare_ubo(8, "ivec4", "mat2_strides")} - ${layout_declare_ubo(9, "ivec4", "scales_strides")} +${layout_declare_tensor(B, "w", "ret", DTYPE, STORAGE)} +${layout_declare_tensor(B, "r", "x", DTYPE, STORAGE)} +${layout_declare_tensor(B, "r", "weights", "uint8", "buffer")} +${layout_declare_tensor(B, "r", "qparams", DTYPE, STORAGE)} +${layout_declare_ubo(B, "ivec3", "ret_limits")} +${layout_declare_ubo(B, "ivec4", "x_sizes")} +${layout_declare_ubo(B, "ivec4", "weights_strides")} +${layout_declare_ubo(B, "ivec4", "qparams_strides")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; layout(constant_id = 3) const int group_size = 1; +/* + * This shader computes a linear operator between a floating point input matrix + * x and a weights matrix that is quantized to 4 bits. + * + * The (W, H, C) shape of each tensor is: + * - x: (K, M) + * - weights: (K / 2, N) + * - The weights tensor has a data type of `uint8`. Each element in the tensor + * contains 2 4-bit values packed into a uint8. + * - qparams: (2, N, number_of_groups) + * - This tensor contains the scales and zeros quantization parameters for the + * weights tensor. The weight tensor is quantized group-wise, which means + * that every `group_size` elements along the K dimension of the weights + * tensor has independent quantization parameters. Along the width dim, the + * first value contains the scale for the group and the second value + * contains the zero point for the group. + * + * Note that this shader assumes that all tensors are width packed. + */ void main() { - - const ivec4 out_pos = ivec4( - gl_GlobalInvocationID.x, // n = 0..N-1 - gl_GlobalInvocationID.y, // m = 0..M-1 - gl_GlobalInvocationID.z % out_sizes.z, - gl_GlobalInvocationID.z / out_sizes.z); - - if (any(greaterThanEqual(out_pos, out_sizes))) { - return; + // output positions being calculated are (n, m), (n + 1, m), ... + // This means multiplying the m-th row of x with the n-th, (n+1)-th, ... rows + // of the weights tensor. + const u16vec3 ret_pos = u16vec3(gl_GlobalInvocationID); + if (any(greaterThanEqual(ret_pos, ret_limits))) { + return; + } + + // Since ret is width packed, need to multiply by 4 + const uint16_t n = uint16_t(ret_pos.x * 4); + + // K is guaranteed to be a multiple of group size + const uint16_t num_blocks = uint16_t(x_sizes.x / group_size); + + uint16_t k_texel_i = uint16_t(0); + vec4 sums = vec4(0.0); + for (uint16_t block_idx = uint16_t(0); block_idx < num_blocks; block_idx++) { + vec4 scales; + vec4 zeros; + + [[unroll]] for (int comp = 0; comp < 4; comp++) { + const vec4 scale_and_zero = load_texel( + qparams, u16vec3(0, n + comp, block_idx)); + scales[comp] = scale_and_zero.x; + zeros[comp] = scale_and_zero.y; } - const uint K = mat1_sizes.x; - const uint n = out_pos.x; - const uint m = out_pos.y; - const uint mask = uint(0x0f); - - float rc = 0.0; - int k = 0; - const uint k_block = (K + group_size - 1) / group_size; - - #ifdef USING_BUFFER - ivec4 mat1_pos = ivec4(0, m, out_pos.z, out_pos.w); - ivec4 mat2_pos = ivec4(0, n, out_pos.z, out_pos.w); - ivec4 scale_pos = ivec4(0, n, 0, out_pos.w); - ivec4 zero_pos = ivec4(0, n, 1, out_pos.w); - - for (int kb = 0; kb < k_block; kb++) { - scale_pos.x = kb; - const int scale_bufi = tidx_to_bufi(scale_pos, scales_strides); - const float scale = float(t_scales_and_zeros[scale_bufi]); - - zero_pos.x = kb; - const int zero_bufi = tidx_to_bufi(zero_pos, scales_strides); - const float zero = float(t_scales_and_zeros[zero_bufi]) - scale * 8.0; - - for(uint idx = 0; idx < group_size && k < K; idx++, k++) { - mat1_pos.x = k; - const int mat1_bufi = tidx_to_bufi(mat1_pos, mat1_strides); - const float mat1_val = float(t_mat1[mat1_bufi]); - - mat2_pos.x = k / 2; - const int mat2_bufi = tidx_to_bufi(mat2_pos, mat2_strides); - // Bitwise op treats sign bit from int8 as a value bit instead, - // since there is no uint8_t datatype - uint mat2_val = (t_mat2[mat2_bufi] & 0xFF); - mat2_val = (k & 1) == 0 ? mat2_val & mask : (mat2_val >> 4); - - rc += mat1_val * (scale * float(mat2_val) + zero); - } - } - - const int out_bufi = tidx_to_bufi(out_pos, out_strides); - t_out[out_bufi] = FLOAT_T(rc); - - #else // Using texture - ivec3 mat1_pos = ivec3(0, m, out_pos.z); - ivec4 mat2_pos = ivec4(0, n, out_pos.z, out_pos.w); - ivec3 scale_zero_pos = ivec3(0, n, 0); - uint K_texel = K / FOUR; - - for (int kb = 0; kb < k_block; kb++) { - scale_zero_pos.x = kb; - const vec4 scale_zero = load_texel(t_scales_and_zeros, scale_zero_pos); - const float scale = scale_zero.x; - const float zero = scale_zero.y - scale * 8.0; - - for(uint idx = 0; idx < group_size && k < K_texel; idx += FOUR, k++) { - mat1_pos.x = k; - const VEC4_T mat1_tex = load_texel(t_mat1, mat1_pos); - - mat2_pos.x = k * 2; // k * FOUR / 2 - const int mat2_id = tidx_to_bufi(mat2_pos, mat2_strides); - - for (int texel_pos = 0; texel_pos < FOUR; texel_pos++) { - // Bitwise op treats sign bit from int8 as a value bit instead, - // since there is no uint8_t datatype - uint mat2_val = (t_mat2[mat2_id + texel_pos / 2] & 0xFF); - mat2_val = (texel_pos & 1) == 0 ? mat2_val & mask : (mat2_val >> 4); - rc += mat1_tex[texel_pos] * (scale * float(mat2_val) + zero); - } - } + for (uint16_t i = uint16_t(0); i < group_size; i += uint16_t(4), k_texel_i++) { + const VEC4_T x_texel = load_texel( + x, u16vec3(k_texel_i, ret_pos.y, ret_pos.z)); + + [[unroll]] for (int comp = 0; comp < 4; comp++) { + const int weights_bufi = (n + comp) * weights_strides.y + (k_texel_i * 2); + // Need to read 4 unpacked values, which corresponds to 2 packed values + const uint8_t weights_val_1 = weights[weights_bufi]; + const uint8_t weights_val_2 = weights[weights_bufi + 1]; + + const u8vec4 weights_texel = u8vec4( + (weights_val_1 & 0xF0) >> 4, + weights_val_1 & 0x0F, + (weights_val_2 & 0xF0) >> 4, + weights_val_2 & 0x0F); + + // Note that the unpacked 4-bit values are unsigned, therefore they must + // first be "centered" around 0 by subtracting 8 before applying the + // scale and zero point. + sums[comp] += dot( + x_texel, (vec4(weights_texel) - 8.0) * scales[comp] + zeros[comp]); } - write_texel(t_out, out_pos.xyz, vec4(rc, 0, 0, 0)); - - #endif + } + } + write_texel(ret, ret_pos, sums); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.yaml b/backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.yaml index fd65068080a..40d95d4a05f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.yaml @@ -7,13 +7,10 @@ q_4w_linear: parameter_names_with_default_values: DTYPE: float - STORAGE: buffer + STORAGE: texture3d generate_variant_forall: DTYPE: - VALUE: float - VALUE: half - STORAGE: - - VALUE: buffer - - VALUE: texture3d shader_variants: - - NAME: q_4w_linear + - NAME: q_4w_linear_texture3d diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp index 838605f05f3..4dd55be4692 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp @@ -16,7 +16,7 @@ namespace vkcompute { -void check_qlinear_args( +void check_q_8w_linear_args( const ComputeGraph& graph, const ValueRef mat1, const ValueRef qmat2_data, @@ -38,7 +38,7 @@ void check_qlinear_args( utils::val_at(-1, scales_sizes) == utils::val_at(-2, qmat2_sizes)); } -void resize_qlinear_node( +void resize_q_8w_linear_node( ComputeGraph* graph, const std::vector& args, const std::vector& extra_args) { @@ -123,7 +123,7 @@ void add_q_8w_linear_node( // Specialization Constants {}, // Resizing Logic - resize_qlinear_node)); + resize_q_8w_linear_node)); if (!graph.is_buffer_storage(out) && graph.packed_dim_of(out) != WHCN::kWidthDim) { viewFn(graph, {out_W_packed, graph.add_none(), out}); @@ -133,12 +133,138 @@ void add_q_8w_linear_node( void weight_int8pack_mm( ComputeGraph& graph, const std::vector& args) { - check_qlinear_args(graph, args[0], args[1], args[2], args[3]); + check_q_8w_linear_args(graph, args[0], args[1], args[2], args[3]); return add_q_8w_linear_node(graph, args[0], args[1], args[2], args[3]); } +void check_q_4w_linear_args( + ComputeGraph& graph, + const ValueRef mat1, + const ValueRef mat2_data, + const ValueRef group_size, + const ValueRef scales_and_zeros, + const ValueRef out) { + VK_CHECK_COND(graph.int16_shader_types_enabled()); + VK_CHECK_COND(graph.int8_buffers_enabled()); + + VK_CHECK_COND(graph.val_is_tensor(mat1)); + VK_CHECK_COND(graph.val_is_tref(mat2_data)); + VK_CHECK_COND(graph.val_is_tref(scales_and_zeros)); + + VK_CHECK_COND(graph.dim_of(mat1) <= 3); + VK_CHECK_COND(graph.dim_of(mat2_data) == 2); + VK_CHECK_COND(graph.dim_of(scales_and_zeros) == 3); + + VK_CHECK_COND(graph.size_at(-3, mat1) == 1); + const int K = graph.size_at(-1, mat1); + VK_CHECK_COND(graph.size_at(-1, mat2_data) * 2 == K); + + const int group_size_val = graph.extract_scalar(group_size); + VK_CHECK_COND(K % group_size_val == 0); + + VK_CHECK_COND(graph.packed_dim_of(mat1) == WHCN::kWidthDim); + VK_CHECK_COND(graph.packed_dim_of(out) == WHCN::kWidthDim); + + VK_CHECK_COND(graph.has_standard_axis_map(mat1)); + VK_CHECK_COND(graph.has_standard_axis_map(out)); +} + +void resize_q_4w_linear_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + (void)extra_args; + + vTensorPtr out = graph->get_tensor(args[0].refs[0]); + vTensorPtr mat1 = graph->get_tensor(args[1].refs[0]); + vTensorPtr mat2 = graph->get_tensor(args[1].refs[1]); + + const int out_cols = utils::val_at(-2, mat1->sizes()); + const int out_rows = utils::val_at(-2, mat2->sizes()); + + std::vector new_out_sizes(3); + if (mat1->sizes().size() == 2) { + new_out_sizes.resize(2); + new_out_sizes.at(0) = out_cols; + new_out_sizes.at(1) = out_rows; + } else { + new_out_sizes.at(0) = mat1->sizes().at(0); + new_out_sizes.at(1) = out_cols; + new_out_sizes.at(2) = out_rows; + } + + out->virtual_resize(new_out_sizes); +} + +void add_q_4w_linear_node( + ComputeGraph& graph, + const ValueRef mat1, + const ValueRef mat2_data, + const ValueRef group_size, + const ValueRef scales_and_zeros_data, + const ValueRef out) { + check_q_4w_linear_args( + graph, mat1, mat2_data, group_size, scales_and_zeros_data, out); + + utils::StorageType storage_type = graph.storage_type_of(out); + + ValueRef mat2 = + prepack_buffer_if_tensor_ref(graph, mat2_data, utils::kWidthPacked); + + ValueRef scales_and_zeros = + prepack_if_tensor_ref(graph, scales_and_zeros_data, utils::kWidthPacked); + + std::string kernel_name = "q_4w_linear"; + add_storage_type_suffix(kernel_name, storage_type); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + const uint32_t group_size_val = graph.extract_scalar(group_size); + + vkapi::ParamsBindList ubos({}); + ubos.append(graph.logical_limits_ubo(out)); + ubos.append(graph.sizes_ubo(mat1)); + ubos.append(graph.strides_ubo(mat2)); + ubos.append(graph.strides_ubo(scales_and_zeros)); + + utils::uvec3 global_wg_size = graph.logical_limits_of(out); + utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size); + + graph.execute_nodes().emplace_back(new DispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_wg_size, + local_wg_size, + // Inputs and Outputs + {{out, vkapi::MemoryAccessType::WRITE}, + {{mat1, mat2, scales_and_zeros}, vkapi::MemoryAccessType::READ}}, + // Shader params buffers + ubos, + // Specialization Constants + {SV(group_size_val)}, + // Resizing Logic + resize_q_4w_linear_node, + {})); +} + +void linear_weight_int4( + ComputeGraph& graph, + const std::vector& args) { + return add_q_4w_linear_node( + graph, + args[0], // mat1 + args[1], // mat2 + args[2], // group_size + args[3], // scales_and_zeros + // There is an unused variable inner_k_tiles which is used to call + // _convert_weight_to_int4pack in the AOT custom op, which is why the 4th + // argument is skipped. + args[5] // out + ); +} + REGISTER_OPERATORS { VK_REGISTER_OP(aten._weight_int8pack_mm.default, weight_int8pack_mm); + VK_REGISTER_OP(et_vk.linear_weight_int4.default, linear_weight_int4); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedMatMul.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedMatMul.cpp deleted file mode 100644 index 17291d292a3..00000000000 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedMatMul.cpp +++ /dev/null @@ -1,184 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include -#include -#include - -namespace vkcompute { - -void check_q_matmul_args( - ComputeGraph& graph, - const ValueRef mat1, - const ValueRef mat2_data, - const ValueRef group_size_data, - const ValueRef scales_and_zeros, - const ValueRef out) { - const std::vector mat1_sizes = graph.sizes_of(mat1); - const std::vector mat2_sizes = graph.sizes_of(mat2_data); - const std::vector scales_and_zeros_sizes = - graph.sizes_of(scales_and_zeros); - - const uint32_t group_size = graph.extract_scalar(group_size_data); - - VK_CHECK_COND(mat1_sizes.size() == 2); - VK_CHECK_COND(mat1_sizes.size() == mat2_sizes.size()); - - using namespace WHCN; - VK_CHECK_COND(graph.packed_dim_of(mat1) == kWidthDim); - VK_CHECK_COND(graph.packed_dim_of(mat2_data) == kWidthDim); - // VK_CHECK_COND(graph.packed_dim_of(scales_and_zeros) == kWidthDim); - - if (graph.storage_type_of(scales_and_zeros) == utils::kBuffer) { - VK_CHECK_COND(graph.packed_dim_of(scales_and_zeros) == kWidthDim); - } else { - VK_CHECK_COND(graph.packed_dim_of(scales_and_zeros) == kChannelsDim); - } - - if (graph.storage_type_of(out) == utils::kBuffer) { - VK_CHECK_COND(graph.packed_dim_of(out) == kWidthDim); - } else { - VK_CHECK_COND(graph.packed_dim_of(out) == kChannelsDim); - } - - const int mat1_K = utils::val_at(-1, mat1_sizes); - const int mat2_K = utils::val_at(-1, mat2_sizes) * 2; - const int N = utils::val_at(-2, mat2_sizes); - - VK_CHECK_COND(mat1_K == mat2_K); - - VK_CHECK_COND(mat2_K % group_size == 0); - - const uint32_t k_groups = mat2_K / group_size; - - VK_CHECK_COND(scales_and_zeros_sizes.size() == 3); - VK_CHECK_COND(utils::val_at(-1, scales_and_zeros_sizes) == k_groups); - VK_CHECK_COND(utils::val_at(-2, scales_and_zeros_sizes) == N); - VK_CHECK_COND(utils::val_at(-3, scales_and_zeros_sizes) == 2); - - // Match https://fburl.com/code/6ostkknm - std::vector valid_group_sizes = {32, 64, 128, 256}; - - bool is_valid_group_size = false; - for (auto valid_group_size : valid_group_sizes) { - if (group_size == valid_group_size) { - is_valid_group_size = true; - break; - } - } - - VK_CHECK_COND(is_valid_group_size); -} - -void resize_q_matmul_node( - ComputeGraph* graph, - const std::vector& args, - const std::vector& extra_args) { - (void)extra_args; - - vTensorPtr out = graph->get_tensor(args[0].refs[0]); - vTensorPtr mat1 = graph->get_tensor(args[1].refs[0]); - vTensorPtr mat2 = graph->get_tensor(args[1].refs[1]); - - const int out_cols = utils::val_at(-2, mat1->sizes()); - const int out_rows = utils::val_at(-2, mat2->sizes()); - - std::vector new_out_sizes(3); - if (mat1->sizes().size() == 2) { - new_out_sizes.resize(2); - new_out_sizes.at(0) = out_cols; - new_out_sizes.at(1) = out_rows; - } else { - new_out_sizes.at(0) = mat1->sizes().at(0); - new_out_sizes.at(1) = out_cols; - new_out_sizes.at(2) = out_rows; - } - - out->virtual_resize(new_out_sizes); -} - -void add_q_matmul_node( - ComputeGraph& graph, - const ValueRef mat1, - const ValueRef mat2_data, - const ValueRef group_size, - const ValueRef scales_and_zeros_data, - const ValueRef out) { - auto storage_type = graph.storage_type_of(out); - - ValueRef mat2 = - prepack_buffer_if_tensor_ref(graph, mat2_data, utils::kWidthPacked); - - ValueRef scales_and_zeros = - prepack_if_tensor_ref(graph, scales_and_zeros_data, utils::kWidthPacked); - - std::string kernel_name = "q_4w_linear"; - - add_dtype_suffix(kernel_name, graph.dtype_of(out)); - add_storage_type_suffix(kernel_name, storage_type); - - const uint32_t group_size_val = graph.extract_scalar(group_size); - - vkapi::ParamsBindList ubos({}); - if (storage_type == utils::kBuffer) { - ubos.append(graph.sizes_ubo(out)); - ubos.append(graph.strides_ubo(out)); - ubos.append(graph.sizes_ubo(mat1)); - ubos.append(graph.strides_ubo(mat1)); - ubos.append(graph.strides_ubo(mat2)); - ubos.append(graph.strides_ubo(scales_and_zeros)); - } else { - ubos.append(graph.sizes_ubo(out)); - ubos.append(graph.sizes_ubo(mat1)); - ubos.append(graph.strides_ubo(mat2)); - ubos.append(graph.strides_ubo(scales_and_zeros)); - } - - auto out_sizes = graph.sizes_of(out); - uint32_t N = utils::val_at(-1, out_sizes); - uint32_t M = utils::val_at(-2, out_sizes); - - utils::uvec3 global_wg_size = {N, M, 1}; - - utils::uvec3 local_wg_size = adaptive_work_group_size(global_wg_size); - - graph.execute_nodes().emplace_back(new DispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - global_wg_size, - local_wg_size, - // Inputs and Outputs - {{out, vkapi::MemoryAccessType::WRITE}, - {{mat1, mat2, scales_and_zeros}, vkapi::MemoryAccessType::READ}}, - // Shader params buffers - ubos, - // Specialization Constants - {SV(group_size_val)}, - // Resizing Logic - resize_q_matmul_node, - {})); -} - -void int4pack_mm(ComputeGraph& graph, const std::vector& args) { - check_q_matmul_args(graph, args[0], args[1], args[2], args[3], args[4]); - return add_q_matmul_node( - graph, - args[0], // mat1 - args[1], // mat2 - args[2], // group_size - args[3], // scales_and_zeros - args[4] // out - ); -} - -REGISTER_OPERATORS { - VK_REGISTER_OP(aten._weight_int4pack_mm.default, int4pack_mm); -} - -} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp index d634947a510..4a709fce994 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp @@ -9,6 +9,7 @@ #include #include +#include #include #include @@ -147,7 +148,9 @@ ValueRef prepack_buffer( const utils::GPUMemoryLayout layout) { ValueRef v = graph.add_tensor_like(vref, utils::kBuffer, layout); - vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR("buffer_to_buffer"); + std::string kernel_name = "buffer_to_buffer"; + add_dtype_suffix(kernel_name, graph.dtype_of(vref)); + vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name); vkapi::ParamsBindList ubos; ubos.append({graph.numel_ubo(v)}); diff --git a/backends/vulkan/runtime/vk_api/Adapter.h b/backends/vulkan/runtime/vk_api/Adapter.h index f03f06e1f48..545f59502ef 100644 --- a/backends/vulkan/runtime/vk_api/Adapter.h +++ b/backends/vulkan/runtime/vk_api/Adapter.h @@ -155,30 +155,34 @@ class Adapter final { // Physical Device Features - inline bool has_16bit_storage() { + inline bool supports_16bit_storage_buffers() { return physical_device_.shader_16bit_storage.storageBuffer16BitAccess == VK_TRUE; } - inline bool has_8bit_storage() { + inline bool supports_8bit_storage_buffers() { return physical_device_.shader_8bit_storage.storageBuffer8BitAccess == VK_TRUE; } - inline bool has_16bit_compute() { + inline bool supports_float16_shader_types() { return physical_device_.shader_float16_int8_types.shaderFloat16 == VK_TRUE; } - inline bool has_8bit_compute() { + inline bool supports_int8_shader_types() { return physical_device_.shader_float16_int8_types.shaderInt8 == VK_TRUE; } + inline bool supports_int16_shader_types() { + return physical_device_.supports_int16_shader_types; + } + inline bool has_full_float16_buffers_support() { - return has_16bit_storage() && has_16bit_compute(); + return supports_16bit_storage_buffers() && supports_float16_shader_types(); } inline bool has_full_int8_buffers_support() { - return has_8bit_storage() && has_8bit_compute(); + return supports_16bit_storage_buffers() && supports_int8_shader_types(); } // Command Buffer Submission diff --git a/backends/vulkan/runtime/vk_api/Device.cpp b/backends/vulkan/runtime/vk_api/Device.cpp index 46e534f09f3..08d4565dbab 100644 --- a/backends/vulkan/runtime/vk_api/Device.cpp +++ b/backends/vulkan/runtime/vk_api/Device.cpp @@ -30,6 +30,7 @@ PhysicalDevice::PhysicalDevice(VkPhysicalDevice physical_device_handle) VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_FLOAT16_INT8_FEATURES_KHR}, queue_families{}, num_compute_queues(0), + supports_int16_shader_types(false), has_unified_memory(false), has_timestamps(properties.limits.timestampComputeAndGraphics), timestamp_period(properties.limits.timestampPeriod), @@ -49,6 +50,10 @@ PhysicalDevice::PhysicalDevice(VkPhysicalDevice physical_device_handle) vkGetPhysicalDeviceFeatures2(handle, &features2); + if (features2.features.shaderInt16 == VK_TRUE) { + supports_int16_shader_types = true; + } + // Check if there are any memory types have both the HOST_VISIBLE and the // DEVICE_LOCAL property flags const VkMemoryPropertyFlags unified_memory_flags = diff --git a/backends/vulkan/runtime/vk_api/Device.h b/backends/vulkan/runtime/vk_api/Device.h index 9f4b83540e5..6d6e28857af 100644 --- a/backends/vulkan/runtime/vk_api/Device.h +++ b/backends/vulkan/runtime/vk_api/Device.h @@ -35,6 +35,7 @@ struct PhysicalDevice final { // Metadata uint32_t num_compute_queues; + bool supports_int16_shader_types; bool has_unified_memory; bool has_timestamps; float timestamp_period; diff --git a/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp b/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp new file mode 100644 index 00000000000..63ebb96cfaa --- /dev/null +++ b/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp @@ -0,0 +1,224 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include +#include + +#include + +// +// Reference Implementations +// + +at::Tensor linear_weight_int4_reference_impl( + const at::Tensor& x, + const at::Tensor& weights_4x2, + const int64_t groupsize, + const at::Tensor& scales_and_zeros, + const int64_t inner_k_tiles) { + const std::vector original_x_size(x.sizes().vec()); + const size_t ndim = original_x_size.size(); + const int64_t out_features = weights_4x2.size(0); + const at::Tensor x_flattened = x.reshape({-1, original_x_size[ndim - 1]}); + const at::Tensor packed_weights = + at::_convert_weight_to_int4pack(weights_4x2, inner_k_tiles); + at::Tensor out = at::_weight_int4pack_mm( + x_flattened, packed_weights, groupsize, scales_and_zeros); + std::vector out_shape( + original_x_size.begin(), original_x_size.end()); + out_shape.at(ndim - 1) = out_features; + return out.reshape(out_shape); +} + +at::Tensor dequantize_and_linear( + const at::Tensor& x, + const at::Tensor& weights_4x2, + const int64_t groupsize, + const at::Tensor& scales_and_zeros, + const int64_t inner_k_tiles) { + std::vector weights_shape(weights_4x2.sizes().vec()); + weights_shape[1] *= 2; + + at::Tensor weights_dequantized = + at::empty(weights_shape, at::device(at::kCPU).dtype(at::kFloat)); + + const int64_t N = weights_dequantized.size(0); + const int64_t K = weights_dequantized.size(1); + + const int k_groups = K / groupsize; + for (int n = 0; n < N; n++) { + for (int k = 0; k < K; k += 2) { + const int group_idx = k / groupsize; + // const int scale_idx = k_groups * n + group_idx; + const uint8_t packed_val = weights_4x2[n][k / 2].item().to(); + const uint8_t second_val = packed_val & 0x0F; + const uint8_t first_val = (packed_val & 0xF0) >> 4; + + const float scale = scales_and_zeros[group_idx][n][0].item().to(); + const float zero = scales_and_zeros[group_idx][n][1].item().to(); + + weights_dequantized[n][k] = (float(first_val) - 8.0) * scale + zero; + weights_dequantized[n][k + 1] = (float(second_val) - 8.0) * scale + zero; + } + } + + return at::linear(x, weights_dequantized); +} + +// +// Test functions +// + +void test_reference_linear_int4( + const int B, + const int M, + const int K, + const int N, + const int group_size = 32, + const int inner_k_tiles = 8) { + assert(K % group_size == 0); + + at::Tensor x = at::rand({B, M, K}, at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor weights_4x2 = + at::randint(0, 256, {N, K / 2}, at::device(at::kCPU).dtype(at::kByte)); + + const int k_groups = K / group_size; + at::Tensor scales_and_zeros = + at::rand({k_groups, N, 2}, at::device(at::kCPU).dtype(at::kFloat)); + + at::Tensor out = linear_weight_int4_reference_impl( + x, weights_4x2, group_size, scales_and_zeros, inner_k_tiles); + + at::Tensor out_ref = dequantize_and_linear( + x, weights_4x2, group_size, scales_and_zeros, inner_k_tiles); + + ASSERT_TRUE(at::allclose(out, out_ref)); +} + +vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) { + using namespace vkcompute; + switch (at_scalartype) { + case c10::kFloat: + return vkapi::kFloat; + case c10::kHalf: + return vkapi::kHalf; + case c10::kInt: + return vkapi::kInt; + case c10::kLong: + return vkapi::kInt; + case c10::kChar: + return vkapi::kChar; + case c10::kByte: + return vkapi::kByte; + default: + VK_THROW("Unsupported at::ScalarType!"); + } +} + +void test_vulkan_linear_int4( + const int B, + const int M, + const int K, + const int N, + const int group_size = 32, + const int inner_k_tiles = 8) { + assert(K % group_size == 0); + + at::Tensor x = at::rand({B, M, K}, at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor weights_4x2 = + at::randint(0, 256, {N, K / 2}, at::device(at::kCPU).dtype(at::kByte)); + + const int k_groups = K / group_size; + at::Tensor scales_and_zeros = + at::rand({k_groups, N, 2}, at::device(at::kCPU).dtype(at::kFloat)); + + at::Tensor out_ref = dequantize_and_linear( + x, weights_4x2, group_size, scales_and_zeros, inner_k_tiles); + + // Build Vulkan graph + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(utils::kTexture3D); + ComputeGraph graph(config); + +#define MAKE_TENSORREF_FOR(x) \ + ValueRef r_##x = graph.add_tensorref( \ + x.sizes().vec(), \ + from_at_scalartype(x.scalar_type()), \ + x.const_data_ptr()); + + MAKE_TENSORREF_FOR(weights_4x2); + MAKE_TENSORREF_FOR(scales_and_zeros); + +#define MAKE_INPUT_FOR(x) \ + IOValueRef r_##x = graph.add_input_tensor( \ + x.sizes().vec(), from_at_scalartype(x.scalar_type())); + + MAKE_INPUT_FOR(x); + + const ValueRef r_out = graph.add_tensor( + out_ref.sizes().vec(), from_at_scalartype(out_ref.scalar_type())); + + VK_GET_OP_FN("et_vk.linear_weight_int4.default") + (graph, + {r_x.value, + r_weights_4x2, + graph.add_scalar(group_size), + r_scales_and_zeros, + kDummyValueRef, + r_out}); + + ValueRef staging_out = graph.set_output_tensor(r_out); + + graph.prepare(); + graph.encode_prepack(); + graph.prepack(); + graph.encode_execute(); + + // + // Run model + // + + graph.propagate_resize(); + graph.copy_into_staging(r_x.staging, x.const_data_ptr(), x.numel()); + + graph.execute(); + + at::Tensor vk_out = at::empty_like(out_ref); + graph.copy_from_staging( + staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); + + ASSERT_TRUE(at::allclose(vk_out, out_ref, 1e-4, 1e-4)); +} + +TEST(VulkanInt4LinearTest, test_reference_impl) { + test_reference_linear_int4( + /*B = */ 1, + /*M = */ 4, + /*K = */ 128, + /*N = */ 32); +} + +TEST(VulkanInt4LinearTest, test_vulkan_impl) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_linear_int4( + /*B = */ 1, + /*M = */ 4, + /*K = */ 128, + /*N = */ 32); +} diff --git a/backends/vulkan/test/op_tests/targets.bzl b/backends/vulkan/test/op_tests/targets.bzl index 3acf1debe50..270e1b768a8 100644 --- a/backends/vulkan/test/op_tests/targets.bzl +++ b/backends/vulkan/test/op_tests/targets.bzl @@ -186,3 +186,41 @@ def define_common_targets(is_fbcode = False): runtime.external_dep_location("libtorch"), ], ) + + runtime.cxx_binary( + name = "linear_weight_int4_test_bin", + srcs = [ + "linear_weight_int4_test.cpp", + ], + compiler_flags = [ + "-Wno-unused-variable", + ], + define_static_target = False, + deps = [ + "//third-party/googletest:gtest_main", + "//executorch/backends/vulkan:vulkan_graph_runtime", + runtime.external_dep_location("libtorch"), + ], + ) + + runtime.cxx_test( + name = "linear_weight_int4_test", + srcs = [ + "linear_weight_int4_test.cpp", + ], + contacts = ["oncall+ai_infra_mobile_platform@xmail.facebook.com"], + fbandroid_additional_loaded_sonames = [ + "torch-code-gen", + "vulkan_graph_runtime", + "vulkan_graph_runtime_shaderlib", + ], + platforms = [ANDROID], + use_instrumentation_test = True, + deps = [ + "//third-party/googletest:gtest_main", + "//executorch/backends/vulkan:vulkan_graph_runtime", + "//executorch/extension/llm/custom_ops:custom_ops_aot_lib", + "//executorch/extension/tensor:tensor", + runtime.external_dep_location("libtorch"), + ], + ) diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index ca8558fe0ea..694eeebecee 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -2379,7 +2379,8 @@ void run_from_gpu_test( utils::GPUMemoryLayout::TENSOR_CHANNELS_PACKED, vkapi::ScalarType dtype = vkapi::kFloat, utils::StorageType storage_type = utils::StorageType::TEXTURE_3D) { - if (dtype == vkapi::kHalf && !context()->adapter_ptr()->has_16bit_storage()) { + if (dtype == vkapi::kHalf && + !context()->adapter_ptr()->supports_16bit_storage_buffers()) { return; } vTensor vten = vTensor(context(), sizes, dtype, storage_type, memory_layout); @@ -2433,7 +2434,8 @@ void round_trip_test( utils::GPUMemoryLayout::TENSOR_CHANNELS_PACKED, vkapi::ScalarType dtype = vkapi::kFloat, utils::StorageType storage_type = utils::StorageType::TEXTURE_3D) { - if (dtype == vkapi::kHalf && !context()->adapter_ptr()->has_16bit_storage()) { + if (dtype == vkapi::kHalf && + !context()->adapter_ptr()->supports_16bit_storage_buffers()) { return; } @@ -2484,7 +2486,8 @@ void compute_graph_round_trip_test( utils::GPUMemoryLayout::TENSOR_CHANNELS_PACKED, vkapi::ScalarType dtype = vkapi::kFloat, utils::StorageType storage_type = utils::StorageType::TEXTURE_3D) { - if (dtype == vkapi::kHalf && !context()->adapter_ptr()->has_16bit_storage()) { + if (dtype == vkapi::kHalf && + !context()->adapter_ptr()->supports_16bit_storage_buffers()) { return; } @@ -3026,142 +3029,6 @@ TEST(VulkanComputeGraphOpsTest, grid_priors_test) { /*data_out_expected = */ {4, 4, 12, 4, 20, 4, 4, 12, 12, 12, 20, 12}); } -void test_int4pack_mm( - std::vector MKN, - uint32_t group_size, - utils::StorageType storage_type) { - GraphConfig config; - ComputeGraph graph(config); - - const uint32_t M = MKN[0]; - const uint32_t K = MKN[1]; - const uint32_t N = MKN[2]; - - const std::vector mat1_size = {M, K}; - const std::vector mat2_size = {K, N}; - const std::vector mat2_q_size = {N, K / 2}; // Transposed and packed - const std::vector out_size = {M, N}; - - std::vector A_data = create_random_float_buffer(M * K); - IOValueRef A = graph.add_input_tensor(mat1_size, vkapi::kFloat, storage_type); - graph.copy_into_staging(A.staging, A_data.data(), A_data.size()); - - // Quantized but un-packed weights - std::vector B_quant_data = create_random_uint8_buffer(K * N, 0, 16); - - // Pack and transpose weights to correspond to int4 weight format - std::vector B_int4_data = - int4mm_pack_weights(mat2_size, B_quant_data.data()); - - IOValueRef B_int4 = - graph.add_input_tensor(mat2_q_size, vkapi::kQInt8, utils::kBuffer); - graph.copy_into_staging( - B_int4.staging, B_int4_data.data(), B_int4_data.size()); - - const int k_groups = K / group_size; - - // Random scales and zeroes. Keep scales small to avoid overflow and zeroes in - // int4 range - IOValueRef scales_and_zeros; - - if (storage_type == utils::kBuffer) { - scales_and_zeros.value = graph.add_tensor( - {2, N, k_groups}, vkapi::kFloat, storage_type, utils::kWidthPacked); - } else { - scales_and_zeros.value = graph.add_tensor( - {2, N, k_groups}, vkapi::kFloat, storage_type, utils::kChannelsPacked); - } - - scales_and_zeros.staging = graph.set_input_tensor(scales_and_zeros.value); - - std::vector s_data(graph.numel_of(scales_and_zeros.value)); - const int zeros_stride = s_data.size() / 2; - for (size_t i = 0; i < zeros_stride; i++) { - s_data[i] = rand() % 100; - s_data[i + zeros_stride] = rand() % 16; - } - - graph.copy_into_staging( - scales_and_zeros.staging, s_data.data(), s_data.size()); - - IOValueRef out_int4; - - if (storage_type == utils::kBuffer) { - out_int4.value = graph.add_tensor(out_size, vkapi::kFloat, utils::kBuffer); - } else { - out_int4.value = - graph.add_tensor(out_size, vkapi::kFloat, utils::kChannelsPacked); - } - - VK_GET_OP_FN("aten._weight_int4pack_mm.default") - (graph, - {A.value, - B_int4.value, - graph.add_scalar(group_size), - scales_and_zeros.value, - out_int4.value}); - - out_int4.staging = graph.set_output_tensor(out_int4.value); - - // Dequantized matmul for comparison - IOValueRef B_deq = - graph.add_input_tensor(mat2_size, vkapi::kFloat, storage_type); - std::vector B_deq_data = int4mm_dequantize_weights( - mat2_size, B_quant_data.data(), group_size, s_data.data()); - graph.copy_into_staging(B_deq.staging, B_deq_data.data(), B_deq_data.size()); - - IOValueRef out_deq; - out_deq.value = graph.add_tensor(out_size, vkapi::kFloat, storage_type); - - VK_GET_OP_FN("aten.mm.default") - (graph, {A.value, B_deq.value, out_deq.value}); - - out_deq.staging = graph.set_output_tensor(out_deq.value); - - graph.prepare(); - graph.encode_prepack(); - graph.prepack(); - graph.encode_execute(); - graph.propagate_resize(); - graph.execute(); - - // Compare outputs - std::vector out_int4_data(graph.numel_of(out_int4.value)); - graph.copy_from_staging( - out_int4.staging, out_int4_data.data(), out_int4_data.size()); - - std::vector out_deq_data(graph.numel_of(out_deq.value)); - graph.copy_from_staging( - out_deq.staging, out_deq_data.data(), out_deq_data.size()); - - for (int i = 0; i < out_int4_data.size(); i++) { - EXPECT_TRUE(check_close(out_int4_data[i], out_deq_data[i])); - } -} - -TEST(VulkanComputeGraphOpsTest, int4pack_mm_test) { - if (!context()->adapter_ptr()->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - - for (auto storage_type : {utils::kBuffer, utils::kTexture3D}) { - // Vector multiplication, single group per row - test_int4pack_mm({1, 32, 1}, 32, storage_type); - - // Vector multiplication, multiple groups per row - test_int4pack_mm({1, 256, 1}, 64, storage_type); - - // Square matrices, single group per row - test_int4pack_mm({32, 32, 32}, 32, storage_type); - - // Irregular matrices, single group per row - test_int4pack_mm({37, 32, 19}, 32, storage_type); - - // Irregular matrices, multiple groups per row - test_int4pack_mm({37, 256, 19}, 64, storage_type); - } -} - void test_transpose_view_mm( const int B, const int M, @@ -3355,7 +3222,7 @@ void test_to_copy() { } TEST(VulkanComputeGraphOpsTest, test_to_copy) { - if (context()->adapter_ptr()->has_16bit_storage()) { + if (context()->adapter_ptr()->supports_16bit_storage_buffers()) { test_to_copy(); } }