Skip to content

Commit c1ea7e9

Browse files
authored
[ET-VK] Quantized Int8 Linear (#14041)
Title says it all! This PR adds implementations for int8 linear layers. Convolution is implemented in a later step, computing convolution as matrix multiplication via the im2col procedure. For both linear and convolution, two versions are implemented: 1. `q8ta_q8csw` variant which quantized the input tensor and then performs integer accumulation via the int8 dot product extension 2. `q8csw` variant which dequantized the weight tensor in-shader and performs floating point accumulation. The second one is needed to provide an alternative path for executing quantized models if the target GPU does not support int8 dot product extension. These new ops are tested via the custom op testing + benchmarking framework introduced in the previous diff. Differential Revision: [D81323424](https://our.internmc.facebook.com/intern/diff/D81323424/)
1 parent bf22ab7 commit c1ea7e9

39 files changed

+3166
-21
lines changed

.github/workflows/pull.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -929,7 +929,9 @@ jobs:
929929
CMAKE_ARGS="-DEXECUTORCH_BUILD_VULKAN=ON" \
930930
.ci/scripts/setup-linux.sh --build-tool "cmake"
931931
932+
# Custom operator tests
932933
PYTHON_EXECUTABLE=python bash backends/vulkan/test/custom_ops/build_and_run.sh add
934+
./cmake-out/backends/vulkan/test/custom_ops/q8csw_linear
933935
934936
nxp-build-test:
935937
name: nxp-build-test

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,10 @@ class ComputeGraph final {
308308
return idx == kDummyValueRef ? true : values_.at(idx).isNone();
309309
}
310310

311+
inline bool val_is_not_none(const ValueRef idx) {
312+
return !val_is_none(idx);
313+
}
314+
311315
inline TypeTag get_val_type(const ValueRef idx) {
312316
return values_.at(idx).type();
313317
}
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#ifndef COMMON_GLSLH
10+
#define COMMON_GLSLH
11+
12+
#define mul_2(x) ((x) << 1)
13+
#define mul_4(x) ((x) << 2)
14+
#define mul_8(x) ((x) << 3)
15+
16+
#define div_2(x) ((x) >> 1)
17+
#define div_4(x) ((x) >> 2)
18+
#define div_8(x) ((x) >> 3)
19+
20+
#define div_up_2(x) (((x) + 1) >> 1)
21+
#define div_up_4(x) (((x) + 3) >> 2)
22+
#define div_up_8(x) (((x) + 7) >> 3)
23+
24+
#define align_up_2(x) ((x + 1) & -2)
25+
#define align_up_4(x) ((x + 3) & -4)
26+
#define align_up_8(x) ((x + 7) & -8)
27+
28+
#define mod_2(x) ((x) & 1)
29+
#define mod_4(x) ((x) & 3)
30+
#define mod_8(x) ((x) & 7)
31+
32+
struct TensorIndex4D {
33+
ivec4 data;
34+
};
35+
36+
#ifdef DEBUG_MODE
37+
38+
#extension GL_EXT_debug_printf : require
39+
40+
void printTensorIndex4D(const TensorIndex4D index) {
41+
debugPrintfEXT(
42+
"tensor_idx: %d, %d, %d, %d\\n",
43+
index.data.x,
44+
index.data.y,
45+
index.data.z,
46+
index.data.w);
47+
}
48+
49+
#endif // DEBUG_MODE
50+
51+
#endif // COMMON_GLSLH
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
/*
10+
* Defines common functions and structs to be used across matrix multiplication
11+
* operators.
12+
*/
13+
14+
#ifndef LINEAR_COMMON_GLSLH
15+
#define LINEAR_COMMON_GLSLH
16+
17+
#include "common.glslh"
18+
19+
int sign_extend_8bit(const int val) {
20+
if ((val & 0x80) != 0) {
21+
return val | (~0xFF);
22+
}
23+
return val;
24+
}
25+
26+
int extract_8bit_from_packed_int_le(const int packed, const int i) {
27+
// account for little endian
28+
int byte = sign_extend_8bit(packed >> (8 * i) & 0xFF);
29+
return byte;
30+
}
31+
32+
#endif // LINEAR_COMMON_GLSLH
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#ifndef LINEAR_FP_BIAS_LOAD_GLSLH
10+
#define LINEAR_FP_BIAS_LOAD_GLSLH
11+
12+
#include "linear_fp_per_out_channel_params.glslh"
13+
14+
VEC4_T load_bias_x4(const int n4) {
15+
return t_bias[n4];
16+
}
17+
18+
void load_bias_tile(out FPPerOutChannelParams bias, const int n4_start) {
19+
#if TILE_N4 == 1
20+
bias.data[0] = load_bias_x4(n4_start);
21+
22+
#else
23+
[[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) {
24+
bias.data[n4] = load_bias_x4(n4_start + n4);
25+
}
26+
27+
#endif
28+
}
29+
30+
#endif // LINEAR_FP_BIAS_LOAD_GLSLH
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#ifndef LINEAR_FP_INPUT_TILE_GLSLH
10+
#define LINEAR_FP_INPUT_TILE_GLSLH
11+
12+
/*
13+
* Defines the FPInputTile struct, which is used to represent a tile of the
14+
* input matrix of a matrix multiplication operation.
15+
*
16+
* Settings:
17+
* - TILE_M: number of rows in the tile
18+
* - TILE_K4: number of (groups of 4) columns in the tile
19+
*/
20+
21+
#extension GL_EXT_control_flow_attributes : require
22+
23+
struct FPInputTile {
24+
VEC4_T data[TILE_M][TILE_K4];
25+
};
26+
27+
#ifdef DEBUG_MODE
28+
29+
void printFPInputTile(const FPInputTile in_tile) {
30+
debugPrintfEXT("input_tile: \\n");
31+
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
32+
[[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) {
33+
debugPrintfEXT(
34+
" %f, %f, %f, %f, \\n",
35+
in_tile.data[m][k4].x,
36+
in_tile.data[m][k4].y,
37+
in_tile.data[m][k4].z,
38+
in_tile.data[m][k4].w);
39+
}
40+
}
41+
}
42+
43+
#endif // DEBUG_MODE
44+
45+
#endif // LINEAR_FP_INPUT_TILE_GLSLH
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
/*
10+
* Defines functions to load a FPInputTile from input buffer/texture.
11+
*
12+
* Requires:
13+
* - t_input to be declared in the shader layout (input buffer/texture)
14+
*
15+
* Settings:
16+
* - INPUT_BUFFER to indicate input resource is a buffer, otherwise texture is
17+
* assumed.
18+
*/
19+
20+
#ifndef LINEAR_FP_INPUT_TILE_LOAD_GLSLH
21+
#define LINEAR_FP_INPUT_TILE_LOAD_GLSLH
22+
23+
#extension GL_EXT_control_flow_attributes : require
24+
25+
#include "linear_fp_input_tile.glslh"
26+
27+
#ifdef INPUT_BUFFER
28+
29+
VEC4_T load_input_x4(const int k4, const int m, const int ntexels_k) {
30+
return t_input[(m * ntexels_k) + k4];
31+
}
32+
33+
#else
34+
35+
VEC4_T load_input_x4(const int k4, const int m, const int ntexels_k) {
36+
return texelFetch(t_input, ivec3(k4, m, 0), 0);
37+
}
38+
39+
#endif // INPUT_BUFFER
40+
41+
// To be used if (M - m_start >= TILE_M) || (K4 - k4_start >= TILE_K4)
42+
void load_input_tile_no_checks(
43+
out FPInputTile in_tile,
44+
const int k4_start,
45+
const int m_start,
46+
const int K4,
47+
const int M) {
48+
#if TILE_K4 == 1
49+
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
50+
in_tile.data[m][0] = load_input_x4(k4_start, m_start + m, K4);
51+
}
52+
53+
#else
54+
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
55+
[[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) {
56+
in_tile.data[m][k4] = load_input_x4(k4_start + k4, m_start + m, K4);
57+
}
58+
}
59+
#endif
60+
}
61+
62+
// To be used if near tensor boundaries
63+
void load_input_tile_with_checks(
64+
out FPInputTile in_tile,
65+
const int k4_start,
66+
const int m_start,
67+
const int K4,
68+
const int M) {
69+
#if TILE_K4 == 1
70+
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
71+
if (m_start + m < M) {
72+
in_tile.data[m][0] = load_input_x4(k4_start, m_start + m, K4);
73+
} else {
74+
in_tile.data[m][0] = VEC4_T(0.0);
75+
}
76+
}
77+
78+
#else
79+
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
80+
[[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) {
81+
if (m_start + m < M && k4_start + k4 < K4) {
82+
in_tile.data[m][k4] = load_input_x4(k4_start + k4, m_start + m, K4);
83+
} else {
84+
in_tile.data[m][k4] = VEC4_T(0.0);
85+
}
86+
}
87+
}
88+
#endif
89+
}
90+
91+
#endif // LINEAR_FP_INPUT_TILE_LOAD_GLSLH
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
/*
10+
* Defines the FPOutTile struct, which is used to represent a tile of the output
11+
* matrix of a matrix multiplication operation.
12+
*
13+
* Settings:
14+
* - TILE_M: number of rows in the output tile
15+
* - TILE_N4: number of (groups of 4) columns in the output tile
16+
*/
17+
18+
#ifndef LINEAR_FP_OUTPUT_TILE_GLSLH
19+
#define LINEAR_FP_OUTPUT_TILE_GLSLH
20+
21+
#extension GL_EXT_control_flow_attributes : require
22+
23+
struct FPOutTile {
24+
VEC4_T data[TILE_M][TILE_N4];
25+
};
26+
27+
void initialize(out FPOutTile out_tile) {
28+
#if TILE_N4 == 1
29+
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
30+
out_tile.data[m][0] = VEC4_T(0);
31+
}
32+
33+
#else
34+
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
35+
[[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) {
36+
out_tile.data[m][n4] = VEC4_T(0);
37+
}
38+
}
39+
#endif
40+
}
41+
42+
#ifdef DEBUG_MODE
43+
44+
void printFPOutTile(const FPOutTile tile) {
45+
debugPrintfEXT("output_tile: \\n");
46+
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
47+
[[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) {
48+
debugPrintfEXT(
49+
" %f, %f, %f, %f,",
50+
tile.data[m][n4].x,
51+
tile.data[m][n4].y,
52+
tile.data[m][n4].z,
53+
tile.data[m][n4].w);
54+
}
55+
debugPrintfEXT("\\n");
56+
}
57+
}
58+
59+
#endif // DEBUG_MODE
60+
61+
#endif // LINEAR_FP_OUTPUT_TILE_GLSLH

0 commit comments

Comments
 (0)