From de2aab69de54dc48dca6f577d215c8b077aeff87 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Mon, 31 Mar 2025 12:18:12 -0700 Subject: [PATCH 1/2] [ET-VK] Store weights transposed for int8 linear Pull Request resolved: https://github.com/pytorch/executorch/pull/9765 ## Context The weight tensor of a linear layer is usually stored in a transposed manner, such that when computing the matrix multiplication, the reduction traverses along the rows of the weight tensor as opposed to the columns. This results in a better memory access pattern for CPUs. However, for GPUs, I have found that "un-transposing" the weight tensors result in better performance. This is likely due to the fact since GPUs can compute multiple output elements in parallel, reading along the columns allows for coalescing memory loads among threads in a work group. ## Changes * Introduce the ability to transpose height and weight dims when transferring tensor data to the GPU. * Prepackthe weight tensor "un-transposed" for the int8 quantized linear operator ghstack-source-id: 275180033 @exported-using-ghexport Differential Revision: [D72066588](https://our.internmc.facebook.com/intern/diff/D72066588/) --- .../nchw_to_bitw8_image_nobitw8buffer.glsl | 21 ++++++++++-- .../graph/ops/glsl/nchw_to_buffer.glsl | 9 ++++- .../runtime/graph/ops/glsl/nchw_to_image.glsl | 21 ++++++++++-- .../runtime/graph/ops/glsl/q_8w_linear.glsl | 31 ++++++++--------- .../graph/ops/impl/QuantizedLinear.cpp | 4 +-- .../vulkan/runtime/graph/ops/impl/Staging.cpp | 34 +++++++++++++++++-- .../vulkan/runtime/graph/ops/impl/Staging.h | 12 +++++++ backends/vulkan/test/op_tests/cases.py | 4 ++- 8 files changed, 110 insertions(+), 26 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_bitw8_image_nobitw8buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_bitw8_image_nobitw8buffer.glsl index 25113887dca..327c3868847 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_bitw8_image_nobitw8buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_bitw8_image_nobitw8buffer.glsl @@ -27,6 +27,8 @@ ${layout_declare_ubo(B, "ivec4", "sizes")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; ${layout_declare_spec_const(C, "int", "t_layout", "DEFAULT_LAYOUT")} +${layout_declare_spec_const(C, "int", "transpose_hw", "0")} + const lowp ivec4 axis_map = unhash_axis_map(t_layout); const lowp int packed_dim = unhash_packed_dim(t_layout); @@ -41,8 +43,23 @@ int extend_sign(int x) { } ivec4 read_texel(ivec4 tidx) { + ivec4 tidx_to_use = tidx; + ivec4 sizes_to_use = sizes; + int packed_dim_to_use = packed_dim; + if (transpose_hw == 1) { + sizes_to_use.xy = sizes_to_use.yx; + tidx_to_use.xy = tidx.yx; + + if (packed_dim == 1) { + packed_dim_to_use = 0; + } + if (packed_dim == 0) { + packed_dim_to_use = 1; + } + } + const ivec4 buf_indices = tidx_to_nchwi( - tidx, sizes, packed_dim); + tidx_to_use, sizes_to_use, packed_dim_to_use); int shift = (1 << 8) - 1; ivec4 masks; @@ -70,7 +87,7 @@ ivec4 read_texel(ivec4 tidx) { void main() { const ivec3 lpos = ivec3(gl_GlobalInvocationID); - const ivec4 tidx = lpos_to_tidx(lpos, sizes, axis_map.w, packed_dim); + ivec4 tidx = lpos_to_tidx(lpos, sizes, axis_map.w, packed_dim); if (any(greaterThanEqual(tidx, sizes))) { return; diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.glsl index bf498f34d5b..32235a9ad65 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.glsl @@ -21,6 +21,7 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; // This constant is unused in this shader but is kept so that the signature is // consistent with nchw_to_image. ${layout_declare_spec_const(C, "int", "UNUSED_layout", "0")} +${layout_declare_spec_const(C, "int", "transpose_hw", "0")} void main() { int out_bufi = int(gl_GlobalInvocationID.x); @@ -29,7 +30,13 @@ void main() { } ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides); - const int in_nchwi = tidx_to_nchwi(out_tidx, out_sizes); + + ivec4 sizes = out_sizes; + if (transpose_hw == 1) { + sizes.xy = sizes.yx; + out_tidx.xy = out_tidx.yx; + } + const int in_nchwi = tidx_to_nchwi(out_tidx, sizes); t_out[out_bufi] = nchw_in[in_nchwi]; } diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl index 3d2a102dac7..2f55535c82c 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl @@ -30,14 +30,31 @@ $if not FROM_STAGING: layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; ${layout_declare_spec_const(C, "int", "t_layout", "DEFAULT_LAYOUT")} +${layout_declare_spec_const(C, "int", "transpose_hw", "0")} + const lowp ivec4 axis_map = unhash_axis_map(t_layout); const lowp int packed_dim = unhash_packed_dim(t_layout); VEC4_T read_texel(ivec4 tidx) { + ivec4 tidx_to_use = tidx; + ivec4 sizes_to_use = sizes; + int packed_dim_to_use = packed_dim; + if (transpose_hw == 1) { + sizes_to_use.xy = sizes_to_use.yx; + tidx_to_use.xy = tidx.yx; + + if (packed_dim == 1) { + packed_dim_to_use = 0; + } + if (packed_dim == 0) { + packed_dim_to_use = 1; + } + } + $if FROM_STAGING: - const ivec4 buf_indices = tidx_to_nchwi(tidx, sizes, packed_dim); + const ivec4 buf_indices = tidx_to_nchwi(tidx_to_use, sizes_to_use, packed_dim_to_use); $else: - const ivec4 buf_indices = tidx_to_4bufi(tidx, buf_strides, packed_dim); + const ivec4 buf_indices = tidx_to_4bufi(tidx_to_use, buf_strides, packed_dim_to_use); VEC4_T texel = VEC4_T(0); if (tidx[packed_dim] < sizes[packed_dim]) { diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl index 56bffaee675..228e2e8f870 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl @@ -64,24 +64,21 @@ void main() { FLOAT_T outval = FLOAT_T(0.0); - // Initial mat1 tensor idx will be (0, out_tidx.y, out_tidx.z, 0) int mat1_offset = out_tidx.y * mat1_strides.y + out_tidx.z * qmat2_strides.z; - // Initial qmat2 tensor idx wil be (0, out_tidx.x, 0, 0); note that the qmat2 - // tensor is transposed - int qmat2_offset = out_tidx.x * qmat2_strides.y; + int qmat2_offset = out_tidx.x; // TODO(ssjia): optimize memory access pattern by traversing mat1 x in inner loop for (int i = 0; i < mat1_sizes.x; i++) { const FLOAT_T mat1_val = t_mat1[mat1_offset]; - const FLOAT_T mat2_val = t_qmat2[qmat2_offset] * scale; + const FLOAT_T mat2_val = FLOAT_T(t_qmat2[qmat2_offset]); outval += mat1_val * mat2_val; mat1_offset++; - qmat2_offset++; + qmat2_offset += qmat2_strides.y; } - t_out[out_bufi] = outval; + t_out[out_bufi] = outval * scale; } #else // USING_TEXTURE @@ -97,25 +94,27 @@ void main() { return; } - const uint16_t qmat2_pos_y = out_pos.x * uint16_t(4); + const uint16_t qmat2_pos_x = out_pos.x; VEC4_T outtex = VEC4_T(0); const VEC4_T scales = load_texel(t_scales, u16vec3(out_pos.x, 0, 0)); + VEC4_T mat1_tex; + VEC4_T mat2_tex[4]; for ( uint16_t i = uint16_t(0), x = uint16_t(0); i < uint16_t(mat1_sizes.x); i += uint16_t(4), x++) { - const VEC4_T mat1_tex = load_texel(t_mat1, u16vec3(x, out_pos.y, 0)); - const VEC4_T sums = VEC4_T( - dot(mat1_tex, load_texel(t_qmat2, u16vec3(x, qmat2_pos_y, 0))), - dot(mat1_tex, load_texel(t_qmat2, u16vec3(x, qmat2_pos_y + uint16_t(1), 0))), - dot(mat1_tex, load_texel(t_qmat2, u16vec3(x, qmat2_pos_y + uint16_t(2), 0))), - dot(mat1_tex, load_texel(t_qmat2, u16vec3(x, qmat2_pos_y + uint16_t(3), 0)))); - - outtex += sums; + mat1_tex = load_texel(t_mat1, u16vec3(x, out_pos.y, 0)); + + mat2_tex[0] = load_texel(t_qmat2, u16vec3(out_pos.x, i, 0)); + mat2_tex[1] = load_texel(t_qmat2, u16vec3(out_pos.x, i + uint16_t(1), 0)); + mat2_tex[2] = load_texel(t_qmat2, u16vec3(out_pos.x, i + uint16_t(2), 0)); + mat2_tex[3] = load_texel(t_qmat2, u16vec3(out_pos.x, i + uint16_t(3), 0)); + + outtex += mat1_tex.x * mat2_tex[0] + mat1_tex.y * mat2_tex[1] + mat1_tex.z * mat2_tex[2] + mat1_tex.w * mat2_tex[3]; } outtex *= scales; diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp index 59684d73bd2..2011331ec38 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp @@ -48,7 +48,7 @@ void resize_q_8w_linear_node( vTensorPtr qmat2 = graph->get_tensor(args[1].refs[1]); const int out_cols = utils::val_at(-2, mat1->sizes()); - const int out_rows = utils::val_at(-2, qmat2->sizes()); + const int out_rows = utils::val_at(-1, qmat2->sizes()); std::vector new_out_sizes(3); if (mat1->sizes().size() == 2) { @@ -86,7 +86,7 @@ void add_q_8w_linear_node( // Ensure out is packed correctly out_W_packed = out_tmp; } - ValueRef q_mat2 = prepack_standard( + ValueRef q_mat2 = prepack_standard_hw_transposed( graph, q_mat2_data, graph.storage_type_of(out), utils::kWidthPacked); ValueRef scales = prepack_standard( graph, scales_data, graph.storage_type_of(out), utils::kWidthPacked); diff --git a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp index 959d3974b73..f59d1cd65d9 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp @@ -113,7 +113,8 @@ void add_tensor_to_staging_node( void add_prepack_standard_node( ComputeGraph& graph, const ValueRef tensor_data, - const ValueRef tensor) { + const ValueRef tensor, + const bool transpose_hw = false) { vkapi::ShaderInfo shader = get_nchw_to_tensor_shader( *graph.get_tensor(tensor), graph.int8_buffers_enabled()); @@ -127,6 +128,8 @@ void add_prepack_standard_node( ubos.append({graph.sizes_ubo(tensor)}); } + int transpose_hw_spec = transpose_hw ? 1 : 0; + graph.prepack_nodes().emplace_back(new PrepackNode( graph, shader, @@ -138,7 +141,7 @@ void add_prepack_standard_node( // Parameter Buffers ubos, // Specialization Constants - {graph.hashed_layout_of(tensor)})); + {graph.hashed_layout_of(tensor), transpose_hw_spec})); } ValueRef prepack_standard( @@ -158,6 +161,33 @@ ValueRef prepack_standard( return tensor; } +ValueRef prepack_standard_hw_transposed( + ComputeGraph& graph, + const ValueRef tensor_data, + const utils::StorageType storage_type, + const utils::GPUMemoryLayout layout, + const bool passthrough, + const utils::AxisMapLayout axis_map_layout) { + (void)passthrough; + + VK_CHECK_COND(graph.val_is_tref(tensor_data)); + std::vector new_out_sizes = graph.sizes_of(tensor_data); + const int w_dim = new_out_sizes.size() - 1; + const int h_dim = new_out_sizes.size() - 2; + const int64_t tmp = new_out_sizes.at(w_dim); + new_out_sizes.at(w_dim) = new_out_sizes.at(h_dim); + new_out_sizes.at(h_dim) = tmp; + ValueRef tensor = graph.add_tensor( + new_out_sizes, + graph.dtype_of(tensor_data), + storage_type, + layout, + -1, + axis_map_layout); + add_prepack_standard_node(graph, tensor_data, tensor, true); + return tensor; +} + ValueRef prepack_standard_like( ComputeGraph& graph, const ValueRef tensor_data, diff --git a/backends/vulkan/runtime/graph/ops/impl/Staging.h b/backends/vulkan/runtime/graph/ops/impl/Staging.h index bc501d5d053..1b6f245bd34 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Staging.h +++ b/backends/vulkan/runtime/graph/ops/impl/Staging.h @@ -51,6 +51,18 @@ ValueRef prepack_standard( const bool passthrough = false, const utils::AxisMapLayout axis_map_layout = utils::kDefaultAxisMap); +/* + * Same as prepack_standard, but transpose the height and width dimensions of + * the tensor while packing. + */ +ValueRef prepack_standard_hw_transposed( + ComputeGraph& graph, + const ValueRef tensor_data, + const utils::StorageType storage_type, + const utils::GPUMemoryLayout layout, + const bool passthrough = false, + const utils::AxisMapLayout axis_map_layout = utils::kDefaultAxisMap); + /* * Equivalent to `prepack_standard()` function, except the `storage_type` and * `memory_layout` are set to match `to_copy`, which must be a `Tensor`. diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 41d8edf1f25..329d62c2285 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -157,12 +157,14 @@ def get_weight_int8pack_mm_inputs(): [6, 1024, 256], [6, 256, 256], [6, 256, 512], + [4, 768, 4096], + [1024, 1024, 1024], ] inputs_list = [((M, K), (N, K), (N)) for M, K, N in MKN_list] test_suite = VkTestSuite(inputs_list) - test_suite.dtypes = ["at::kFloat", "at::kHalf"] + test_suite.dtypes = ["at::kFloat"] test_suite.layouts = ["utils::kWidthPacked"] test_suite.storage_types = ["utils::kTexture3D", "utils::kBuffer"] test_suite.prepacked_args = ["mat2", "scales"] From 1907ae282b5b8e89221ea1c01aa4bae2235baf4e Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Mon, 31 Mar 2025 12:18:13 -0700 Subject: [PATCH 2/2] [ET-VK] Efficient tiled int8 matmul Pull Request resolved: https://github.com/pytorch/executorch/pull/9766 ## Context Introduce a optimized tiled implementation for computing the weight int8-quantized linear operation. This implementation takes advantage of the following principles to squeeze out performance: * Compute an output tile with each thread, rather than a single output element. This allows for better memory re-use of loaded input tensor data. * Compute the output tile by iteratively loading tiles of the input matrices, caching them in registers, and then performing the `fma` accumulations to obtain a partial output. By splitting the data loading and computation into distinct steps, the GPU is able to perform latency hiding more effectively, i.e. switching to a warp that needs to perform compute when the current warp is waiting on data load * Use a work group size of `{N, 1, 1}`. This makes it so that all the threads in a work group load the same row of the input matrx, and consecutive columns of the weight matrix. This way, the row of the input is kept hot in the cache, and accesses to the weight matrix can be coalesced due to the previous diff un-transposing the weight matrix. Differential Revision: [D72066587](https://our.internmc.facebook.com/intern/diff/D72066587/) ghstack-source-id: 275180032 --- .../graph/ops/glsl/q_8w_linear_optimized.glsl | 212 ------------------ .../graph/ops/glsl/q_8w_linear_optimized.yaml | 35 --- .../graph/ops/glsl/q_8w_linear_tiled.glsl | 92 ++++++++ .../graph/ops/glsl/q_8w_linear_tiled.yaml | 18 ++ .../graph/ops/impl/QuantizedLinear.cpp | 137 +++++------ 5 files changed, 184 insertions(+), 310 deletions(-) delete mode 100644 backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_optimized.glsl delete mode 100644 backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_optimized.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.yaml diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_optimized.glsl b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_optimized.glsl deleted file mode 100644 index b8d7622f94d..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_optimized.glsl +++ /dev/null @@ -1,212 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} - -#define VEC4_T ${texel_load_type(DTYPE, STORAGE)} -#define FLOAT_T ${buffer_scalar_type(DTYPE)} - -${define_active_storage_type(STORAGE)} - -${define_required_extensions(DTYPE)} -$if STORAGE == "buffer": - ${define_required_extensions("int8")} - - -$if BATCH_MODE: - #define BATCH_MODE - -#define TILE_ROWS ${TILE_ROWS} -#define FOUR 4 - -// we avoid mat4 and vec4 usage here as they compile to much less efficient -// SPIR-V -struct FloatMatrix_2d { - float data[TILE_ROWS][FOUR]; -}; - -struct FloatMatrix_3d { - float data[TILE_ROWS][FOUR][FOUR]; -}; - -#ifdef BATCH_MODE - #define FloatMatrix FloatMatrix_3d -#else - #define FloatMatrix FloatMatrix_2d -#endif - -#include "indexing_utils.h" - -layout(std430) buffer; - -${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)} -${layout_declare_tensor(1, "r", "t_mat1", DTYPE, STORAGE)} -${layout_declare_tensor(2, "r", "t_qmat2", "int8", STORAGE)} -${layout_declare_tensor(3, "r", "t_scales", DTYPE, STORAGE)} - -$if STORAGE == "buffer": - ${layout_declare_ubo(4, "ivec4", "out_sizes")} - ${layout_declare_ubo(5, "ivec4", "out_strides")} - ${layout_declare_ubo(6, "int", "out_numel")} - ${layout_declare_ubo(7, "ivec4", "mat1_sizes")} - ${layout_declare_ubo(8, "ivec4", "mat1_strides")} - ${layout_declare_ubo(9, "ivec4", "qmat2_strides")} - ${layout_declare_ubo(10, "ivec4", "scales_strides")} -$else: - ${layout_declare_ubo(4, "ivec3", "out_limits")} - ${layout_declare_ubo(5, "ivec4", "mat1_sizes")} - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -// This header file must be defined after the layout descriptors have been -// declared because the functions in the header assume some variables have been -// declared as layout descriptors. - -#ifdef USING_BUFFER - -#ifndef FLOAT_T -#define FLOAT_T float -#endif - -FLOAT_T q_8w_linear(const ivec4 out_idx, const int K) { - const FLOAT_T scale = t_scales[out_idx.x]; - - FLOAT_T outval = FLOAT_T(0.0); - - // Initial mat1 tensor idx will be (0, out_idx.y, out_idx.z, 0) - int mat1_offset = out_idx.y * mat1_strides.y + out_idx.z * qmat2_strides.z; - // Initial qmat2 tensor idx wil be (0, out_idx.x, 0, 0); note that the qmat2 - // tensor is transposed - int qmat2_offset = out_idx.x * qmat2_strides.y; - - // TODO(ssjia): optimize memory access pattern by traversing K in inner loop - for (int i = 0; i < K; i++) { - const FLOAT_T mat1_val = t_mat1[mat1_offset]; - const FLOAT_T mat2_val = t_qmat2[qmat2_offset] * scale; - - outval += mat1_val * mat2_val; - - mat1_offset++; - qmat2_offset++; - } - - return outval; -} - -void main() { - const int out_bufi = int(gl_GlobalInvocationID.x); - if (out_bufi >= out_numel) { - return; - } - - const ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, 0); - - t_out[out_bufi] = q_8w_linear(out_tidx, mat1_sizes.x); -} - -#else // USING_TEXTURE -FloatMatrix q_8w_linear_optimized(const ivec3 out_idx_tl) { - FloatMatrix results; - for (int i = 0; i < TILE_ROWS; i++) { - for (int j = 0; j < FOUR; j++) { -#ifdef BATCH_MODE - for (int k = 0; k < FOUR; k++) { - results.data[i][j][k] = 0.0f; - } -#else - results.data[i][j] = 0.0f; -#endif // BATCH_MODE - } - } - - VEC4_T im_mat1_partial_load[TILE_ROWS]; - VEC4_T im_mat2_partial_load[FOUR]; - -#ifdef BATCH_MODE - for (int batch_idx = 0; batch_idx < FOUR; batch_idx++) { - if (out_idx_tl.z + batch_idx >= out_limits.z) { - break; - } -#endif - for (int k = 0; k < mat1_sizes.x; k++) { - for (int r = 0; r < TILE_ROWS; r++) { - ivec3 mat1_pos = ivec3(k, out_idx_tl.y * TILE_ROWS + r, 0); -#ifdef BATCH_MODE - mat1_pos[2] = out_idx_tl.z + batch_idx; -#endif - - im_mat1_partial_load[r] = texelFetch(t_mat1, mat1_pos, 0); - } - - for (int r = 0; r < FOUR; ++r) { - ivec3 qmat2_pos = ivec3(k, FOUR * out_idx_tl.x + r, 0); - - im_mat2_partial_load[r] = texelFetch(t_qmat2, qmat2_pos, 0); - } - - vec4 scales = texelFetch(t_scales, ivec3(out_idx_tl.x, 0, 0), 0); - - // perform partial dot products and add partial result to results - for (int out_row = 0; out_row < TILE_ROWS; out_row++) { - for (int out_col = 0; out_col < FOUR; out_col++) { -#ifdef BATCH_MODE - results.data[out_row][out_col][batch_idx] += -#else - results.data[out_row][out_col] += -#endif - dot(im_mat1_partial_load[out_row], - im_mat2_partial_load[out_col] * scales[out_col]); - } - } - } -#ifdef BATCH_MODE - } -#endif - return results; -} - -void main() { - const ivec3 out_idx = ivec3(gl_GlobalInvocationID); - if (any(greaterThanEqual(out_idx, out_limits))) { - return; - } - - FloatMatrix results = q_8w_linear_optimized(out_idx); - - ivec3 out_pos = ivec3( - out_idx.x, - out_idx.y * TILE_ROWS, -#ifdef BATCH_MODE - out_idx.z * 4 -#else - out_idx.z -#endif -); - - for (int idx_c = 0; idx_c < TILE_ROWS; idx_c++, out_pos[1]++) { - out_pos.x = out_idx.x; - $if BATCH_MODE: - for (int idx_r = 0; idx_r < FOUR; idx_r++, out_pos[0]++) { - write_texel(t_out, out_pos, VEC4_T( - results.data[idx_c][idx_r][0], - results.data[idx_c][idx_r][1], - results.data[idx_c][idx_r][2], - results.data[idx_c][idx_r][3])); - } - $else: - write_texel(t_out, out_pos, VEC4_T( - results.data[idx_c][0], - results.data[idx_c][1], - results.data[idx_c][2], - results.data[idx_c][3])); - } -} - -#endif diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_optimized.yaml b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_optimized.yaml deleted file mode 100644 index 52bebf90125..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_optimized.yaml +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -q_8w_linear_optimized: - parameter_names_with_default_values: - DTYPE: float - STORAGE: texture3d - MAT1_PACKING: W_packed - MAT2_PACKING: W_packed - BATCH_MODE: false - TILE_ROWS: 4 - generate_variant_forall: - TILE_ROWS: - - VALUE: 4 - SUFFIX: tile_row_4 - - VALUE: 2 - SUFFIX: tile_row_2 - DTYPE: - - VALUE: float - - VALUE: half - STORAGE: - - VALUE: texture3d - - VALUE: buffer - shader_variants: - - NAME: q_8w_linear_optimized_W_packed_W_packed - - NAME: q_8w_linear_optimized_W_packed_H_packed - MAT2_PACKING: H_packed - - NAME: batch_q_8w_linear_optimized_W_packed_W_packed - BATCH_MODE: true - - NAME: batch_q_8w_linear_optimized_W_packed_H_packed - MAT2_PACKING: H_packed - BATCH_MODE: true diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.glsl new file mode 100644 index 00000000000..c3bd9f41af9 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.glsl @@ -0,0 +1,92 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define T ${buffer_scalar_type(DTYPE)} +#define VEC4_T ${buffer_gvec_type(DTYPE, 4)} + +#define TILE_ROWS ${TILE_ROWS} + +${define_required_extensions(DTYPE)} + +$if STORAGE == "buffer": + ${define_required_extensions("int8")} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight", "int8", STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_scales", DTYPE, STORAGE, is_scalar_array=False)} + + +layout(push_constant) uniform restrict Block { + ivec4 out_sizes; + ivec4 in_sizes; + ivec4 weight_sizes; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const uint out_row = gl_GlobalInvocationID.y * TILE_ROWS; + const uint out_col = gl_GlobalInvocationID.x << 2; + + if (out_col >= out_sizes.x || out_row >= out_sizes.y) { + return; + } + + VEC4_T a[TILE_ROWS]; + VEC4_T b[4]; + VEC4_T c[TILE_ROWS]; + + $if STORAGE == "buffer": + const VEC4_T scales = VEC4_T(t_scales[out_col >> 2]); + $else: + const VEC4_T scales = VEC4_T(texelFetch(t_scales, ivec3(out_col >> 2, 0, 0), 0)); + + [[unroll]] for (int i = 0; i < TILE_ROWS; ++i) { + c[i] = VEC4_T(0.0); + } + + for (int pos = 0; pos < in_sizes.x; pos += 4) { + // Preload weight tensor + [[unroll]] for (int i = 0; i < 4; i++) { + $if STORAGE == "buffer": + b[i] = t_weight[((pos + i) * B_sizes.x + out_col) >> 2]; + $else: + b[i] = VEC4_T(texelFetch(t_weight, ivec3(out_col >> 2, pos + i, 0), 0)); + } + + // Preload input tensor + [[unroll]] for (int i = 0; i < TILE_ROWS; i++) { + $if STORAGE == "buffer": + a[i] = t_in[((out_row + i) * in_sizes.x + (pos)) >> 2]; + $else: + a[i] = VEC4_T(texelFetch(t_in, ivec3(pos >> 2, out_row + i, 0), 0)); + } + + // Compute partial output + [[unroll]] for (int i = 0; i < TILE_ROWS; ++i) { + c[i] += a[i].x * b[0] + a[i].y * b[1] + a[i].z * b[2] + a[i].w * b[3]; + } + } + + // Store output tensor + [[unroll]] for (int i = 0; i < TILE_ROWS; ++i) { + $if STORAGE == "buffer": + t_out[((out_row + i) * out_sizes.x + out_col) >> 2] = c[i] * scales; + $else: + imageStore(t_out, ivec3(out_col >> 2, out_row + i, 0), c[i] * scales); + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.yaml new file mode 100644 index 00000000000..b01af47e179 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.yaml @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +q_8w_linear_tiled: + parameter_names_with_default_values: + DTYPE: float + STORAGE: texture3d + TILE_ROWS: 4 + shader_variants: + - NAME: q_8w_linear_tiled_o4x4_texture3d_float + STORAGE: texture3d + TILE_ROWS: 4 + - NAME: q_8w_linear_tiled_o4x6_texture3d_float + STORAGE: texture3d + TILE_ROWS: 6 diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp index 2011331ec38..f4f5c853ddd 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp @@ -160,100 +160,111 @@ void add_q_8w_linear_node( } } -void add_q_8w_linear_optimized_node( +void add_q_8w_linear_tiled_node( ComputeGraph& graph, const ValueRef mat1, const ValueRef q_mat2_data, const ValueRef scales_data, const ValueRef out) { - auto viewFn = VK_GET_OP_FN("aten.view_copy.default"); - ValueRef mat1_W_packed = mat1; - ValueRef out_W_packed = out; - if (!graph.is_buffer_storage(out) && - graph.packed_dim_of(mat1) != WHCN::kWidthDim) { - // Ensure mat1 is width packed - mat1_W_packed = graph.add_tensor_like(mat1, utils::kWidthPacked); - viewFn(graph, {mat1, graph.add_none(), mat1_W_packed}); - // Ensure out is packed correctly - out_W_packed = graph.add_tensor_like(out, utils::kWidthPacked); - } - utils::StorageType stype = graph.storage_type_of(out); - ValueRef q_mat2 = - prepack_standard(graph, q_mat2_data, stype, utils::kWidthPacked); + ValueRef q_mat2 = prepack_standard_hw_transposed( + graph, q_mat2_data, stype, utils::kWidthPacked); ValueRef scales = prepack_standard(graph, scales_data, stype, utils::kWidthPacked); - std::string kernel_name = "q_8w_linear_optimized"; + std::string kernel_name = "q_8w_linear_tiled"; kernel_name.reserve(kShaderNameReserve); - add_packed_dim_suffix(kernel_name, graph.packed_dim_of(mat1_W_packed)); - add_packed_dim_suffix(kernel_name, graph.packed_dim_of(q_mat2)); - std::vector mat1_sizes = graph.sizes_of(mat1_W_packed); - const int mat1_dims = mat1_sizes.size(); - if (mat1_dims == 3) { - kernel_name = "batch_" + kernel_name; - } - if (mat1_sizes.at(mat1_dims - 2) < 8) { - kernel_name += "_tile_row_2"; + std::vector mat1_sizes = graph.sizes_of(mat1); + const int64_t M = utils::val_at(-2, mat1_sizes); + int out_tile_nrows = 4; + if (M % 6 == 0) { + kernel_name += "_o4x6"; + out_tile_nrows = 6; } else { - kernel_name += "_tile_row_4"; + kernel_name += "_o4x4"; + out_tile_nrows = 4; } - add_dtype_suffix(kernel_name, graph.dtype_of(out_W_packed)); - add_storage_type_suffix(kernel_name, graph.storage_type_of(out_W_packed)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); - vkapi::ParamsBindList ubos({}); + utils::uvec3 global_wg_size = graph.logical_limits_of(out); + global_wg_size[1] = global_wg_size[1] / out_tile_nrows; - utils::uvec3 global_size; - utils::uvec3 local_size; - if (graph.is_buffer_storage(out)) { - ubos.append( - {graph.sizes_ubo(out_W_packed), - graph.strides_ubo(out_W_packed), - graph.numel_ubo(out_W_packed), - graph.sizes_ubo(mat1_W_packed), - graph.strides_ubo(mat1_W_packed), - graph.strides_ubo(q_mat2), - graph.strides_ubo(scales)}); - global_size = graph.create_global_wg_size(out_W_packed); - local_size = graph.create_local_wg_size(out_W_packed); - } else { - global_size = graph.logical_limits_of(out_W_packed); - ubos.append( - {graph.logical_limits_ubo(out_W_packed), - graph.sizes_ubo(mat1_W_packed)}); - if (mat1_sizes.at(mat1_dims - 2) < 8) { - global_size = global_size = utils::divup_vec(global_size, {1, 2, 1}); - } else { - global_size = utils::divup_vec(global_size, {1, 4, 1}); - } - local_size = {16, 3, 1}; - } + utils::uvec3 local_wg_size{64, 1, 1}; graph.execute_nodes().emplace_back(new DispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - global_size, - local_size, + global_wg_size, + local_wg_size, // Inputs and Outputs - {{out_W_packed, vkapi::MemoryAccessType::WRITE}, - {{mat1_W_packed, q_mat2, scales}, vkapi::MemoryAccessType::READ}}, + {{out, vkapi::kWrite}, {{mat1, q_mat2, scales}, vkapi::kRead}}, // Shader params buffers - ubos, + {}, // Specialization Constants - {}, // spec_vars, + {}, // Resizing Logic - resize_q_8w_linear_node)); + resize_q_8w_linear_node, + {}, + // Push Constants + {{graph.sizes_pc_of(out), graph.sizes_pc_of(mat1)}})); +} - if (!graph.is_buffer_storage(out)) { - viewFn(graph, {out_W_packed, graph.add_none(), out}); +bool can_use_tiled_impl( + ComputeGraph& graph, + const ValueRef mat1, + const ValueRef q_mat2_data, + const ValueRef scales_data, + const ValueRef out) { + (void)q_mat2_data; + (void)scales_data; + + // Check if mat1 is not a 3D tensor or that batches = 1 + // TODO(ssjia): Add support for batches in the tiled impl + if (graph.dim_of(mat1) == 3 && graph.size_at(-1, mat1) != 1) { + return false; + } + // Check that K is a multiple of 4 + if (graph.size_at(-1, mat1) % 4 != 0) { + return false; } + // Check that M is a multiple of 4 or 6 + if (graph.size_at(-2, mat1) % 4 != 0 && + graph.size_at(-2, mat1) % 6 != 0) { + return false; + } + // Check that the storage type is texture + // TODO(ssjia): Add support for buffer storage in the tiled impl + if (graph.storage_type_of(out) != utils::kTexture3D) { + return false; + } + // Check that the packed dim is the width dim + if (graph.packed_dim_of(mat1) != WHCN::kWidthDim) { + return false; + } + // Check that no special axis mapping is used for the input + // TODO(ssjia): Add support for non-standard axis mapping in the tiled impl + if (!graph.has_standard_axis_map(mat1)) { + return false; + } + // Check that no special axis mapping is used for the output + // TODO(ssjia): Add support for non-standard axis mapping in the tiled impl + if (!graph.has_standard_axis_map(out)) { + return false; + } + + return true; } void weight_int8pack_mm( ComputeGraph& graph, const std::vector& args) { check_q_8w_linear_args(graph, args[0], args[1], args[2], args[3]); + if (can_use_tiled_impl(graph, args[0], args[1], args[2], args[3])) { + return add_q_8w_linear_tiled_node( + graph, args[0], args[1], args[2], args[3]); + } return add_q_8w_linear_node(graph, args[0], args[1], args[2], args[3]); }