Skip to content

Commit f3651f3

Browse files
author
ssjia
committed
[ET-VK] Statically quantized convolutions
Pull Request resolved: #14647 ## Changes This diff adds implementations for quantized convolution under the following quantization conditions: * activations statically quantized to 8-bit with per tensor scale and zero point * weights quantized to 8-bit with per channel scales * outputs statically quantized to 8-bit with per tensor scale and zero point 3 different implementations are added, which are selected between based on the input conditions. The first is an direct convolution shader which uses the quantized int8 input directly. The second is an im2col variant, which computes the convolution via a gemm like algorithm by first applying an im2col tranformation on the input tensor. Finally, a specialized implementation is added for depthwise convolutions. ghstack-source-id: 312809805 @exported-using-ghexport Differential Revision: [D83437827](https://our.internmc.facebook.com/intern/diff/D83437827/)
1 parent 049c9fc commit f3651f3

40 files changed

+4277
-121
lines changed

backends/vulkan/runtime/graph/ops/glsl/common.glslh

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,14 @@ int extract_8bit_from_packed_int_le(const int packed, const int i) {
4646
return byte;
4747
}
4848

49+
ivec4 unpack_int8x4(const int packed) {
50+
return ivec4(
51+
extract_8bit_from_packed_int_le(packed, 0),
52+
extract_8bit_from_packed_int_le(packed, 1),
53+
extract_8bit_from_packed_int_le(packed, 2),
54+
extract_8bit_from_packed_int_le(packed, 3));
55+
}
56+
4957
int pack_4xqint_into_int32(
5058
const int val0,
5159
const int val1,
@@ -57,6 +65,13 @@ int pack_4xqint_into_int32(
5765
return packed;
5866
}
5967

68+
int pack_into_int32(const ivec4 quant_vals) {
69+
int packed = ((quant_vals[0] & 0xFF) << 0) | ((quant_vals[1] & 0xFF) << 8) |
70+
((quant_vals[2] & 0xFF) << 16) | ((quant_vals[3] & 0xFF) << 24);
71+
72+
return packed;
73+
}
74+
6075
#ifdef DEBUG_MODE
6176

6277
#extension GL_EXT_debug_printf : require

backends/vulkan/runtime/graph/ops/glsl/conv2d_common.glslh

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,18 @@ Conv2dBlockExtents make_block_extents(const ivec4 tensor_sizes) {
6161
return block_sizes;
6262
}
6363

64+
Conv2dBlockIndex linear_idx_to_block_idx(
65+
const int idx, const Conv2dBlockExtents block_extents) {
66+
Conv2dBlockIndex block_idx;
67+
block_idx.data.z = idx % block_extents.data.z;
68+
69+
const int row = idx / block_extents.data.z;
70+
block_idx.data.x = row % block_extents.data.x;
71+
block_idx.data.y = row / block_extents.data.x;
72+
73+
return block_idx;
74+
}
75+
6476
bool block_idx_out_of_bounds(
6577
const Conv2dBlockIndex block_idx,
6678
const Conv2dBlockExtents block_extents) {
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#ifndef CONV2D_DW_Q8_UTILS_GLSLH
10+
#define CONV2D_DW_Q8_UTILS_GLSLH
11+
12+
#extension GL_EXT_control_flow_attributes : require
13+
14+
struct InputWindow1D {
15+
vec4[MAX_WINDOW_WIDTH] data;
16+
int len;
17+
};
18+
19+
InputWindow1D initial_input_window() {
20+
InputWindow1D input_window;
21+
for (int i = 0; i < MAX_WINDOW_WIDTH; ++i) {
22+
input_window.data[i] = vec4(0);
23+
}
24+
input_window.len = 0;
25+
return input_window;
26+
}
27+
28+
vec4 dequantize(const int packed_texel, const float scale, const int zp) {
29+
return vec4(unpack_int8x4(packed_texel) - zp) * scale;
30+
}
31+
32+
vec4 dequantize(const int packed_texel, const vec4 scales) {
33+
return vec4(unpack_int8x4(packed_texel)) * scales;
34+
}
35+
36+
bool in_bounds(
37+
const int block_w,
38+
const int block_h,
39+
const int block_c4,
40+
const Conv2dBlockExtents block_extents) {
41+
ivec3 idx = ivec3(block_w, block_h, block_c4);
42+
if (any(lessThan(idx, ivec3(0)))) {
43+
return false;
44+
}
45+
if (any(greaterThanEqual(idx, block_extents.data))) {
46+
return false;
47+
}
48+
49+
return true;
50+
}
51+
52+
InputWindow1D load_input_window(
53+
const int w_start,
54+
const int w_end,
55+
const int h,
56+
const int c4,
57+
const Conv2dBlockExtents block_extents,
58+
const float input_scale,
59+
const int input_zp,
60+
const ivec4 input_zps) {
61+
InputWindow1D input_window = initial_input_window();
62+
63+
const int block_w_start = div_4(w_start);
64+
const int block_w_end = div_4(w_end);
65+
66+
int window_i = 0;
67+
for (int block_w = block_w_start; block_w <= block_w_end; ++block_w) {
68+
ivec4 input_block = input_zps;
69+
70+
if (in_bounds(block_w, h, c4, block_extents)) {
71+
#ifdef PACKED_INT8_INPUT_BUFFER
72+
const int buffer_idx =
73+
h * block_extents.data_xz + block_w * block_extents.data.z + c4;
74+
input_block = t_packed_int8_input[buffer_idx];
75+
#else
76+
input_block = texelFetch(t_packed_int8_input, ivec3(block_w, h, c4), 0);
77+
#endif
78+
}
79+
80+
const int loaded_w_start = mul_4(block_w);
81+
for (int row = 0; row < 4; ++row) {
82+
if (loaded_w_start + row >= w_start && loaded_w_start + row <= w_end) {
83+
input_window.data[window_i++] =
84+
dequantize(input_block[row], input_scale, input_zp);
85+
}
86+
}
87+
}
88+
input_window.len = window_i;
89+
return input_window;
90+
}
91+
92+
struct WeightRow {
93+
vec4[MAX_KERNEL_WIDTH] data;
94+
int len;
95+
};
96+
97+
WeightRow initial_weight_row() {
98+
WeightRow weight_row;
99+
for (int i = 0; i < MAX_KERNEL_WIDTH; ++i) {
100+
weight_row.data[i] = vec4(0);
101+
}
102+
weight_row.len = 0;
103+
return weight_row;
104+
}
105+
106+
WeightRow load_weight_row(
107+
const int oc4,
108+
const int ky,
109+
const int OC4,
110+
const int Kw,
111+
const int Kw4,
112+
const vec4 weight_scales) {
113+
WeightRow weight_row = initial_weight_row();
114+
115+
int k4 = ky * Kw4;
116+
int row_idx = 0;
117+
for (int w = 0; w < Kw; w += 4) {
118+
#ifdef WEIGHT_BUFFER
119+
const ivec4 weight_block = t_packed_int8_weight[k4 * OC4 + oc4];
120+
#else
121+
const ivec4 weight_block = texelFetch(
122+
t_packed_int8_weight, ivec2(oc4, k4), 0);
123+
#endif
124+
125+
for (int row = 0; row < 4; ++row) {
126+
if (w + row < Kw) {
127+
weight_row.data[row_idx++] = dequantize(weight_block[row], weight_scales);
128+
}
129+
}
130+
k4++;
131+
}
132+
weight_row.len = row_idx;
133+
return weight_row;
134+
}
135+
136+
struct FPOutBlock {
137+
vec4[4] data;
138+
};
139+
140+
void perform_conv1d(
141+
inout FPOutBlock out_block,
142+
const InputWindow1D input_window,
143+
const WeightRow weight_row) {
144+
for (int out_w = 0; out_w < 4; ++out_w) {
145+
[[unroll]] for (int kx = 0; kx < weight_row.len; ++kx) {
146+
const int in_w = out_w * conv2d_params.stride.x;
147+
out_block.data[out_w] = fma(
148+
input_window.data[in_w + kx],
149+
weight_row.data[kx],
150+
out_block.data[out_w]);
151+
}
152+
}
153+
}
154+
155+
ivec4 quantize(
156+
const vec4 texel, const float inv_scale, const int zp) {
157+
vec4 quantized = round(texel * inv_scale) + zp;
158+
return clamp(ivec4(quantized), -128, 127);
159+
}
160+
161+
ivec4 quantize_and_pack(
162+
FPOutBlock out_block, const float inv_scale, const int zp) {
163+
ivec4 packed_block;
164+
for (int row = 0; row < 4; ++row) {
165+
ivec4 quantized_texel = quantize(out_block.data[row], inv_scale, zp);
166+
packed_block[row] = pack_into_int32(quantized_texel);
167+
}
168+
return packed_block;
169+
}
170+
171+
#ifdef DEBUG_MODE
172+
173+
void printInputWindow1D(const InputWindow1D input_window) {
174+
debugPrintfEXT("InputWindow1D contents (len = %d): \\n", input_window.len);
175+
for (int i = 0; i < min(input_window.len, MAX_WINDOW_WIDTH); ++i) {
176+
debugPrintfEXT(
177+
" [%d]: (%.3f, %.3f, %.3f, %.3f) \\n",
178+
i,
179+
input_window.data[i].x,
180+
input_window.data[i].y,
181+
input_window.data[i].z,
182+
input_window.data[i].w);
183+
}
184+
}
185+
186+
void printWeightRow(const WeightRow weight_row) {
187+
debugPrintfEXT("WeightRow contents (len = %d): \\n", weight_row.len);
188+
for (int i = 0; i < min(weight_row.len, MAX_KERNEL_WIDTH); ++i) {
189+
debugPrintfEXT(
190+
" [%d]: (%.3f, %.3f, %.3f, %.3f) \\n",
191+
i,
192+
weight_row.data[i].x,
193+
weight_row.data[i].y,
194+
weight_row.data[i].z,
195+
weight_row.data[i].w);
196+
}
197+
}
198+
199+
void printFPOutBlock(const FPOutBlock out_block) {
200+
debugPrintfEXT("FPOutBlock contents: \\n");
201+
for (int i = 0; i < 4; ++i) {
202+
debugPrintfEXT(
203+
" [%d]: (%.3f, %.3f, %.3f, %.3f) \\n",
204+
i,
205+
out_block.data[i].x,
206+
out_block.data[i].y,
207+
out_block.data[i].z,
208+
out_block.data[i].w);
209+
}
210+
}
211+
212+
#endif // DEBUG_MODE
213+
214+
#endif // CONV2D_DW_Q8_UTILS_GLSLH
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
#define VEC4_T ${texel_load_type(DTYPE, "buffer")}
13+
#define T ${texel_load_component_type(DTYPE, "buffer")}
14+
15+
$if IO_STORAGE == "buffer":
16+
#define PACKED_INT8_OUTPUT_BUFFER
17+
#define PACKED_INT8_INPUT_BUFFER
18+
$if WEIGHT_STORAGE == "buffer":
19+
#define WEIGHT_BUFFER
20+
21+
#define MAX_WINDOW_WIDTH 12
22+
#define MAX_KERNEL_WIDTH 5
23+
24+
${define_required_extensions(DTYPE)}
25+
26+
layout(std430) buffer;
27+
28+
#include "conv2d_common.glslh"
29+
30+
${layout_declare_tensor(B, "w", "t_packed_int8_output", "int", IO_STORAGE, is_scalar_array=False)}
31+
${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", IO_STORAGE, is_scalar_array=False)}
32+
${layout_declare_tensor(B, "r", "t_packed_int8_weight", "int", WEIGHT_STORAGE, is_scalar_array=False)}
33+
${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_array=False)}
34+
${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)}
35+
${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)}
36+
37+
${layout_declare_ubo(B, "ivec4", "output_sizes")}
38+
${layout_declare_ubo(B, "ivec4", "input_sizes")}
39+
${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")}
40+
41+
layout(push_constant) uniform restrict Block {
42+
float input_scale;
43+
int input_zp;
44+
float output_inv_scale;
45+
int output_zp;
46+
};
47+
48+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
49+
50+
${layout_declare_spec_const(C, "int", "apply_bias", "1")}
51+
52+
#include "conv2d_dw_q8_utils.glslh"
53+
54+
void main() {
55+
const int tid = int(gl_GlobalInvocationID.x);
56+
Conv2dBlockExtents out_block_extents = make_block_extents(output_sizes);
57+
58+
Conv2dBlockIndex out_block_idx = linear_idx_to_block_idx(
59+
tid, out_block_extents);
60+
61+
if (block_idx_out_of_bounds(out_block_idx, out_block_extents)) {
62+
return;
63+
}
64+
65+
const int out_w = mul_4(out_block_idx.data.x);
66+
const int w_start =
67+
(out_w * conv2d_params.stride.x) - conv2d_params.padding.x;
68+
const int w_end = ((out_w + 3) * conv2d_params.stride.x) -
69+
conv2d_params.padding.x +
70+
(conv2d_params.kernel_size.x - 1) * conv2d_params.dilation.x;
71+
72+
Conv2dBlockExtents in_block_extents = make_block_extents(input_sizes);
73+
74+
const ivec4 input_zps = ivec4(pack_into_int32(ivec4(input_zp)));
75+
const vec4 weight_scales = vec4(t_weight_scales[out_block_idx.data.z]);
76+
77+
const int Kw4 = div_up_4(conv2d_params.kernel_size.x);
78+
79+
FPOutBlock out_block;
80+
for (int ky = 0; ky < conv2d_params.kernel_size.y; ky++) {
81+
const int out_h = out_block_idx.data.y;
82+
const int h = out_h * conv2d_params.stride.y - conv2d_params.padding.y +
83+
ky * conv2d_params.dilation.y;
84+
85+
InputWindow1D input_window = load_input_window(
86+
w_start,
87+
w_end,
88+
h,
89+
out_block_idx.data.z,
90+
in_block_extents,
91+
input_scale,
92+
input_zp,
93+
input_zps);
94+
95+
WeightRow weight_row = load_weight_row(
96+
out_block_idx.data.z,
97+
ky,
98+
out_block_extents.data.z,
99+
conv2d_params.kernel_size.x,
100+
Kw4,
101+
weight_scales);
102+
103+
perform_conv1d(out_block, input_window, weight_row);
104+
}
105+
106+
if (apply_bias > 0) {
107+
const vec4 bias = vec4(t_bias[out_block_idx.data.z]);
108+
for (int row = 0; row < 4; row++) {
109+
out_block.data[row] += bias;
110+
}
111+
}
112+
113+
const ivec4 packed_out_block = quantize_and_pack(
114+
out_block, output_inv_scale, output_zp);
115+
116+
#ifdef PACKED_INT8_OUTPUT_BUFFER
117+
t_packed_int8_output[tid] = packed_out_block;
118+
#else
119+
imageStore(t_packed_int8_output, out_block_idx.data, packed_out_block);
120+
#endif
121+
}

0 commit comments

Comments
 (0)