Skip to content

Commit 501f89c

Browse files
author
ssjia
committed
[ET-VK] Quantized Int8 Convolution + Linear
Title says it all! This PR adds implementations for int8 quantized convolution and linear layers. Convolution is implemented as matrix multiplication under the hood by using the im2col procedure. For both linear and convolution, two versions are implemented: 1. `q8ta_q8csw` variant which quantized the input tensor and then performs integer accumulation via the int8 dot product extension 2. `q8csw` variant which dequantized the weight tensor in-shader and performs floating point accumulation. The second one is needed to provide an alternative path for executing quantized models if the target GPU does not support int8 dot product extension. These new ops are tested via the custom op testing + benchmarking framework introduced in the previous diff. Differential Revision: [D81323424](https://our.internmc.facebook.com/intern/diff/D81323424/) [ghstack-poisoned]
1 parent 8bc6836 commit 501f89c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+5305
-0
lines changed
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
#define VEC4_T ${texel_load_type(DTYPE, OUTPUT_STORAGE)}
13+
#define T ${texel_load_component_type(DTYPE, OUTPUT_STORAGE)}
14+
15+
$if OUTPUT_STORAGE == "buffer":
16+
#define OUTPUT_BUFFER
17+
$if INPUT_STORAGE == "buffer":
18+
#define INPUT_BUFFER
19+
20+
#define TILE_M4 1
21+
#define TILE_N4 1
22+
#define TILE_K4 1
23+
24+
#define TILE_M 4
25+
#define TILE_N 4
26+
#define TILE_K 4
27+
28+
${define_required_extensions(DTYPE)}
29+
30+
layout(std430) buffer;
31+
32+
#include "conv2d_common.glslh"
33+
34+
${layout_declare_tensor(B, "w", "t_output", DTYPE, OUTPUT_STORAGE, is_scalar_array=False)}
35+
${layout_declare_tensor(B, "r", "t_input", DTYPE, INPUT_STORAGE, is_scalar_array=False)}
36+
37+
// Sizes of the convolution output image
38+
${layout_declare_ubo(B, "ivec4", "output_sizes")}
39+
// Sizes of the convolution input image
40+
${layout_declare_ubo(B, "ivec4", "input_sizes")}
41+
// Sizes of the im2col matrix of the convolution output
42+
${layout_declare_ubo(B, "ivec4", "matrix_sizes")}
43+
44+
${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")}
45+
46+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
47+
48+
#include "conv2d_fp_im2col_block_store.glslh"
49+
50+
#ifdef INPUT_BUFFER
51+
52+
void load_matrix_tile(
53+
out FPOutTile tile,
54+
const int n4,
55+
const int m_start,
56+
const int N4) {
57+
[[unroll]] for (int m = 0; m < TILE_M; m++) {
58+
tile.data[m][0] = t_input[(m_start + m) * N4 + n4];
59+
}
60+
}
61+
62+
#else // INPUT_TEXTURE
63+
64+
void load_matrix_tile(
65+
out FPOutTile tile,
66+
const int n4,
67+
const int m_start,
68+
const int N4) {
69+
[[unroll]] for (int m = 0; m < TILE_M; m++) {
70+
tile.data[m][0] = texelFetch(
71+
t_input, ivec3(n4, m_start + m, 0), 0);
72+
}
73+
}
74+
75+
#endif // INPUT_BUFFER
76+
77+
void main() {
78+
// Each thread loads and writes a 4 wide x 4 high block of the matrix
79+
const int n4 = int(gl_GlobalInvocationID.x);
80+
const int m4 = int(gl_GlobalInvocationID.y);
81+
82+
const int n = mul_4(n4);
83+
const int m = mul_4(m4);
84+
85+
if (n >= matrix_sizes.x || m >= matrix_sizes.y) {
86+
return;
87+
}
88+
89+
FPOutTile tile;
90+
91+
const int N4 = div_4(matrix_sizes.x);
92+
load_matrix_tile(tile, n4, m, N4);
93+
write_im2col_tile_as_image(tile, n4, m);
94+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
col2im:
8+
parameter_names_with_default_values:
9+
DTYPE: float
10+
OUTPUT_STORAGE: texture3d
11+
INPUT_STORAGE: buffer
12+
generate_variant_forall:
13+
DTYPE:
14+
- VALUE: half
15+
- VALUE: float
16+
shader_variants:
17+
- NAME: col2im_texture3d_buffer
18+
- NAME: col2im_texture3d_texture3d
19+
INPUT_STORAGE: texture3d
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#ifndef COMMON_GLSLH
10+
#define COMMON_GLSLH
11+
12+
#define align_up_4(x) ((x + 3) & -4)
13+
14+
#define div_up_4(x) (((x) + 3) >> 2)
15+
16+
#define mul_4(x) ((x) << 2)
17+
#define div_4(x) ((x) >> 2)
18+
19+
#define mod_4(x) ((x) & 3)
20+
21+
struct TensorIndex4D {
22+
ivec4 data;
23+
};
24+
25+
#ifdef DEBUG_MODE
26+
27+
#extension GL_EXT_debug_printf : require
28+
29+
void printTensorIndex4D(const TensorIndex4D index) {
30+
debugPrintfEXT(
31+
"tensor_idx: %d, %d, %d, %d\\n",
32+
index.data.x,
33+
index.data.y,
34+
index.data.z,
35+
index.data.w);
36+
}
37+
38+
#endif // DEBUG_MODE
39+
40+
#endif // COMMON_GLSLH
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#ifndef CONV2D_COMMON_GLSLH
10+
#define CONV2D_COMMON_GLSLH
11+
12+
#include "common.glslh"
13+
14+
struct Conv2DParams {
15+
ivec2 kernel_size;
16+
ivec2 stride;
17+
ivec2 padding;
18+
ivec2 dilation;
19+
int groups;
20+
int out_channels_per_group;
21+
int in_channels_per_group;
22+
int logical_K_per_group;
23+
int K_per_group;
24+
int K4_per_group;
25+
int logical_K;
26+
int K;
27+
int K4;
28+
};
29+
30+
#ifdef DEBUG_MODE
31+
32+
void printConv2DParams(const Conv2DParams params) {
33+
debugPrintfEXT("Conv2DParams: \\n");
34+
debugPrintfEXT(
35+
" kernel_size: %d, %d\\n", params.kernel_size.x, params.kernel_size.y);
36+
debugPrintfEXT(" stride: %d, %d\\n", params.stride.x, params.stride.y);
37+
debugPrintfEXT(" padding: %d, %d\\n", params.padding.x, params.padding.y);
38+
debugPrintfEXT(" dilation: %d, %d\\n", params.dilation.x, params.dilation.y);
39+
debugPrintfEXT(" groups: %d\\n", params.groups);
40+
debugPrintfEXT(
41+
" out_channels_per_group: %d\\n", params.out_channels_per_group);
42+
debugPrintfEXT(
43+
" in_channels_per_group: %d\\n", params.in_channels_per_group);
44+
debugPrintfEXT(" logical_K_per_group: %d\\n", params.logical_K_per_group);
45+
debugPrintfEXT(" K_per_group: %d\\n", params.K_per_group);
46+
debugPrintfEXT(" K4_per_group: %d\\n", params.K4_per_group);
47+
}
48+
49+
#endif // DEBUG_MODE
50+
51+
#endif // CONV2D_COMMON_GLSLH
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#ifndef CONV2D_FP_IM2COL_BLOCK
10+
#define CONV2D_FP_IM2COL_BLOCK
11+
12+
/*
13+
* Defines utilities to convert between (col, row) indices of an im2col matrix
14+
* and 4-dimension tensor indices of image tensors.
15+
*
16+
* Requires:
17+
* - output_sizes to be defined in the shader layout, corresponding to the sizes
18+
* of the output image of the convolution op.
19+
* - image_sizes to be defined in the shader layout, corresponding to the sizes
20+
* of the input image of the convolution op.
21+
* - conv2d_params to be defined in the shader layout
22+
*/
23+
24+
#extension GL_EXT_control_flow_attributes : require
25+
26+
#include "common.glslh"
27+
#include "conv2d_common.glslh"
28+
29+
struct Im2ColMatrixIdx {
30+
int row;
31+
int col;
32+
// Relevant for grouped convolution. This indicates the column index relative
33+
// to the first column in the group.
34+
int col_idx_in_group;
35+
int group_idx;
36+
};
37+
38+
void unwrap_m(out TensorIndex4D out_tidx_base, const int m) {
39+
out_tidx_base.data[3] = m / (output_sizes.y * output_sizes.x);
40+
out_tidx_base.data[1] = (m / output_sizes.x) % output_sizes.y;
41+
out_tidx_base.data[0] = m % output_sizes.x;
42+
43+
// Initialize channels to 0; assume it will be set later on
44+
out_tidx_base.data[2] = 0;
45+
}
46+
47+
void im2col_tidx_to_output_tidx(
48+
out TensorIndex4D output_tidx,
49+
const Im2ColMatrixIdx im2col_tidx) {
50+
unwrap_m(output_tidx, im2col_tidx.row);
51+
// Set channels
52+
output_tidx.data.z = im2col_tidx.col;
53+
}
54+
55+
/*
56+
* Converts im2col matrix position to corresponding 4D tensor index, accounting
57+
* for grouped convolutions. The conversion should ensure that all data within
58+
* the same group occupy a contiguous block in memory.
59+
*/
60+
void im2col_idx_to_input_tidx(
61+
out TensorIndex4D input_tidx,
62+
const Im2ColMatrixIdx im2col_idx) {
63+
TensorIndex4D output_tidx;
64+
unwrap_m(output_tidx, im2col_idx.row);
65+
66+
const int in_channels_per_group = conv2d_params.in_channels_per_group;
67+
// Determine the corresponding position within the convolution window based
68+
// on the col index (more specifically, the col index within the group)
69+
const int channel_within_group =
70+
im2col_idx.col_idx_in_group % in_channels_per_group;
71+
const int kernel_x = (im2col_idx.col_idx_in_group / in_channels_per_group) %
72+
conv2d_params.kernel_size.x;
73+
const int kernel_y = im2col_idx.col_idx_in_group /
74+
(in_channels_per_group * conv2d_params.kernel_size.x);
75+
76+
// Calculate the actual input channel index
77+
const int channel_idx =
78+
im2col_idx.group_idx * conv2d_params.in_channels_per_group +
79+
channel_within_group;
80+
81+
// Calculate corresponding input coordinates based on output position
82+
// associated with the row index.
83+
const int input_y = int(output_tidx.data.y * conv2d_params.stride.y) -
84+
int(conv2d_params.padding.y) + int(kernel_y * conv2d_params.dilation.y);
85+
const int input_x = int(output_tidx.data.x * conv2d_params.stride.x) -
86+
int(conv2d_params.padding.x) + int(kernel_x * conv2d_params.dilation.x);
87+
88+
input_tidx.data = ivec4(input_x, input_y, channel_idx, output_tidx.data.w);
89+
}
90+
91+
// 4x4 block of the im2col matrix
92+
struct FPIm2ColBlock {
93+
VEC4_T data[4];
94+
};
95+
96+
#endif // CONV2D_FP_IM2COL_BLOCK

0 commit comments

Comments
 (0)