diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled_smem.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled_smem.glsl new file mode 100644 index 00000000000..b5d4ced5bf4 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled_smem.glsl @@ -0,0 +1,174 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, OUTPUT_STORAGE)} +#define T int + +$if OUTPUT_STORAGE == "buffer": + #define OUTPUT_BUFFER +$if PACKED_INT8_INPUT_STORAGE == "buffer": + #define PACKED_INT8_INPUT_BUFFER +$if WEIGHT_STORAGE == "buffer": + #define WEIGHT_BUFFER + +#define TILE_M4 ${TILE_M4} +#define TILE_K4 ${TILE_K4} +#define TILE_N4 ${TILE_N4} + +#define TILE_M ${TILE_M4 * 4} +#define TILE_K ${TILE_K4 * 4} +#define TILE_N ${TILE_N4 * 4} + +#define M_TILES_PER_WG 8 +#define N_TILES_PER_WG 8 +#define K_TILES_PER_WG 1 + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_output", DTYPE, OUTPUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", PACKED_INT8_INPUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_weight", "int", WEIGHT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)} + +${layout_declare_spec_const(C, "int", "apply_bias", "0")} + +${layout_declare_ubo(B, "ivec4", "output_sizes")} +${layout_declare_ubo(B, "ivec4", "input_sizes")} + +layout(push_constant) uniform restrict Block { + float input_scale; + int input_zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "linear_int8_input_tile_load.glslh" +#include "linear_int8_weight_tile_load.glslh" +#include "linear_fp_output_tile_int8_int8_compute.glslh" +#include "linear_fp_output_tile_store.glslh" +#include "linear_fp_weight_scales_load.glslh" +#include "linear_int_weight_sums_load.glslh" +#include "linear_fp_bias_load.glslh" + +shared Int32Accum partial_sums[M_TILES_PER_WG][N_TILES_PER_WG][K_TILES_PER_WG]; + +void add_into_first(inout Int32Accum first, const Int32Accum second) { + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + first.data[m][n4] += second.data[m][n4]; + } + } +} + +void main() { + const int m_tile_lid = int(gl_LocalInvocationID.x); + const int n_tile_lid = int(gl_LocalInvocationID.y); + const int k_tile_lid = int(gl_LocalInvocationID.z); + + // Each thread writes out a 4 wide x 4 high tile of output values + const int out_tile_x = int(gl_GlobalInvocationID.x); + const int out_tile_y = int(gl_GlobalInvocationID.y); + + const int n = out_tile_x * TILE_N; + const int m = out_tile_y * TILE_M; + + const int n4 = div_4(n); + const int m4 = div_4(m); + + if (n >= output_sizes.x || m >= output_sizes.y) { + return; + } + + const int M = output_sizes.y; + const int K4 = div_up_4(input_sizes.x); + const int N4 = div_up_4(output_sizes.x); + + Int32Accum out_accum; + initialize(out_accum); + + Int8InputTile int8_in_tile; + Int8WeightTile int8_weight_tile; + + const int k4_per_iter = TILE_K4 * K_TILES_PER_WG; + + // No checks are needed since packed input and weight are structured in units + // of 4x4 blocks. + for (int k4 = k_tile_lid; k4 < K4; k4 += k4_per_iter) { + load_int8_input_tile(int8_in_tile, k4, m4, K4); + load_int8_weight_tile(int8_weight_tile, n4, k4, N4); + + int_accumulate_with_int8_weight(out_accum, int8_in_tile, int8_weight_tile); + } + + partial_sums[m_tile_lid][n_tile_lid][k_tile_lid] = out_accum; + + memoryBarrierShared(); + barrier(); + + // Tree reduction to compute the overall result. + for (int i = K_TILES_PER_WG / 2; i > 0; i /= 2) { + if (k_tile_lid < i) { + add_into_first( + partial_sums[m_tile_lid][n_tile_lid][k_tile_lid], + partial_sums[m_tile_lid][n_tile_lid][k_tile_lid + i]); + } + memoryBarrierShared(); + barrier(); + } + + if (k_tile_lid > 0) { + return; + } + + out_accum = partial_sums[m_tile_lid][n_tile_lid][0]; + + FPPerOutChannelParams weight_scales_tile; + load_weight_scales_tile(weight_scales_tile, n4); + + IntPerOutChannelParams weight_sums_tile; + load_weight_sums_tile(weight_sums_tile, n4); + + FPOutTile out_tile; + initialize(out_tile); + + if (apply_bias > 0) { + FPPerOutChannelParams bias_tile; + load_bias_tile(bias_tile, n4); + + accumulate_out_tile_with_int_accum( + out_tile, + out_accum, + input_scale, + input_zp, + weight_sums_tile, + weight_scales_tile, + bias_tile); + } + else { + accumulate_out_tile_with_int_accum( + out_tile, + out_accum, + input_scale, + input_zp, + weight_sums_tile, + weight_scales_tile); + } + + if (M - m >= TILE_M) { + write_output_tile_no_checks(out_tile, n4, m, N4, M); + } else { + write_output_tile_with_checks(out_tile, n4, m, N4, M); + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled_smem.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled_smem.yaml new file mode 100644 index 00000000000..0b66b0020f7 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled_smem.yaml @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +linear_q8ta_q8csw_tiled_smem: + parameter_names_with_default_values: + DTYPE: float + OUTPUT_STORAGE: texture3d + PACKED_INT8_INPUT_STORAGE: buffer + WEIGHT_STORAGE: texture2d + TILE_M4: 1 + TILE_N4: 2 + TILE_K4: 1 + generate_variant_forall: + combination: + parameter_names: [OUTPUT_STORAGE, PACKED_INT8_INPUT_STORAGE, WEIGHT_STORAGE] + combos: + - parameter_values: [texture3d, buffer, texture2d] + - parameter_values: [buffer, buffer, texture2d] + DTYPE: + - VALUE: float + shader_variants: + - NAME: linear_q8ta_q8csw_tiled_smem diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp index 97566038501..8f5550e903b 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp @@ -100,8 +100,9 @@ utils::uvec3 quantized_linear_local_wg_size( if (use_coop_algorithm) { return {1, 1, 64}; } else { - return pick_hw_square_wg_size( - graph, shader, global_workgroup_size, args, resize_args); + // return pick_hw_square_wg_size( + // graph, shader, global_workgroup_size, args, resize_args); + return {8, 8, 1}; } } @@ -595,7 +596,7 @@ DynamicDispatchNode make_linear_qa_qw_node( int32_t zp = graph.extract_scalar(input_zp_data); // Get shader for quantized linear - std::string kernel_name = "linear_q8ta_q8csw_tiled"; + std::string kernel_name = "linear_q8ta_q8csw_tiled_smem"; add_storage_type_suffix(kernel_name, graph.storage_type_of(output)); add_storage_type_suffix(kernel_name, graph.storage_type_of(packed_int_input)); add_storage_type_suffix(kernel_name, graph.storage_type_of(packed_weight));