From a5d8a02db8b6e81b731fa563cb23d3139da2e1be Mon Sep 17 00:00:00 2001 From: ssjia Date: Fri, 29 Aug 2025 09:52:33 -0700 Subject: [PATCH] [ET-VK] Quantized Int8 Linear Title says it all! This PR adds implementations for int8 linear layers. Convolution is implemented in a later step, computing convolution as matrix multiplication via the im2col procedure. For both linear and convolution, two versions are implemented: 1. `q8ta_q8csw` variant which quantized the input tensor and then performs integer accumulation via the int8 dot product extension 2. `q8csw` variant which dequantized the weight tensor in-shader and performs floating point accumulation. The second one is needed to provide an alternative path for executing quantized models if the target GPU does not support int8 dot product extension. These new ops are tested via the custom op testing + benchmarking framework introduced in the previous diff. Differential Revision: [D81323424](https://our.internmc.facebook.com/intern/diff/D81323424/) [ghstack-poisoned] --- .github/workflows/pull.yml | 2 + .../runtime/graph/ops/glsl/common.glslh | 40 ++ .../graph/ops/glsl/linear_bias_load.glslh | 30 + .../graph/ops/glsl/linear_common.glslh | 41 ++ .../graph/ops/glsl/linear_fp_input_tile.glslh | 43 ++ .../ops/glsl/linear_fp_input_tile_load.glslh | 91 +++ .../ops/glsl/linear_fp_output_tile.glslh | 60 ++ .../linear_fp_output_tile_fp_compute.glslh | 96 +++ .../linear_fp_output_tile_int8_compute.glslh | 124 ++++ .../glsl/linear_fp_output_tile_store.glslh | 114 ++++ .../ops/glsl/linear_fp_weight_tile.glslh | 100 ++++ .../ops/glsl/linear_int8_input_block.glslh | 77 +++ .../ops/glsl/linear_int8_input_tile.glslh | 93 +++ .../glsl/linear_int8_input_tile_load.glslh | 75 +++ .../ops/glsl/linear_int8_weight_block.glslh | 140 +++++ .../ops/glsl/linear_int8_weight_tile.glslh | 45 ++ .../glsl/linear_int8_weight_tile_load.glslh | 75 +++ .../graph/ops/glsl/linear_q8csw_tiled.glsl | 117 ++++ .../graph/ops/glsl/linear_q8csw_tiled.yaml | 30 + .../ops/glsl/linear_q8ta_q8csw_tiled.glsl | 117 ++++ .../ops/glsl/linear_q8ta_q8csw_tiled.yaml | 30 + .../graph/ops/glsl/linear_scales_load.glslh | 30 + .../ops/glsl/linear_weight_sums_load.glslh | 30 + .../graph/ops/glsl/pack_q8_linear_weight.glsl | 62 ++ .../graph/ops/glsl/pack_q8_linear_weight.yaml | 14 + .../glsl/quantize_and_pack_linear_input.glsl | 79 +++ .../glsl/quantize_and_pack_linear_input.yaml | 24 + .../graph/ops/impl/QuantizedLinear.cpp | 548 ++++++++++++++++++ .../runtime/graph/ops/impl/QuantizedLinear.h | 35 ++ .../vulkan/test/custom_ops/CMakeLists.txt | 1 + .../test/custom_ops/quantized_linear.cpp | 352 +++++++++++ backends/vulkan/test/custom_ops/targets.bzl | 1 + 32 files changed, 2716 insertions(+) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/common.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_bias_load.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_common.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile_load.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_compute.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_compute.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_store.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_fp_weight_tile.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_block.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_tile.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_tile_load.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_block.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_tile.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_tile_load.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_q8csw_tiled.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_q8csw_tiled.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_scales_load.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_weight_sums_load.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/pack_q8_linear_weight.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/pack_q8_linear_weight.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input.yaml create mode 100644 backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp create mode 100644 backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.h create mode 100644 backends/vulkan/test/custom_ops/quantized_linear.cpp diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index fead532acf2..330f9521f60 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -928,7 +928,9 @@ jobs: CMAKE_ARGS="-DEXECUTORCH_BUILD_VULKAN=ON" \ .ci/scripts/setup-linux.sh --build-tool "cmake" + # Custom operator tests PYTHON_EXECUTABLE=python bash backends/vulkan/test/custom_ops/build_and_run.sh add + ./cmake-out/backends/vulkan/test/custom_ops/quantized_linear nxp-build-test: name: nxp-build-test diff --git a/backends/vulkan/runtime/graph/ops/glsl/common.glslh b/backends/vulkan/runtime/graph/ops/glsl/common.glslh new file mode 100644 index 00000000000..c96392792b2 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/common.glslh @@ -0,0 +1,40 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef COMMON_GLSLH +#define COMMON_GLSLH + +#define align_up_4(x) ((x + 3) & -4) + +#define div_up_4(x) (((x) + 3) >> 2) + +#define mul_4(x) ((x) << 2) +#define div_4(x) ((x) >> 2) + +#define mod_4(x) ((x) & 3) + +struct TensorIndex4D { + ivec4 data; +}; + +#ifdef DEBUG_MODE + +#extension GL_EXT_debug_printf : require + +void printTensorIndex4D(const TensorIndex4D index) { + debugPrintfEXT( + "tensor_idx: %d, %d, %d, %d\\n", + index.data.x, + index.data.y, + index.data.z, + index.data.w); +} + +#endif // DEBUG_MODE + +#endif // COMMON_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_bias_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_bias_load.glslh new file mode 100644 index 00000000000..346ed2b0a87 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_bias_load.glslh @@ -0,0 +1,30 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef LINEAR_BIAS_LOAD_GLSLH +#define LINEAR_BIAS_LOAD_GLSLH + +#include "linear_common.glslh" + +VEC4_T load_bias_x4(const uint n4) { + return t_bias[n4]; +} + +void load_bias_tile(out FPPerOutChannelParams bias, const uint n4_start) { +#if TILE_N4 == 1 + bias.data[0] = load_bias_x4(n4_start); + +#else + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + bias.data[n4] = load_bias_x4[n4_start + n4]; + } + +#endif +} + +#endif // LINEAR_BIAS_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_common.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_common.glslh new file mode 100644 index 00000000000..e1717bc5e18 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_common.glslh @@ -0,0 +1,41 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Defines common functions and structs to be used across matrix multiplication + * operators. + */ + +#ifndef LINEAR_COMMON_GLSLH +#define LINEAR_COMMON_GLSLH + +#include "common.glslh" + +// Represents floating point parameter tensors where each element is associated +// with an output channel, such as weight scales, biases, etc. +struct FPPerOutChannelParams { + VEC4_T data[TILE_N4]; +}; + +#ifdef DEBUG_MODE + +void printFPPerOutChannelParams(const FPPerOutChannelParams params) { + debugPrintfEXT("per_out_channel_params: \\n"); + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + debugPrintfEXT( + " %f, %f, %f, %f, \\n", + params.data[n4].x, + params.data[n4].y, + params.data[n4].z, + params.data[n4].w); + } +} + +#endif // DEBUG_MODE + +#endif // LINEAR_COMMON_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile.glslh new file mode 100644 index 00000000000..492dab8239d --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile.glslh @@ -0,0 +1,43 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef LINEAR_FP_INPUT_TILE_GLSLH +#define LINEAR_FP_INPUT_TILE_GLSLH + +/* + * Defines the FPInputTile struct, which is used to represent a tile of the + * input matrix of a matrix multiplication operation. + * + * Settings: + * - TILE_M: number of rows in the tile + * - TILE_K4: number of (groups of 4) columns in the tile + */ + +struct FPInputTile { + VEC4_T data[TILE_M][TILE_K4]; +}; + +#ifdef DEBUG_MODE + +void printFPInputTile(const FPInputTile in_tile) { + debugPrintfEXT("input_tile: \\n"); + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { + debugPrintfEXT( + " %f, %f, %f, %f, \\n", + in_tile.data[m][k4].x, + in_tile.data[m][k4].y, + in_tile.data[m][k4].z, + in_tile.data[m][k4].w); + } + } +} + +#endif // DEBUG_MODE + +#endif // LINEAR_FP_INPUT_TILE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile_load.glslh new file mode 100644 index 00000000000..a98f07b042a --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile_load.glslh @@ -0,0 +1,91 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Defines functions to load a FPInputTile from input buffer/texture. + * + * Requires: + * - t_input to be declared in the shader layout (input buffer/texture) + * + * Settings: + * - INPUT_BUFFER to indicate input resource is a buffer, otherwise texture is + * assumed. + */ + +#ifndef LINEAR_FP_INPUT_TILE_LOAD_GLSLH +#define LINEAR_FP_INPUT_TILE_LOAD_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_fp_input_tile.glslh" + +#ifdef INPUT_BUFFER + +VEC4_T load_input_x4(const uint k4, const uint m, const uint ntexels_k) { + return t_input[(m * ntexels_k) + k4]; +} + +#else + +VEC4_T load_input_x4(const uint k4, const uint m, const uint ntexels_k) { + return texelFetch(t_input, ivec3(k4, m, 0), 0); +} + +#endif // INPUT_BUFFER + +// To be used if (M - m_start >= TILE_M) || (K4 - k4_start >= TILE_K4) +void load_input_tile_no_checks( + out FPInputTile in_tile, + const uint k4_start, + const uint m_start, + const uint K4, + const uint M) { +#if TILE_K4 == 1 + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + in_tile.data[m][0] = load_input_x4(k4_start, m_start + m, K4); + } + +#else + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { + in_tile.data[m][k4] = load_input_x4(k4_start + k4, m_start + m, K4); + } + } +#endif +} + +// To be used if near tensor boundaries +void load_input_tile_with_checks( + out FPInputTile in_tile, + const uint k4_start, + const uint m_start, + const uint K4, + const uint M) { +#if TILE_K4 == 1 + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + if (m_start + m < M) { + in_tile.data[m][0] = load_input_x4(k4_start, m_start + m, K4); + } else { + in_tile.data[m][0] = VEC4_T(0.0); + } + } + +#else + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { + if (m_start + m < M && k4_start + k4 < K4) { + in_tile.data[m][k4] = load_input_x4(k4_start + k4, m_start + m, K4); + } else { + in_tile.data[m][k4] = VEC4_T(0.0); + } + } + } +#endif +} + +#endif // LINEAR_FP_INPUT_TILE_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile.glslh new file mode 100644 index 00000000000..c4571315bdd --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile.glslh @@ -0,0 +1,60 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Defines the FPOutTile struct, which is used to represent a tile of the output + * matrix of a matrix multiplication operation. + * + * Settings: + * - TILE_M: number of rows in the output tile + * - TILE_N4: number of (groups of 4) columns in the output tile + */ + +#ifndef LINEAR_FP_OUTPUT_TILE_GLSLH +#define LINEAR_FP_OUTPUT_TILE_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +struct FPOutTile { + VEC4_T data[TILE_M][TILE_N4]; +}; + +void initialize(out FPOutTile out_tile) { +#if TILE_M > 1 && TILE_N4 == 1 + [[unroll]] for (int y = 0; y < TILE_M; ++y) { + out_tile.data[y][0] = VEC4_T(0); + } + +#else + [[unroll]] for (int y = 0; y < TILE_M; ++y) { + [[unroll]] for (int x4 = 0; x4 < TILE_K4; ++x4) { + out_tile.data[y][x4] = VEC4_T(0); + } + } +#endif +} + +#ifdef DEBUG_MODE + +void printFPOutputTile(const FPOutTile tile) { + debugPrintfEXT("output_tile: \\n"); + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + debugPrintfEXT( + " %f, %f, %f, %f, \\n", + tile.data[m][n4].x, + tile.data[m][n4].y, + tile.data[m][n4].z, + tile.data[m][n4].w); + } + } +} + +#endif // DEBUG_MODE + +#endif // LINEAR_FP_OUTPUT_TILE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_compute.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_compute.glslh new file mode 100644 index 00000000000..470db8b529a --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_compute.glslh @@ -0,0 +1,96 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Defines functions to compute a FPOutTile using fp input and weight tiles. + */ + +#ifndef LINEAR_FP_OUTPUT_TILE_FP_COMPUTE_GLSLH +#define LINEAR_FP_OUTPUT_TILE_FP_COMPUTE_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_common.glslh" +#include "linear_fp_input_tile.glslh" +#include "linear_fp_output_tile.glslh" +#include "linear_fp_weight_tile.glslh" + +/* + * Accumulates floating point input tile and floating point weight tile into + * floating point output tile. + */ +void update(inout FPOutTile accum, FPInputTile in_tile, FPWeightTile w_tile) { + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + const int n = mul_4(n4); + [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { + accum.data[m][n4] = + fma(VEC4_T(in_tile.data[m][k4][0]), + w_tile.data[k4][n4], + accum.data[m][n4]); + + accum.data[m][n4] = + fma(VEC4_T(in_tile.data[m][k4][1]), + w_tile.data[k4 + 1][n4], + accum.data[m][n4]); + + accum.data[m][n4] = + fma(VEC4_T(in_tile.data[m][k4][2]), + w_tile.data[k4 + 2][n4], + accum.data[m][n4]); + + accum.data[m][n4] = + fma(VEC4_T(in_tile.data[m][k4][3]), + w_tile.data[k4 + 3][n4], + accum.data[m][n4]); + } + } + } +} + +/* + * Applies per output channel weight scales to the output tile. + */ +void apply_scales(inout FPOutTile tile, const FPPerOutChannelParams scales) { +#if TILE_M > 1 && TILE_N4 == 1 + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + tile.data[m][0] = tile.data[m][0] * scales.data[0]; + } + +#else + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + tile.data[m][n4] = tile.data[m][n4] * scales.data[n4]; + } + } +#endif +} + +/* + * Applies per output channel weight scales and per output channel biases to the + * output tile. + */ +void apply_scales_and_biases( + inout FPOutTile tile, + const FPPerOutChannelParams scales, + const FPPerOutChannelParams bias) { +#if TILE_M > 1 && TILE_N4 == 1 + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + tile.data[m][0] = tile.data[m][0] * scales.data[0] + bias.data[0]; + } + +#else + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + tile.data[m][n4] = tile.data[m][n4] * scales.data[n4] + bias.data[n4]; + } + } +#endif +} + +#endif // LINEAR_FP_OUTPUT_TILE_FP_COMPUTE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_compute.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_compute.glslh new file mode 100644 index 00000000000..58fd9086266 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_compute.glslh @@ -0,0 +1,124 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Defines functions to compute a FPOutTile using int8 input and weight tiles. + * + * Settings: + * - TILE_M: The number of rows in the output tile. + * - TILE_N4: The number of (groups of 4) columns in the output tile. + */ + +#ifndef LINEAR_FP_OUTPUT_TILE_INT8_COMPUTE_GLSLH +#define LINEAR_FP_OUTPUT_TILE_INT8_COMPUTE_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_common.glslh" +#include "linear_fp_output_tile.glslh" +#include "linear_int8_input_tile.glslh" +#include "linear_int8_weight_tile.glslh" + +// Stores integer accumulators for an output tile. +struct Int8OutAccum { + ivec4 data[TILE_M][TILE_N4]; +}; + +// Initialize values to 0 +void initialize(out Int8OutAccum out_accum) { +#if TILE_M > 1 && TILE_N4 == 1 + [[unroll]] for (int y = 0; y < TILE_M; ++y) { + out_accum.data[y][0] = ivec4(0); + } + +#else + [[unroll]] for (int y = 0; y < TILE_M; ++y) { + [[unroll]] for (int x4 = 0; x4 < TILE_K4; ++x4) { + out_accum.data[y][x4] = ivec4(0); + } + } +#endif +} + +// Accumulate int8 input and weight tiles into accumulator tile +void accumulate( + inout Int8OutAccum accum, + Int8InputTile in_tile, + Int8WeightTile w_tile) { + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + const int m4 = div_4(m); + const int m4i = mod_4(m); + [[unroll]] for (int n = 0; n < TILE_N; ++n) { + const int n4 = div_4(n); + const int n4i = mod_4(n); + [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { + accum.data[m][n4][n4i] = dotPacked4x8AccSatEXT( + in_tile.data[m4][k4][m4i], + w_tile.data[k4][n4][n4i], + accum.data[m][n4][n4i]); + } + } + } +} + +/* + * Computes final weight matrix output tile using: + * - int8 accumulator tile + * - per output channel weight sums + * - per output channel scales + */ +void compute( + out FPOutTile out_tile, + const Int8OutAccum out_accum, + const FPPerOutChannelParams sums, + const FPPerOutChannelParams scales) { +#if TILE_M > 1 && TILE_N4 == 1 + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + out_tile.data[m][0] = + (VEC4_T(out_accum.data[m][0]) - input_zp * sums.data[0]) * + scales.data[0] * input_scale; + } + +#else + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + out_tile.data[m][n4] = + (VEC4_T(out_accum.data[m][n4]) - input_zp * sums.data[n4]) * + scales.data[n4] * input_scale; + } + } +#endif +} + +void compute( + out FPOutTile out_tile, + const Int8OutAccum out_accum, + const FPPerOutChannelParams sums, + const FPPerOutChannelParams scales, + const FPPerOutChannelParams bias) { +#if TILE_M > 1 && TILE_N4 == 1 + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + out_tile.data[m][0] = + (VEC4_T(out_accum.data[m][0]) - input_zp * sums.data[0]) * + scales.data[0] * input_scale + + bias.data[0]; + } + +#else + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + out_tile.data[m][n4] = + (VEC4_T(out_accum.data[m][n4]) - input_zp * sums.data[n4]) * + scales.data[n4] * input_scale + + bias.data[n4]; + } + } +#endif +} + +#endif // LINEAR_FP_OUTPUT_TILE_INT8_COMPUTE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_store.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_store.glslh new file mode 100644 index 00000000000..d40a0fe98cc --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_store.glslh @@ -0,0 +1,114 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Defines functions store a FpOutTile to output buffer/texture. + * + * Requires: + * - t_output to be declared in the shader layout + * + * Settings: + * - OUTPUT_BUFFER to indicate t_output is a vec4 buffer, otherwise texture + * storage is assumed. + */ + +#ifndef LINEAR_FP_OUTPUT_TILE_STORE_GLSLH +#define LINEAR_FP_OUTPUT_TILE_STORE_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_fp_output_tile.glslh" + +#ifdef OUTPUT_BUFFER + +void write_output_x4( + const VEC4_T out_texel, + const uint n4, + const uint m, + const uint N4) { + t_output[m * N4 + n4] = out_texel; +} + +#else + +void write_output_x4( + const VEC4_T out_texel, + const uint n4, + const uint m, + const uint N4) { + imageStore(t_output, ivec3(n4, m, 0), out_texel); +} + +#endif // OUTPUT_BUFFER + +void write_output_tile( + const FPOutTile out_tile, + const uint n4_start, + const uint m_start, + const uint N4) { +#if TILE_K4 == 1 + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + write_output_x4(out_tile.data[m][0], n4_start, m_start + m, N4); + } + +#else + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + write_output_x4(out_tile.data[m][n4], n4_start + n4, m_start + m, N4); + } + } +#endif +} + +// To be used if M - m >= TILE_M && N4 - n4 >= TILE_N4 +void write_output_tile_no_checks( + const FPOutTile out_tile, + const uint n4_start, + const uint m_start, + const uint N4, + const uint M) { +#if TILE_N4 == 1 + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + write_output_x4(out_tile.data[m][0], n4_start, m_start + m, N4); + } + +#else + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + write_output_x4(out_tile.data[m][n4], n4_start + n4, m_start + m, N4); + } + } +#endif +} + +// To be used if close to tensor boundaries +void write_output_tile_with_checks( + const FPOutTile out_tile, + const uint n4_start, + const uint m_start, + const uint N4, + const uint M) { +#if TILE_N4 == 1 + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + if (m_start + m < M) { + write_output_x4(out_tile.data[m][0], n4_start, m_start + m, N4); + } + } + +#else + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + if (m_start + m < M && n4_start + n4 < N4) { + write_output_x4(out_tile.data[m][n4], n4_start + n4, m_start + m, N4); + } + } + } +#endif +} + +#endif // LINEAR_FP_OUTPUT_TILE_STORE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_weight_tile.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_weight_tile.glslh new file mode 100644 index 00000000000..fb50911fb98 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_weight_tile.glslh @@ -0,0 +1,100 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Defines the FPWeightTile struct, which is used to represent a fp tile of a + * weight matrix in matrix multiplication. + * + * Settings: + * - TILE_K: number of rows in the output tile + * - TILE_N4: number of (groups of 4) columns in the output tile + */ + +#ifndef LINEAR_FP_WEIGHT_TILE_GLSLH +#define LINEAR_FP_WEIGHT_TILE_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +struct FPWeightTile { + VEC4_T data[TILE_K][TILE_N4]; +}; + +#ifdef LINEAR_INT8_WEIGHT_TILE_GLSLH + +int sign_extend(const int val) { + if ((val & 0x80) != 0) { + return val | (~0xFF); + } + return val; +} + +T extract_8bit_value(const Int8WeightTile w_tile, const uint k, const uint n) { +#if TILE_K4 == 1 && TILE_N4 == 1 + const uint k4i = k; + const uint n4i = n; + ivec4 block = w_tile.data[0][0]; + +#else + const uint k4 = div_4(k); + const uint k4i = mod_4(k); + + const uint n4 = div_4(n); + const uint n4i = mod_4(n); + + ivec4 block = w_tile.data[k4][n4]; +#endif + + int col = block[n4i]; + int val = (col >> ((3 - k4i) * 8)) & 0xFF; + + return T(sign_extend(val)); +} + +void unpack(out FPWeightTile fp_w_tile, const Int8WeightTile w_tile) { +#if TILE_K > 1 && TILE_N4 == 1 + [[unroll]] for (int k = 0; k < TILE_K; ++k) { + fp_w_tile.data[k][0][0] = extract_8bit_value(w_tile, k, 0); + fp_w_tile.data[k][0][1] = extract_8bit_value(w_tile, k, 1); + fp_w_tile.data[k][0][2] = extract_8bit_value(w_tile, k, 2); + fp_w_tile.data[k][0][3] = extract_8bit_value(w_tile, k, 3); + } + +#else + [[unroll]] for (int k = 0; k < TILE_M; ++k) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + const uint n = mul_4(n4); + fp_w_tile.data[k][n4][0] = extract_8bit_value(w_tile, k, n); + fp_w_tile.data[k][n4][1] = extract_8bit_value(w_tile, k, n + 1); + fp_w_tile.data[k][n4][2] = extract_8bit_value(w_tile, k, n + 2); + fp_w_tile.data[k][n4][3] = extract_8bit_value(w_tile, k, n + 3); + } + } +#endif +} + +#endif // LINEAR_INT8_WEIGHT_TILE_GLSLH + +#ifdef DEBUG_MODE + +void printFPWeightTile(const FPWeightTile tile) { + debugPrintfEXT("weight_tile: \\n"); + [[unroll]] for (int k = 0; k < TILE_K; ++k) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + debugPrintfEXT( + " %f, %f, %f, %f, \\n", + tile.data[k][n4].x, + tile.data[k][n4].y, + tile.data[k][n4].z, + tile.data[k][n4].w); + } + } +} + +#endif // DEBUG_MODE + +#endif // LINEAR_FP_WEIGHT_TILE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_block.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_block.glslh new file mode 100644 index 00000000000..5b3a86b77d7 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_block.glslh @@ -0,0 +1,77 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * This file defines utilties to perform int8 quantization and block packing of + * matrix multiplation inputs. It also defines utilities to store packed block + * data to an output buffer or texture. + * + * Requires: + * - t_output to be defined in shader layout (output buffer/texture) + * + * Settings: + * - OUTPUT_BUFFER to indicate if output resource is a buffer. Otherwise texture + * is assumed. + */ + +#ifndef LINEAR_INT8_INPUT_BLOCK_GLSLH +#define LINEAR_INT8_INPUT_BLOCK_GLSLH + +#define TILE_M 4 +#define TILE_K4 1 + +#include "linear_fp_input_tile.glslh" + +struct Int8InputBlock { + ivec4 data; +}; + +ivec4 quantize(const VEC4_T val) { + vec4 quantized = round(vec4(val) * inv_scale) + zp; + + // hard-code 8 bit quantization range + return clamp(ivec4(quantized), -127, 127); +} + +int pack_into_int32(const ivec4 quant_vals) { + int packed = ((quant_vals[3] & 0xFF) << 0) | ((quant_vals[2] & 0xFF) << 8) | + ((quant_vals[1] & 0xFF) << 16) | ((quant_vals[0] & 0xFF) << 24); + + return packed; +} + +void quantize_and_pack(out Int8InputBlock packed, const FPInputTile in_block) { + for (int row = 0; row < 4; ++row) { + ivec4 quantized_inputs = quantize(in_block.data[row][0]); + packed.data[row] = pack_into_int32(quantized_inputs); + } +} + +#ifdef OUTPUT_BUFFER + +void write_block( + const Int8InputBlock block, + const uint block_x, + const uint block_y, + const uint nblocks_x) { + t_output[block_y * nblocks_x + block_x] = block.data; +} + +#else // OUTPUT_TEXTURE + +void write_block( + const Int8InputBlock block, + const uint block_x, + const uint block_y, + const uint nblocks_x) { + imageStore(t_output, ivec3(block_x, block_y, 0), block.data); +} + +#endif // OUTPUT_BUFFER + +#endif // LINEAR_INT8_INPUT_BLOCK_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_tile.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_tile.glslh new file mode 100644 index 00000000000..21e8ba031c5 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_tile.glslh @@ -0,0 +1,93 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Defines the Int8InputTile struct, which is used to represent a tile of the + * quantized int8 input matrix of a quantized matrix multiplication operation. + * + * Settings: + * - TILE_M4: number of (groups of 4) rows in the tile + * - TILE_K4: number of (groups of 4) columns in the tile + */ + +#ifndef LINEAR_INT8_INPUT_TILE_GLSLH +#define LINEAR_INT8_INPUT_TILE_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +struct Int8InputTile { + ivec4 data[TILE_M4][TILE_K4]; +}; + +#ifdef DEBUG_MODE + +int extract_8bit_from_packed_int_le(const int packed, const uint i) { + // account for little endian, extract 8-bit value at position i + int byte = int(uint(packed) >> (8 * i) & 255u); + // convert unsigned byte to signed byte + if (byte > 127) { + byte = byte - 256; + } + return byte; +} + +void printInt8InputTile(const Int8InputTile tile) { + debugPrintfEXT( + "Int8InputTile [TILE_M4=%d][TILE_K4=%d]:\\n", TILE_M4, TILE_K4); + + [[unroll]] for (int m4 = 0; m4 < TILE_M4; ++m4) { + [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { + debugPrintfEXT(" tile[%d][%d] (ivec4): ", m4, k4); + + // Each ivec4 contains 4 packed integers, each integer contains 4 8-bit + // values + [[unroll]] for (int vec_idx = 0; vec_idx < 4; ++vec_idx) { + int packed_int = tile.data[m4][k4][vec_idx]; + debugPrintfEXT("packed_int[%d]=%d -> [", vec_idx, packed_int); + + // Extract 4 8-bit values from this packed integer + [[unroll]] for (int byte_idx = 0; byte_idx < 4; ++byte_idx) { + int val = extract_8bit_from_packed_int_le(packed_int, byte_idx); + if (byte_idx < 3) { + debugPrintfEXT("%d, ", val); + } else { + debugPrintfEXT("%d] ", val); + } + } + } + debugPrintfEXT("\\n"); + } + } +} + +void printInt8InputTileCompact(const Int8InputTile tile) { + debugPrintfEXT( + "Int8InputTile [%dx%d] (showing extracted 8-bit values):\\n", + TILE_M4 * 4, + TILE_K4 * 4); + + [[unroll]] for (int m4 = 0; m4 < TILE_M4; ++m4) { + // Print 4 rows at a time (since each m4 represents 4 rows) + [[unroll]] for (int row_in_m4 = 0; row_in_m4 < 4; ++row_in_m4) { + debugPrintfEXT(" row %d: ", m4 * 4 + row_in_m4); + + [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { + [[unroll]] for (int vec_idx = 0; vec_idx < 4; ++vec_idx) { + int packed_int = tile.data[m4][k4][vec_idx]; + int val = extract_8bit_from_packed_int_le(packed_int, row_in_m4); + debugPrintfEXT("%4d ", val); + } + } + debugPrintfEXT("\\n"); + } + } +} + +#endif // DEBUG_MODE + +#endif // LINEAR_INT8_INPUT_TILE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_tile_load.glslh new file mode 100644 index 00000000000..ea302ab4f40 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_tile_load.glslh @@ -0,0 +1,75 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Defines functions to load a Int8InputTile from input buffer/texture. + * + * Requires: + * - t_input to be declared in the shader layout (input buffer/texture) + * + * Settings: + * - INPUT_BUFFER to indicate resource is a buffer, otherwise texture storage is + * assumed. + */ + +#ifndef LINEAR_INT8_INPUT_TILE_LOAD_GLSLH +#define LINEAR_INT8_INPUT_TILE_LOAD_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_int8_input_tile.glslh" + +#ifdef INPUT_BUFFER + +ivec4 load_input_block( + const uint block_x, + const uint block_y, + const uint nblocks_x) { + return t_input[(block_y * nblocks_x) + block_x]; +} + +#else + +ivec4 load_input_block( + const uint block_x, + const uint block_y, + const uint nblocks_x) { + return texelFetch(t_input, ivec3(block_x, block_y, 0), 0); +} + +#endif // INPUT_BUFFER + +void load_input_tile( + out Int8InputTile in_tile, + const uint block_x, + const uint block_y, + const uint nblocks_x) { +#if TILE_M4 == 1 && TILE_K4 == 1 + in_tile.data[0][0] = load_input_block(block_x, block_y, nblocks_x); + +#elif TILE_M4 == 1 && TILE_K4 > 1 + [[unroll]] for (int x = 0; x < TILE_K4; ++x) { + in_tile.data[0][x] = load_input_block(block_x + x, block_y, nblocks_x); + } + +#elif TILE_M4 > 1 && TILE_K4 == 1 + [[unroll]] for (int y = 0; y < TILE_M4; ++y) { + in_tile.data[y][0] = load_input_block(block_x, block_y + y, nblocks_x); + } + +#else + [[unroll]] for (int y = 0; y < TILE_M4; ++y) { + [[unroll]] for (int x = 0; x < TILE_K4; ++x) { + in_tile.data[y][x] = + load_input_block(block_x + x, block_y + y, nblocks_x); + } + } +#endif +} + +#endif // LINEAR_INT8_INPUT_TILE_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_block.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_block.glslh new file mode 100644 index 00000000000..c7a2022730b --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_block.glslh @@ -0,0 +1,140 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef LINEAR_INT8_WEIGHT_BLOCK_GLSLH +#define LINEAR_INT8_WEIGHT_BLOCK_GLSLH + +/* + * This file defines utilties to perform weight prepacking of quantized int8 + * matrix multiplation weights. It also defines utilities to load source + * weight data from inputbuffer, and write out a packed weight block to output + * texture/buffer. + * + * Requires: + * - t_qmat2 to be defined in shader layout (output texture/buffer) + * - t_input to be defined in shader layout (input buffer) + * + * Settings: + * - USING_BUFFER to indicate if output resource is a buffer. Otherwise texture + * is assumed. + */ + +#extension GL_EXT_control_flow_attributes : require + +// Represents source data for a 4x4 block of the weight matrix read from the +// input buffer. +struct Int8WeightBlockSourceData { + int data[4]; +}; + +// Represents data for a packed 4x4 block of the weight matrix to be written out +// to output texture/buffer. +struct Int8WeightBlockPacked { + ivec4 data; +}; + +// To be used if K - k_start >= 4 +void load_block_source_data_no_checks( + out Int8WeightBlockSourceData src_data, + const uint n4, + const uint k_start, + const uint ntexels_N, + const uint K) { + [[unroll]] for (int k = 0; k < 4; ++k) { + src_data.data[k] = t_input[(k_start + k) * ntexels_N + n4]; + } +} + +// To be used if K - k_start < 4 +void load_block_source_data_with_checks( + out Int8WeightBlockSourceData src_data, + const uint n4, + const uint k_start, + const uint ntexels_N, + const uint K) { + [[unroll]] for (int k = 0; k < 4; ++k) { + if (k_start + k < K) { + src_data.data[k] = t_input[(k_start + k) * ntexels_N + n4]; + } else { + src_data.data[k] = 0; + } + } +} + +int extract_8bit_from_packed_uint_le(const uint packed, const uint i) { + // account for little endian + int byte = int(packed >> (8 * i) & 255); + return byte; +} + +int pack_4x8bit_signed_into_int( + const int val0, + const int val1, + const int val2, + const int val3) { + return int( + ((val0 & 0xFF) << 24) | ((val1 & 0xFF) << 16) | ((val2 & 0xFF) << 8) | + ((val3 & 0xFF))); +} + +void create_packed_block( + out Int8WeightBlockPacked block, + const Int8WeightBlockSourceData src_data) { + [[unroll]] for (int col = 0; col < 4; ++col) { + block.data[col] = pack_4x8bit_signed_into_int( + extract_8bit_from_packed_uint_le(src_data.data[0], col), + extract_8bit_from_packed_uint_le(src_data.data[1], col), + extract_8bit_from_packed_uint_le(src_data.data[2], col), + extract_8bit_from_packed_uint_le(src_data.data[3], col)); + } +} + +#ifdef USING_BUFFER + +void write_packed_block( + const Int8WeightBlockPacked block, + const uint block_x, + const uint block_y, + const uint nblocks_x) { + t_qmat2[block_y * nblocks_x + block_x] = block.data; +} + +#else // USING_TEXTURE + +void write_packed_block( + const Int8WeightBlockPacked block, + const uint block_x, + const uint block_y, + const uint nblocks_w) { + imageStore(t_qmat2, ivec2(block_x, block_y), block.data); +} + +#endif // USING_BUFFER + +#ifdef DEBUG_MODE + +void printInt8WeightBlockSourceData(const Int8WeightBlockSourceData src_data) { + debugPrintfEXT("int8_weight_block_source_data: \\n"); + [[unroll]] for (int row = 0; row < 4; ++row) { + debugPrintfEXT("row %i: %u \\n", row, src_data.data[row]); + } +} + +void printInt8WeightBlockPacked(const Int8WeightBlockPacked block) { + debugPrintfEXT("int8_weight_block_packed: \\n"); + debugPrintfEXT( + "%i %i %i %i \\n", + block.data[0], + block.data[1], + block.data[2], + block.data[3]); +} + +#endif // DEBUG_MODE + +#endif // LINEAR_INT8_WEIGHT_BLOCK_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_tile.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_tile.glslh new file mode 100644 index 00000000000..2711f1d3174 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_tile.glslh @@ -0,0 +1,45 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef LINEAR_INT8_WEIGHT_TILE_GLSLH +#define LINEAR_INT8_WEIGHT_TILE_GLSLH + +/* + * Defines the Int8WeightTile struct, which is used to represent a tile of the + * quantized int8 weight matrix of a quantized matrix multiplication operation. + * + * Settings: + * - TILE_K4: number of (groups of 4) rows in the weight tile + * - TILE_N4: number of (groups of 4) columns in the weight tile + */ + +#extension GL_EXT_control_flow_attributes : require + +struct Int8WeightTile { + ivec4 data[TILE_K4][TILE_N4]; +}; + +#ifdef DEBUG_MODE + +void printInt8WeightTile(const Int8WeightTile tile) { + debugPrintfEXT("int8_weight_tile: \\n"); + [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + debugPrintfEXT( + "%i %i %i %i \\n", + tile.data[k4][n4][0], + tile.data[k4][n4][1], + tile.data[k4][n4][2], + tile.data[k4][n4][3]); + } + } +} + +#endif // DEBUG_MODE + +#endif // LINEAR_INT8_WEIGHT_TILE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_tile_load.glslh new file mode 100644 index 00000000000..2b9baa84356 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_tile_load.glslh @@ -0,0 +1,75 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef LINEAR_INT8_WEIGHT_TILE_LOAD_GLSLH +#define LINEAR_INT8_WEIGHT_TILE_LOAD_GLSLH + +/* + * Defines functions to load a Int8WeightTile from input buffer/texture. + * + * Requires: + * - t_qmat2 to be declared in the shader layout (input buffer/texture) + * + * Settings: + * - WEIGHT_BUFFER to indicate t_qmat2 is a buffer, otherwise texture storage is + * assumed. + */ + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_int8_weight_tile.glslh" + +#ifdef WEIGHT_BUFFER + +ivec4 load_weight_block( + const uint block_x, + const uint block_y, + const uint nblocks_x) { + return t_qmat2[(block_y * nblocks_x) + block_x]; +} + +#else // WEIGHT_TEXTURE + +ivec4 load_weight_block( + const uint block_x, + const uint block_y, + const uint nblocks_x) { + return texelFetch(t_qmat2, ivec2(block_x, block_y), 0); +} + +#endif // WEIGHT_BUFFER + +void load_weight_tile( + out Int8WeightTile weight_tile, + const uint block_x, + const uint block_y, + const uint nblocks_x) { +#if TILE_K4 == 1 && TILE_N4 == 1 + weight_tile.data[0][0] = load_weight_block(block_x, block_y, nblocks_x); + +#elif TILE_K4 == 1 && TILE_N4 > 1 + [[unroll]] for (int x = 0; x < TILE_N4; ++x) { + weight_tile.data[0][x] = load_weight_block(block_x + x, block_y, nblocks_x); + } + +#elif TILE_K4 > 1 && TILE_N4 == 1 + [[unroll]] for (int y = 0; y < TILE_M4; ++y) { + weight_tile.data[y][0] = load_weight_block(block_x, block_y + y, nblocks_x); + } + +#else + [[unroll]] for (int y = 0; y < TILE_K4; ++y) { + [[unroll]] for (int x = 0; x < TILE_N4; ++x) { + weight_tile.data[y][x] = + load_weight_block(block_x + x, block_y + y, nblocks_x); + } + } +#endif +} + +#endif // LINEAR_INT8_WEIGHT_TILE_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_q8csw_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_q8csw_tiled.glsl new file mode 100644 index 00000000000..49d880f732f --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_q8csw_tiled.glsl @@ -0,0 +1,117 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, OUTPUT_STORAGE)} +#define T ${texel_load_component_type(DTYPE, OUTPUT_STORAGE)} + +$if OUTPUT_STORAGE == "buffer": + #define OUTPUT_BUFFER +$if INPUT_STORAGE == "buffer": + #define INPUT_BUFFER +$if WEIGHT_STORAGE == "buffer": + #define WEIGHT_BUFFER + +#define TILE_M4 ${TILE_M4} +#define TILE_N4 ${TILE_N4} +#define TILE_K4 ${TILE_K4} + +#define TILE_M ${TILE_M4 * 4} +#define TILE_N ${TILE_N4 * 4} +#define TILE_K ${TILE_K4 * 4} + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_output", DTYPE, OUTPUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_input", DTYPE, INPUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_qmat2", "int", WEIGHT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)} + +${layout_declare_ubo(B, "ivec4", "output_sizes")} +${layout_declare_ubo(B, "ivec4", "input_sizes")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "uint", "apply_bias", "0")} + +#include "linear_fp_input_tile_load.glslh" +#include "linear_int8_weight_tile_load.glslh" +#include "linear_fp_weight_tile.glslh" +#include "linear_fp_output_tile_fp_compute.glslh" +#include "linear_fp_output_tile_store.glslh" +#include "linear_scales_load.glslh" +#include "linear_bias_load.glslh" + +void main() { + // Each thread writes out a 4 wide x 4 high tile of output values + const uint out_tile_x = gl_GlobalInvocationID.x; + const uint out_tile_y = gl_GlobalInvocationID.y; + + const uint n = out_tile_x * TILE_N; + const uint m = out_tile_y * TILE_M; + + const uint n4 = div_4(n); + const uint m4 = div_4(m); + + if (n >= output_sizes.x || m >= output_sizes.y) { + return; + } + + const uint M = uint(input_sizes.y); + const uint K4 = div_up_4(input_sizes.x); + const uint N4 = div_up_4(output_sizes.x); + + FPOutTile out_tile; + initialize(out_tile); + + FPInputTile in_tile; + Int8WeightTile weight_tile; + FPWeightTile fp_weight_tile; + + const bool dont_check_bounds = (M - m) >= TILE_M; + + if (dont_check_bounds) { + for (int k4 = 0; k4 < K4; k4++) { + load_input_tile_no_checks(in_tile, k4, m, K4, M); + load_weight_tile(weight_tile, n4, k4, N4); + unpack(fp_weight_tile, weight_tile); + update(out_tile, in_tile, fp_weight_tile); + } + } else { + for (int k4 = 0; k4 < K4; k4++) { + load_input_tile_with_checks(in_tile, k4, m, K4, M); + load_weight_tile(weight_tile, n4, k4, N4); + unpack(fp_weight_tile, weight_tile); + update(out_tile, in_tile, fp_weight_tile); + } + } + + FPPerOutChannelParams scales_tile; + load_scales_tile(scales_tile, n4); + + if (apply_bias > 0) { + FPPerOutChannelParams bias_tile; + load_bias_tile(bias_tile, n4); + + apply_scales_and_biases(out_tile, scales_tile, bias_tile); + } + else { + apply_scales(out_tile, scales_tile); + } + + if (dont_check_bounds) { + write_output_tile_no_checks(out_tile, n4, m, N4, M); + } else { + write_output_tile_with_checks(out_tile, n4, m, N4, M); + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_q8csw_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_q8csw_tiled.yaml new file mode 100644 index 00000000000..2356fcdb251 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_q8csw_tiled.yaml @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +linear_q8csw_tiled: + parameter_names_with_default_values: + DTYPE: float + OUTPUT_STORAGE: texture3d + INPUT_STORAGE: texture3d + WEIGHT_STORAGE: texture2d + TILE_M4: 1 + TILE_N4: 1 + TILE_K4: 1 + generate_variant_forall: + DTYPE: + - VALUE: float + shader_variants: + - NAME: linear_q8csw_tiled_texture3d_texture3d_texture2d + - NAME: linear_q8csw_tiled_texture3d_texture3d_buffer + WEIGHT_STORAGE: buffer + - NAME: linear_q8csw_tiled_buffer_buffer_texture2d + OUTPUT_STORAGE: buffer + INPUT_STORAGE: buffer + WEIGHT_STORAGE: texture2d + - NAME: linear_q8csw_tiled_buffer_buffer_buffer + OUTPUT_STORAGE: buffer + INPUT_STORAGE: buffer + WEIGHT_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled.glsl new file mode 100644 index 00000000000..a4bd4b4a115 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled.glsl @@ -0,0 +1,117 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, OUTPUT_STORAGE)} +#define T int + +$if OUTPUT_STORAGE == "buffer": + #define OUTPUT_BUFFER +$if INPUT_STORAGE == "buffer": + #define INPUT_BUFFER +$if WEIGHT_STORAGE == "buffer": + #define WEIGHT_BUFFER + +#define TILE_M4 ${TILE_M4} +#define TILE_N4 ${TILE_N4} +#define TILE_K4 ${TILE_K4} + +#define TILE_M ${TILE_M4 * 4} +#define TILE_N ${TILE_N4 * 4} +#define TILE_K ${TILE_K4 * 4} + +${define_required_extensions(DTYPE)} + +#extension GL_EXT_integer_dot_product : require + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_output", DTYPE, OUTPUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_input", "int", INPUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_qmat2", "int", WEIGHT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_sums", "float", "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)} + +${layout_declare_spec_const(C, "uint", "apply_bias", "0")} + +${layout_declare_ubo(B, "ivec4", "output_sizes")} +${layout_declare_ubo(B, "ivec4", "input_sizes")} + +layout(push_constant) uniform restrict Block { + float input_scale; + int input_zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "linear_int8_input_tile_load.glslh" +#include "linear_int8_weight_tile_load.glslh" +#include "linear_fp_output_tile_int8_compute.glslh" +#include "linear_fp_output_tile_store.glslh" +#include "linear_scales_load.glslh" +#include "linear_weight_sums_load.glslh" +#include "linear_bias_load.glslh" + +void main() { + // Each thread writes out a 4 wide x 4 high tile of output values + const uint out_tile_x = gl_GlobalInvocationID.x; + const uint out_tile_y = gl_GlobalInvocationID.y; + + const uint n = out_tile_x * TILE_N; + const uint m = out_tile_y * TILE_M; + + const uint n4 = div_4(n); + const uint m4 = div_4(m); + + if (n >= output_sizes.x || m >= output_sizes.y) { + return; + } + + const uint M = output_sizes.y; + const uint K4 = div_up_4(input_sizes.x); + const uint N4 = div_up_4(output_sizes.x); + + Int8OutAccum out_accum; + initialize(out_accum); + + Int8InputTile in_tile; + Int8WeightTile weight_tile; + + for (int k4 = 0; k4 < K4; k4++) { + load_input_tile(in_tile, k4, m4, K4); + load_weight_tile(weight_tile, n4, k4, N4); + + accumulate(out_accum, in_tile, weight_tile); + } + + FPPerOutChannelParams scales_tile; + load_scales_tile(scales_tile, n4); + + FPPerOutChannelParams sums_tile; + load_sums_tile(sums_tile, n4); + + FPOutTile out_tile; + if (apply_bias > 0) { + FPPerOutChannelParams bias_tile; + load_bias_tile(bias_tile, uint(n4)); + + compute(out_tile, out_accum, sums_tile, scales_tile, bias_tile); + } + else { + compute(out_tile, out_accum, sums_tile, scales_tile); + } + + if (M - m >= TILE_M) { + write_output_tile_no_checks(out_tile, n4, m, N4, M); + } else { + write_output_tile_with_checks(out_tile, n4, m, N4, M); + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled.yaml new file mode 100644 index 00000000000..dfaa839e02e --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled.yaml @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +linear_q8ta_q8csw_tiled: + parameter_names_with_default_values: + DTYPE: float + OUTPUT_STORAGE: texture3d + INPUT_STORAGE: texture3d + WEIGHT_STORAGE: texture2d + TILE_M4: 1 + TILE_N4: 1 + TILE_K4: 1 + generate_variant_forall: + DTYPE: + - VALUE: float + shader_variants: + - NAME: linear_q8ta_q8csw_tiled_texture3d_texture3d_texture2d + - NAME: linear_q8ta_q8csw_tiled_texture3d_texture3d_buffer + WEIGHT_STORAGE: buffer + - NAME: linear_q8ta_q8csw_tiled_buffer_buffer_texture2d + OUTPUT_STORAGE: buffer + INPUT_STORAGE: buffer + WEIGHT_STORAGE: texture2d + - NAME: linear_q8ta_q8csw_tiled_buffer_buffer_buffer + OUTPUT_STORAGE: buffer + INPUT_STORAGE: buffer + WEIGHT_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_scales_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_scales_load.glslh new file mode 100644 index 00000000000..47f6d318008 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_scales_load.glslh @@ -0,0 +1,30 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef LINEAR_SCALES_LOAD_GLSLH +#define LINEAR_SCALES_LOAD_GLSLH + +#include "linear_common.glslh" + +VEC4_T load_scale_x4(const uint n4) { + return t_weight_scales[n4]; +} + +void load_scales_tile(out FPPerOutChannelParams scales, const uint n4_start) { +#if TILE_N4 == 1 + scales.data[0] = load_scale_x4(n4_start); + +#else + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + scales.data[n4] = load_scale_x4[n4_start + n4]; + } + +#endif +} + +#endif // LINEAR_SCALES_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_weight_sums_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_weight_sums_load.glslh new file mode 100644 index 00000000000..8c13315d50d --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_weight_sums_load.glslh @@ -0,0 +1,30 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef LINEAR_WEIGHT_SUMS_LOAD_GLSLH +#define LINEAR_WEIGHT_SUMS_LOAD_GLSLH + +#include "linear_common.glslh" + +VEC4_T load_sum_x4(const uint n4) { + return VEC4_T(t_weight_sums[n4]); +} + +void load_sums_tile(out FPPerOutChannelParams sums, const uint n4_start) { +#if TILE_N4 == 1 + sums.data[0] = load_sum_x4(n4_start); + +#else + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + sums.data[n4] = load_sum_x4[n4_start + n4]; + } + +#endif +} + +#endif // LINEAR_WEIGHT_SUMS_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/pack_q8_linear_weight.glsl b/backends/vulkan/runtime/graph/ops/glsl/pack_q8_linear_weight.glsl new file mode 100644 index 00000000000..e731aa596a7 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/pack_q8_linear_weight.glsl @@ -0,0 +1,62 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +${define_active_storage_type(STORAGE)} + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_qmat2", "int", STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_input", "int", "buffer")} + +layout(push_constant) uniform restrict Block { + ivec4 qmat2_sizes; + ivec2 orig_sizes; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "common.glslh" +#include "linear_int8_weight_block.glslh" + +void main() { + uint block_x = gl_GlobalInvocationID.x; + uint block_y = gl_GlobalInvocationID.y; + + const int N = orig_sizes.y; + const int K = orig_sizes.x; + + // Each group of 4 8bit values are packed into each uint in the input tensor. + const int N4 = div_up_4(N); + const int K4 = div_up_4(K); + + // Check bounds + if (block_x >= N4 || block_y >= K4) { + return; + } + + Int8WeightBlockSourceData src_data; + const uint k = mul_4(block_y); + if (K - k >= 4) { + load_block_source_data_no_checks(src_data, block_x, mul_4(block_y), N4, K); + } else { + load_block_source_data_with_checks(src_data, block_x, mul_4(block_y), N4, K); + } + + Int8WeightBlockPacked packed_block; + create_packed_block(packed_block, src_data); + + write_packed_block( + packed_block, + block_x, + block_y, + N4); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/pack_q8_linear_weight.yaml b/backends/vulkan/runtime/graph/ops/glsl/pack_q8_linear_weight.yaml new file mode 100644 index 00000000000..13e6d43b2c5 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/pack_q8_linear_weight.yaml @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +pack_q8_linear_weight: + parameter_names_with_default_values: + STORAGE: buffer + shader_variants: + - NAME: pack_q8_linear_weight_buffer + STORAGE: buffer + - NAME: pack_q8_linear_weight_texture2d + STORAGE: texture2d diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input.glsl b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input.glsl new file mode 100644 index 00000000000..5a9b9f30ce4 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input.glsl @@ -0,0 +1,79 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, INPUT_STORAGE)} +#define T ${texel_load_component_type(DTYPE, INPUT_STORAGE)} + +$if OUTPUT_STORAGE == "buffer": + #define OUTPUT_BUFFER +$if INPUT_STORAGE == "buffer": + #define INPUT_BUFFER + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "common.glslh" + +${layout_declare_tensor(B, "w", "t_output", "int", OUTPUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_input", DTYPE, INPUT_STORAGE, is_scalar_array=False)} + +$if GRANULARITY == "per_channel": + ${layout_declare_tensor(B, "r", "t_scale", DTYPE, "buffer")} + +${layout_declare_ubo(B, "ivec4", "input_sizes")} + +layout(push_constant) uniform restrict Block { + float inv_scale; + int zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "linear_int8_input_block.glslh" +#include "linear_fp_input_tile_load.glslh" + +void main() { + // Each input block contains 4x4 int8 quantized values, which are packed into + // a ivec4. k4 and m4 represent the "block index" of the current block being + // processed. + uint k4 = gl_GlobalInvocationID.x; + uint m4 = gl_GlobalInvocationID.y; + + const int K = input_sizes.x; + const int M = input_sizes.y; + + // K4 and M4 represent the number of blocks in each dimension. + const int K4 = div_up_4(K); + const int M4 = div_up_4(M); + + if (k4 >= K4 || m4 >= M4) { + return; + } + + // row of the input tensor to start loading from. Note the input tensor is + // interpreted as a t + const uint m = mul_4(m4); + + const bool dont_check_bounds = (M - m) >= 4; + + FPInputTile in_tile; + if (dont_check_bounds) { + load_input_tile_no_checks(in_tile, k4, m, K4, M); + } else { + load_input_tile_with_checks(in_tile, k4, m, K4, M); + } + + Int8InputBlock packed_block; + quantize_and_pack(packed_block, in_tile); + + write_block(packed_block, k4, m4, K4); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input.yaml new file mode 100644 index 00000000000..37721db1ba8 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input.yaml @@ -0,0 +1,24 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +quantize_and_pack_linear_input: + parameter_names_with_default_values: + DTYPE: float + OUTPUT_STORAGE: texture3d + INPUT_STORAGE: texture3d + STORAGE: texture3d + GRANULARITY: per_tensor + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: quantize_and_pack_linear_input_per_tensor_texture3d_texture3d + - NAME: quantize_and_pack_linear_input_per_tensor_buffer_texture3d + OUTPUT_STORAGE: buffer + - NAME: quantize_and_pack_linear_input_per_tensor_buffer_buffer + OUTPUT_STORAGE: buffer + INPUT_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp new file mode 100644 index 00000000000..7cbaf6cc409 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp @@ -0,0 +1,548 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +namespace vkcompute { + +utils::uvec3 quantized_linear_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef out = args.at(0).refs.at(0); + + std::vector out_sizes = graph->sizes_of(out); + // height + const uint32_t M = utils::val_at(-2, out_sizes); + // width + const uint32_t N = utils::val_at(-1, out_sizes); + + // 1 output tile is 4x4 elements + const uint32_t M4 = utils::div_up(M, 4u); + const uint32_t N4 = utils::div_up(N, 4u); + + return {N4, M4, 1}; +} + +utils::uvec3 quantized_linear_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)args; + (void)resize_args; + + // Optimize local workgroup size for linear operations + uint32_t local_wg_size_x = 1; + uint32_t local_wg_size_y = 1; + + if (global_workgroup_size[1] % 8 == 0) { + local_wg_size_y = 8; + } else if (global_workgroup_size[1] % 4 == 0) { + local_wg_size_y = 4; + } else if (global_workgroup_size[1] % 2 == 0) { + local_wg_size_y = 2; + } + + // Adjust x dimension to maintain reasonable total workgroup size + local_wg_size_x = std::min(64u / local_wg_size_y, global_workgroup_size[0]); + + return {local_wg_size_x, local_wg_size_y, 1}; +} + +ValueRef prepack_q8_linear_weight( + ComputeGraph& graph, + const ValueRef qmat2_data) { + std::vector qmat2_orig_sizes = graph.sizes_of(qmat2_data); + const int64_t ndim = graph.dim_of(qmat2_data); + + // Input is [K, N] + const int64_t K = qmat2_orig_sizes.at(ndim - 2); + const int64_t N = qmat2_orig_sizes.at(ndim - 1); + + // N must be a multiple of 4 so data data loads are aligned nicely with texel + // boundaries. + VK_CHECK_COND(N % 4 == 0); + + // This packing format partitions the weight tensor into 4 wide x 4 high + // blocks. To figure out the size of the output tensor, determine the number + // of blocks along the width and height dims. + const int64_t num_blocks_K = utils::div_up(K, int64_t(4)); + const int64_t num_blocks_N = utils::div_up(N, int64_t(4)); + + // Each transposed block is 4 wide x 4 high. To maximize memory loading + // efficiency, the packed weight tensor will use a base data type of uint32_t; + // in terms of uint32_t, each block is 1 wide x 4 high. However, each block is + // also flattened as it is stored, so that the whole block can be loaded at + // once. As a result, the stored block will be 4 wide x 1 high. + const int64_t output_height = num_blocks_K; + const int64_t output_width = num_blocks_N * 4; + + // Store the original sizes of the tensor to pass to the shader + utils::ivec2 orig_sizes{ + utils::safe_downcast(K), utils::safe_downcast(N)}; + + std::vector qmat2_sizes{output_height, output_width}; + + utils::StorageType storage_type = utils::kTexture2D; + uint32_t max_extent = graph.context()->adapter_ptr()->max_texture2d_dim(); + if (output_width > max_extent * 4 || output_height > max_extent) { + storage_type = utils::kBuffer; + } + + ValueRef qmat2 = graph.add_tensor( + qmat2_sizes, vkcompute::vkapi::kInt, storage_type, utils::kWidthPacked); + + // Global workgroup size: each thread writes out two adjacent blocks + utils::uvec3 global_wg_size{ + utils::safe_downcast(num_blocks_N), + utils::safe_downcast(num_blocks_K), + 1u}; + + std::string kernel_name = "pack_q8_linear_weight"; + add_storage_type_suffix(kernel_name, storage_type); + + graph.prepack_nodes().emplace_back(new PrepackNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_wg_size, + graph.create_local_wg_size(global_wg_size), + // Inputs and Outputs + qmat2_data, + qmat2, + // UBOs + {}, + // Specialization Constants + {}, + // Push Constants + {graph.sizes_pc_of(qmat2), + PushConstantDataInfo(&orig_sizes, sizeof(utils::ivec2))})); + + return qmat2; +} + +struct InputQuantConstants { + alignas(16) float inv_scale; + alignas(16) int32_t zp; +}; + +std::tuple get_quantized_input_num_blocks( + ComputeGraph& graph, + const ValueRef input) { + std::vector input_sizes = graph.sizes_of(input); + const int64_t ndim = graph.dim_of(input); + + const int64_t M = input_sizes.at(ndim - 2); + const int64_t K = input_sizes.at(ndim - 1); + + const int64_t num_blocks_M = utils::div_up(M, int64_t(4)); + const int64_t num_blocks_K = utils::div_up(K, int64_t(4)); + + return std::make_tuple(num_blocks_M, num_blocks_K); +} + +utils::uvec3 quant_pack_input_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef input = args.at(1).refs.at(0); + int64_t num_blocks_M, num_blocks_K; + std::tie(num_blocks_M, num_blocks_K) = + get_quantized_input_num_blocks(*graph, input); + + return { + utils::safe_downcast(num_blocks_K), + utils::safe_downcast(num_blocks_M), + 1u}; +} + +DynamicDispatchNode make_quantize_and_pack_linear_input_node( + ComputeGraph& graph, + const ValueRef input, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef quantized_input) { + int64_t num_blocks_M, num_blocks_K; + std::tie(num_blocks_M, num_blocks_K) = + get_quantized_input_num_blocks(graph, input); + + bool is_per_channel = graph.val_is_tensor(input_scale); + + float inv_scale = 1.0f; + int32_t zp = 0; + if (!is_per_channel) { + inv_scale = 1.0f / graph.extract_scalar(input_scale); + zp = graph.extract_scalar(input_zp); + } + + std::string shader_name = "quantize_and_pack_linear_input"; + if (is_per_channel) { + shader_name += "_per_channel"; + } else { + shader_name += "_per_tensor"; + } + add_storage_type_suffix(shader_name, graph.storage_type_of(quantized_input)); + add_storage_type_suffix(shader_name, graph.storage_type_of(input)); + add_dtype_suffix(shader_name, graph.dtype_of(input)); + + vkapi::ParamsBindList param_buffers = {graph.sizes_ubo(input)}; + + std::vector push_constants = { + PushConstantDataInfo(&inv_scale, sizeof(inv_scale)), + PushConstantDataInfo(&zp, sizeof(zp)), + }; + + return DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(shader_name), + quant_pack_input_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{quantized_input, vkapi::kWrite}, {input, vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + {}, + // Resize args + {}); +} + +DynamicDispatchNode make_linear_q8ta_q8csw_tiled_node( + ComputeGraph& graph, + const std::vector& args) { + // Extract arguments + int32_t idx = 0; + const ValueRef input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef packed_weight = args.at(idx++); + const ValueRef packed_weight_sums = args.at(idx++); + const ValueRef packed_weight_scales = args.at(idx++); + const ValueRef bias = args.at(idx++); + const ValueRef packed_bias = args.at(idx++); + const ValueRef output = args.at(idx++); + const ValueRef original_weight = args.at(idx++); // For resize args + + bool is_per_channel = graph.val_is_tensor(input_scale); + + float scale = 1.0f; + int32_t zp = 0; + if (!is_per_channel) { + scale = graph.extract_scalar(input_scale); + zp = graph.extract_scalar(input_zp); + } + + // Get shader for quantized linear + std::string kernel_name = "linear_q8ta_q8csw_tiled"; + add_storage_type_suffix(kernel_name, graph.storage_type_of(output)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(packed_weight)); + add_dtype_suffix(kernel_name, graph.dtype_of(output)); + vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name); + + vkapi::ParamsBindList param_buffers = { + graph.sizes_ubo(output), graph.sizes_ubo(input)}; + + std::vector push_constants = { + PushConstantDataInfo(&scale, sizeof(scale)), + PushConstantDataInfo(&zp, sizeof(zp)), + }; + + uint32_t apply_bias = 0; + if (!graph.val_is_none(bias)) { + apply_bias = 1; + } + + // Add the compute node + return DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + quantized_linear_global_wg_size, + quantized_linear_local_wg_size, + // Inputs and Outputs + {{output, vkapi::kWrite}, + {{input, + packed_weight, + packed_weight_sums, + packed_weight_scales, + packed_bias}, + vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + {apply_bias}, + // Resize args + {original_weight}, + // Resizing Logic + nullptr); +} + +DynamicDispatchNode make_linear_q8csw_node( + ComputeGraph& graph, + const std::vector& args) { + // Extract arguments + int32_t idx = 0; + const ValueRef input = args.at(idx++); + const ValueRef packed_weight = args.at(idx++); + const ValueRef packed_weight_scales = args.at(idx++); + const ValueRef bias = args.at(idx++); + const ValueRef packed_bias = args.at(idx++); + const ValueRef output = args.at(idx++); + const ValueRef original_weight = args.at(idx++); // For resize args + + // Get shader for quantized linear + std::string kernel_name = "linear_q8csw_tiled"; + add_storage_type_suffix(kernel_name, graph.storage_type_of(output)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(packed_weight)); + add_dtype_suffix(kernel_name, graph.dtype_of(output)); + vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name); + + vkapi::ParamsBindList param_buffers = { + graph.sizes_ubo(output), graph.sizes_ubo(input)}; + + uint32_t apply_bias = 0; + if (!graph.val_is_none(bias)) { + apply_bias = 1; + } + + return DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + quantized_linear_global_wg_size, + quantized_linear_local_wg_size, + // Inputs and Outputs + {{output, vkapi::kWrite}, + {{input, packed_weight, packed_weight_scales, packed_bias}, + vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + {}, + // Specialization Constants + {apply_bias}, + // Resize args + {original_weight}, + // Resizing Logic + nullptr); +} + +/* + * Allows orchestration of two compute shader dispatch paths: + * 1. quantize & pack input to int8, execute linear_q8ta_q8csw + * 2. execute linear_q8csw with fp inputs + * + * The reason for this split is twofold: + * - Some devices may not support accelerated int8 dot product. In that case, + * there is no benefit to quantizing the input tensor. In that case + * linear_q8csw is required. + * - For LLMs, which switch between GEMM and GEMV input conditions when going + * from prefill to decode. GEMM is typically a compute bound operation, which + * will benefit from accelerated int8 accumulation. On the other hand, GEMV + * is usually memory bound, which means it may actually suffer from the extra + * cost of having to quantize and pack the input tensor. Therefore, + * linear_q8ta_q8csw is preferred fro GEMM and linear_q8csw is preferred for + * GEMV. + * + * Note that dynamic shape is currently not supported, so switching paths + * when input conditions go between GEMM -> GEMV is currently not implemented. + * This will be implemented at a later date. + */ +struct QuantizedLinearNode : public ExecuteNode { + friend class ComputeGraph; + + bool can_use_int8_dot_product = false; + DynamicDispatchNode quantize_and_pack_input_node; + DynamicDispatchNode linear_q8ta_q8csw_tiled_node; + DynamicDispatchNode linear_q8csw_node; + + explicit QuantizedLinearNode( + ComputeGraph& graph, + const std::vector& args, + DynamicDispatchNode&& quant_pack_input, + DynamicDispatchNode&& qaqw_tiled_linear, + DynamicDispatchNode&& linear_q8csw, + bool int8_dot_product_enabled) + : ExecuteNode(), + quantize_and_pack_input_node(std::move(quant_pack_input)), + linear_q8ta_q8csw_tiled_node(std::move(qaqw_tiled_linear)), + linear_q8csw_node(std::move(linear_q8csw)) { + if (int8_dot_product_enabled) { + can_use_int8_dot_product = graph.can_use_int8_dot_product(); + } + } + + void prepare_pipelines(ComputeGraph* graph) override { + if (can_use_int8_dot_product) { + quantize_and_pack_input_node.prepare_pipelines(graph); + linear_q8ta_q8csw_tiled_node.prepare_pipelines(graph); + } + linear_q8csw_node.prepare_pipelines(graph); + } + + void encode(ComputeGraph* graph) override { + if (can_use_int8_dot_product) { + quantize_and_pack_input_node.encode(graph); + linear_q8ta_q8csw_tiled_node.encode(graph); + } else { + linear_q8csw_node.encode(graph); + } + } +}; + +/* + * Implements activation and weight quantized linear. Currently, only the + * following quantization configurations are supported: + * - activation quantized to int8 with per tensor quant params + * - weight quantized to int8 with per channel quant params + */ +void linear_q8ta_q8csw_impl( + ComputeGraph& graph, + const std::vector& args, + const bool use_int8_dot_product = true) { + int32_t idx = 0; + const ValueRef input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef weight = args.at(idx++); + const ValueRef weight_sums = args.at(idx++); + const ValueRef weight_scales = args.at(idx++); + const ValueRef bias = args.at(idx++); + const ValueRef output = args.at(idx++); + + bool is_per_channel = graph.val_is_tensor(input_scale); + + // Input validation + std::vector input_sizes = graph.sizes_of(input); + std::vector weight_sizes = graph.sizes_of(weight); + + const int64_t K = utils::val_at(-1, input_sizes); + // K (input channels) must be a multiple of 4 to ensure that reading a group + // of 4 input channels from the input tensor will be aligned on a texel + // boundary. + VK_CHECK_COND(K % 4 == 0); + + const int64_t N = utils::val_at(-1, input_sizes); + // N (output channels) must be a multiple of 4 to ensure that reading a group + // of 4 output channels from the weight/output tensor will be aligned on a + // texel boundary. + VK_CHECK_COND(N % 4 == 0); + + // Prepacking + const ValueRef packed_weight = prepack_q8_linear_weight(graph, weight); + ValueRef packed_weight_scales = prepack_standard( + graph, weight_scales, utils::kBuffer, utils::kWidthPacked); + ValueRef packed_weight_sums = + prepack_standard(graph, weight_sums, utils::kBuffer, utils::kWidthPacked); + + ValueRef packed_input_scale = input_scale; + ValueRef packed_input_zp = input_zp; + if (is_per_channel) { + packed_input_scale = prepack_standard( + graph, input_scale, utils::kBuffer, utils::kWidthPacked); + packed_input_zp = + prepack_standard(graph, input_zp, utils::kBuffer, utils::kWidthPacked); + } + + // Create a dummy tensor to fill the binding slot of the bias tensor if it is + // not provided. This helps simplify dispatch logic and makes it so that + // fewer shdaer variants need to be generated. + TmpTensor dummy_bias( + &graph, {}, graph.dtype_of(output), utils::kBuffer, utils::kWidthPacked); + + ValueRef packed_bias = dummy_bias.vref; + if (!graph.val_is_none(bias)) { + packed_bias = + prepack_standard(graph, bias, utils::kBuffer, utils::kWidthPacked); + } + + int64_t num_blocks_M, num_blocks_K; + std::tie(num_blocks_M, num_blocks_K) = + get_quantized_input_num_blocks(graph, input); + + const int64_t quantized_input_height = num_blocks_M; + const int64_t quantized_input_width = num_blocks_K * 4; + + TmpTensor quantized_packed_input( + &graph, + {quantized_input_height, quantized_input_width}, + vkapi::kInt, + graph.storage_type_of(input), + utils::kWidthPacked); + + DynamicDispatchNode quantize_and_pack_linear_node( + make_quantize_and_pack_linear_input_node( + graph, + input, + packed_input_scale, + packed_input_zp, + quantized_packed_input)); + + std::vector linear_args = { + quantized_packed_input, + packed_input_scale, + packed_input_zp, + packed_weight, + packed_weight_sums, + packed_weight_scales, + bias, + packed_bias, + output, + weight}; + + DynamicDispatchNode linear_q8ta_q8csw_tiled_node( + make_linear_q8ta_q8csw_tiled_node(graph, linear_args)); + + linear_args = { + input, + packed_weight, + packed_weight_scales, + bias, + packed_bias, + output, + weight}; + + DynamicDispatchNode linear_q8csw_node( + make_linear_q8csw_node(graph, linear_args)); + + graph.execute_nodes().emplace_back(new QuantizedLinearNode( + graph, + linear_args, + std::move(quantize_and_pack_linear_node), + std::move(linear_q8ta_q8csw_tiled_node), + std::move(linear_q8csw_node), + use_int8_dot_product)); +} + +void linear_q8ta_q8csw(ComputeGraph& graph, const std::vector& args) { + linear_q8ta_q8csw_impl(graph, args, true); +} + +void linear_q8ta_q8csw_no_int8( + ComputeGraph& graph, + const std::vector& args) { + linear_q8ta_q8csw_impl(graph, args, false); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(et_vk.linear_q8ta_q8csw.default, linear_q8ta_q8csw); + VK_REGISTER_OP(et_vk.linear_q8ta_q8csw.noint8, linear_q8ta_q8csw_no_int8); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.h b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.h new file mode 100644 index 00000000000..11af0b4a0f5 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace vkcompute { + +utils::uvec3 quantized_linear_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args); + +ValueRef prepack_q8_linear_weight( + ComputeGraph& graph, + const ValueRef qmat2_data); + +DynamicDispatchNode make_linear_q8ta_q8csw_tiled_node( + ComputeGraph& graph, + const std::vector& args); + +DynamicDispatchNode make_linear_q8csw_node( + ComputeGraph& graph, + const std::vector& args); + +} // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/CMakeLists.txt b/backends/vulkan/test/custom_ops/CMakeLists.txt index bb8554e1a91..23f50ad0a98 100644 --- a/backends/vulkan/test/custom_ops/CMakeLists.txt +++ b/backends/vulkan/test/custom_ops/CMakeLists.txt @@ -93,4 +93,5 @@ if(TARGET vulkan_backend) # Define operator prototypes add_operator_prototype(add) + add_operator_prototype(quantized_linear) endif() diff --git a/backends/vulkan/test/custom_ops/quantized_linear.cpp b/backends/vulkan/test/custom_ops/quantized_linear.cpp new file mode 100644 index 00000000000..d081f3b621c --- /dev/null +++ b/backends/vulkan/test/custom_ops/quantized_linear.cpp @@ -0,0 +1,352 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include "utils.h" + +#include + +using namespace executorch::vulkan::prototyping; + +using namespace vkcompute; + +// Linear configuration struct +struct LinearConfig { + int64_t M; // Batch size / number of rows in input + int64_t K; // Input features / columns in input, rows in weight + int64_t N; // Output features / columns in weight + std::string name_suffix; + std::string shader_variant_name = "default"; +}; + +// Utility function to create a test case from a LinearConfig +TestCase create_test_case_from_config( + const LinearConfig& config, + utils::StorageType storage_type, + vkapi::ScalarType input_dtype) { + TestCase test_case; + + // Create a descriptive name for the test case + std::string storage_str = + (storage_type == utils::kTexture3D) ? "Texture3D" : "Buffer"; + std::string dtype_str = (input_dtype == vkapi::kFloat) ? "Float" : "Half"; + + std::string test_name = "QuantizedLinear_" + config.name_suffix + "_" + + storage_str + "_" + dtype_str; + test_case.set_name(test_name); + + // Set the operator name for the test case + std::string operator_name = "et_vk.linear_q8ta_q8csw."; + operator_name += config.shader_variant_name; + test_case.set_operator_name(operator_name); + + // Derive sizes from M, K, N + std::vector input_size = {config.M, config.K}; + std::vector weight_size = {config.K, config.N}; + + // Input tensor (float/half) - [M, K] + ValueSpec input_tensor( + input_size, + input_dtype, + storage_type, + utils::kWidthPacked, + DataGenType::RANDINT); + + if (debugging()) { + print_valuespec_data(input_tensor, "input_tensor"); + } + + float input_scale_val = 0.5f; + ValueSpec input_scale(input_scale_val); + + int32_t input_zero_point_val = -4; + ValueSpec input_zero_point(input_zero_point_val); + + // Quantized weight tensor (int8) - [K, N] + ValueSpec quantized_weight( + weight_size, + vkapi::kChar, // int8 for quantized weights + storage_type, + utils::kWidthPacked, + DataGenType::RANDINT8); + quantized_weight.set_constant(true); + + if (debugging()) { + print_valuespec_data(quantized_weight, "weight_tensor"); + } + + // Weight quantization scales (float/half, per-channel) + ValueSpec weight_scales( + {config.N}, // Per output feature + input_dtype, + storage_type, + utils::kWidthPacked, + DataGenType::RANDOM_SCALES); + weight_scales.set_constant(true); + + ValueSpec weight_sums( + {config.N}, // Per output features + vkapi::kFloat, + storage_type, + utils::kWidthPacked, + DataGenType::ZEROS); + weight_sums.set_constant(true); + + // Compute weight_sums data based on quantized weights + int64_t in_features = config.K; + int64_t out_features = config.N; + compute_weight_sums(weight_sums, quantized_weight, out_features, in_features); + + // Bias (optional, float/half) - [N] + ValueSpec bias( + {config.N}, // Per output feature + input_dtype, + storage_type, + utils::kWidthPacked, + DataGenType::ZEROS); + bias.set_constant(true); + + // Output tensor (float/half) - [M, N] + ValueSpec output( + {config.M, config.N}, + input_dtype, + storage_type, + utils::kWidthPacked, + DataGenType::ZEROS); + + // Add all specs to test case + test_case.add_input_spec(input_tensor); + test_case.add_input_spec(input_scale); + test_case.add_input_spec(input_zero_point); + test_case.add_input_spec(quantized_weight); + test_case.add_input_spec(weight_sums); + test_case.add_input_spec(weight_scales); + test_case.add_input_spec(bias); + + test_case.add_output_spec(output); + + return test_case; +} + +// Generate easy test cases for quantized linear operation (for debugging) +std::vector generate_quantized_linear_easy_cases() { + std::vector test_cases; + + // Single simple configuration for debugging + int M = 16; + int K = 128; + int N = 64; + + LinearConfig config = { + M, // Batch size + K, // Input features + N, // Output features + "simple" // descriptive name + }; + + // Test with both storage types and data types for completeness + std::vector storage_types = { + utils::kTexture3D, utils::kBuffer}; + std::vector float_types = {vkapi::kFloat}; + + // Generate test cases for each combination + for (const auto& storage_type : storage_types) { + for (const auto& input_dtype : float_types) { + test_cases.push_back( + create_test_case_from_config(config, storage_type, input_dtype)); + } + } + + return test_cases; +} + +// Generate test cases for quantized linear operation +std::vector generate_quantized_linear_test_cases() { + std::vector test_cases; + + std::vector configs = {// Small linear layers + {1, 64, 32, "64to32_single"}, + {1, 128, 64, "128to64_single"}, + {1, 256, 128, "256to128_single"}, + + // Larger batch sizes + {32, 64, 32, "64to32_batch32"}, + {32, 128, 64, "128to64_batch32"}, + {32, 256, 128, "256to128_batch32"}, + + // Performance test cases + {128, 2048, 2048, "perf_K2048"}, + {16384, 576, 128, "perf_conv"} + + }; + + // Test with different storage types and data types + std::vector storage_types = { + utils::kTexture3D, utils::kBuffer}; + + // Generate test cases for each combination + for (const auto& config : configs) { + for (const auto& storage_type : storage_types) { + // Test both with and without shader int8 dot product + test_cases.push_back( + create_test_case_from_config(config, storage_type, vkapi::kFloat)); + + LinearConfig no_int_config = config; + no_int_config.name_suffix = config.name_suffix + "_noint8"; + no_int_config.shader_variant_name = "noint8"; + + test_cases.push_back(create_test_case_from_config( + no_int_config, storage_type, vkapi::kFloat)); + } + } + + return test_cases; +} + +// Reference implementation for quantized linear operation +void quantized_linear_reference_impl(TestCase& test_case) { + static constexpr int64_t kRefDimSizeLimit = 300; + // Extract input specifications + int32_t idx = 0; + const ValueSpec& input_spec = test_case.inputs()[idx++]; + const ValueSpec& input_scale_spec = test_case.inputs()[idx++]; + const ValueSpec& input_zeros_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_sums_spec = test_case.inputs()[idx++]; + (void)weight_sums_spec; + const ValueSpec& weight_scales_spec = test_case.inputs()[idx++]; + const ValueSpec& bias_spec = test_case.inputs()[idx++]; + + // Extract output specification (mutable reference) + ValueSpec& output_spec = test_case.outputs()[0]; + + // Get tensor dimensions + auto input_sizes = input_spec.get_tensor_sizes(); // [batch_size, in_features] + auto weight_sizes = + weight_spec.get_tensor_sizes(); // [out_features, in_features] + auto output_sizes = + output_spec.get_tensor_sizes(); // [batch_size, out_features] + + int64_t batch_size = input_sizes[0]; + int64_t in_features = input_sizes[1]; + int64_t out_features = weight_sizes[1]; + + // Skip for large tensors since computation time will be extremely slow + if (batch_size > kRefDimSizeLimit || in_features > kRefDimSizeLimit || + out_features > kRefDimSizeLimit) { + throw std::invalid_argument( + "One or more dimensions (batch_size, in_features, out_features) exceed the allowed limit for reference implementation."); + } + + if (input_spec.dtype != vkapi::kFloat) { + throw std::invalid_argument("Unsupported dtype"); + } + + // Get raw data pointers + auto& input_data = input_spec.get_float_data(); + const float input_scale = input_scale_spec.get_float_value(); + const int32_t input_zero_point = input_zeros_spec.get_int_value(); + + auto& weight_data = weight_spec.get_int8_data(); + auto& weight_scales_data = weight_scales_spec.get_float_data(); + auto& bias_data = bias_spec.get_float_data(); + + // Calculate number of output elements + int64_t num_output_elements = batch_size * out_features; + + auto& ref_data = output_spec.get_ref_float_data(); + ref_data.resize(num_output_elements); + + // Perform quantized linear transformation (matrix multiplication) + for (int64_t b = 0; b < batch_size; ++b) { + for (int64_t out_f = 0; out_f < out_features; ++out_f) { + float sum = 0.0f; + + // Matrix multiplication: output[b][out_f] = sum(input[b][in_f] * + // weight[out_f][in_f]) + for (int64_t in_f = 0; in_f < in_features; ++in_f) { + // Get input value and dequantize + int64_t input_idx = b * in_features + in_f; + + float quant_input = + std::round(input_data[input_idx] / input_scale) + input_zero_point; + quant_input = std::min(std::max(quant_input, -128.0f), 127.0f); + float dequant_input = (quant_input - input_zero_point) * input_scale; + + // Get weight value and dequantize + int64_t weight_idx = in_f * out_features + out_f; + float dequant_weight = (static_cast(weight_data[weight_idx])) * + weight_scales_data[out_f]; + + sum += dequant_input * dequant_weight; + } + + // Add bias and store result + sum += bias_data[out_f]; + int64_t output_idx = b * out_features + out_f; + ref_data[output_idx] = sum; + } + } +} + +// Custom FLOP calculator for quantized linear operation +int64_t quantized_linear_flop_calculator(const TestCase& test_case) { + if (test_case.num_inputs() < 5 || test_case.num_outputs() < 1) { + return 0; + } + + // Get input and weight dimensions + const auto& input_sizes = test_case.inputs()[0].get_tensor_sizes(); + const auto& weight_sizes = test_case.inputs()[3].get_tensor_sizes(); + + int64_t batch_size = input_sizes[0]; + int64_t in_features = input_sizes[1]; + int64_t out_features = weight_sizes[0]; + + // Calculate FLOPs for quantized linear operation + // Each output element requires: + // - in_features multiply-accumulate operations + // - Additional operations for quantization/dequantization + int64_t output_elements = batch_size * out_features; + int64_t ops_per_output = in_features; + + // Add quantization overhead (approximate) + // - Dequantize input: 1 op per input element used + // - Dequantize weight: 1 op per weight element used + // - Add bias: 1 op per output element + int64_t quantization_ops = ops_per_output + 1; // Simplified estimate + + int64_t flop = output_elements * (ops_per_output + quantization_ops); + + return flop; +} + +int main(int argc, char* argv[]) { + set_debugging(false); + set_print_output(false); + set_print_latencies(false); + set_use_gpu_timestamps(true); + + print_performance_header(); + std::cout << "Quantized Linear Operation Prototyping Framework" << std::endl; + print_separator(); + + ReferenceComputeFunc ref_fn = quantized_linear_reference_impl; + + // Execute easy test cases using the new framework with custom FLOP calculator + auto results = execute_test_cases( + generate_quantized_linear_test_cases, + quantized_linear_flop_calculator, + "QuantizedLinear", + 0, + 10, + ref_fn); + + return 0; +} diff --git a/backends/vulkan/test/custom_ops/targets.bzl b/backends/vulkan/test/custom_ops/targets.bzl index 2ddf49834e1..c2c87d182c9 100644 --- a/backends/vulkan/test/custom_ops/targets.bzl +++ b/backends/vulkan/test/custom_ops/targets.bzl @@ -97,3 +97,4 @@ def define_common_targets(is_fbcode = False): ) define_custom_op_test_binary("add") + define_custom_op_test_binary("quantized_linear")