Skip to content

Commit 2b20016

Browse files
pytorchbotssjia
andauthored
[ET-VK] Statically quantized convolutions (#14668)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #14647 by @SS-JIA ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/SS-JIA/332/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/SS-JIA/332/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/SS-JIA/332/orig Differential Revision: [D83437827](https://our.internmc.facebook.com/intern/diff/D83437827/) @diff-train-skip-merge Co-authored-by: ssjia <[email protected]>
1 parent 84f0c7d commit 2b20016

40 files changed

+4277
-121
lines changed

backends/vulkan/runtime/graph/ops/glsl/common.glslh

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,14 @@ int extract_8bit_from_packed_int_le(const int packed, const int i) {
4646
return byte;
4747
}
4848

49+
ivec4 unpack_int8x4(const int packed) {
50+
return ivec4(
51+
extract_8bit_from_packed_int_le(packed, 0),
52+
extract_8bit_from_packed_int_le(packed, 1),
53+
extract_8bit_from_packed_int_le(packed, 2),
54+
extract_8bit_from_packed_int_le(packed, 3));
55+
}
56+
4957
int pack_4xqint_into_int32(
5058
const int val0,
5159
const int val1,
@@ -57,6 +65,13 @@ int pack_4xqint_into_int32(
5765
return packed;
5866
}
5967

68+
int pack_into_int32(const ivec4 quant_vals) {
69+
int packed = ((quant_vals[0] & 0xFF) << 0) | ((quant_vals[1] & 0xFF) << 8) |
70+
((quant_vals[2] & 0xFF) << 16) | ((quant_vals[3] & 0xFF) << 24);
71+
72+
return packed;
73+
}
74+
6075
#ifdef DEBUG_MODE
6176

6277
#extension GL_EXT_debug_printf : require

backends/vulkan/runtime/graph/ops/glsl/conv2d_common.glslh

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,18 @@ Conv2dBlockExtents make_block_extents(const ivec4 tensor_sizes) {
6161
return block_sizes;
6262
}
6363

64+
Conv2dBlockIndex linear_idx_to_block_idx(
65+
const int idx, const Conv2dBlockExtents block_extents) {
66+
Conv2dBlockIndex block_idx;
67+
block_idx.data.z = idx % block_extents.data.z;
68+
69+
const int row = idx / block_extents.data.z;
70+
block_idx.data.x = row % block_extents.data.x;
71+
block_idx.data.y = row / block_extents.data.x;
72+
73+
return block_idx;
74+
}
75+
6476
bool block_idx_out_of_bounds(
6577
const Conv2dBlockIndex block_idx,
6678
const Conv2dBlockExtents block_extents) {
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#ifndef CONV2D_DW_Q8_UTILS_GLSLH
10+
#define CONV2D_DW_Q8_UTILS_GLSLH
11+
12+
#extension GL_EXT_control_flow_attributes : require
13+
14+
struct InputWindow1D {
15+
vec4[MAX_WINDOW_WIDTH] data;
16+
int len;
17+
};
18+
19+
InputWindow1D initial_input_window() {
20+
InputWindow1D input_window;
21+
for (int i = 0; i < MAX_WINDOW_WIDTH; ++i) {
22+
input_window.data[i] = vec4(0);
23+
}
24+
input_window.len = 0;
25+
return input_window;
26+
}
27+
28+
vec4 dequantize(const int packed_texel, const float scale, const int zp) {
29+
return vec4(unpack_int8x4(packed_texel) - zp) * scale;
30+
}
31+
32+
vec4 dequantize(const int packed_texel, const vec4 scales) {
33+
return vec4(unpack_int8x4(packed_texel)) * scales;
34+
}
35+
36+
bool in_bounds(
37+
const int block_w,
38+
const int block_h,
39+
const int block_c4,
40+
const Conv2dBlockExtents block_extents) {
41+
ivec3 idx = ivec3(block_w, block_h, block_c4);
42+
if (any(lessThan(idx, ivec3(0)))) {
43+
return false;
44+
}
45+
if (any(greaterThanEqual(idx, block_extents.data))) {
46+
return false;
47+
}
48+
49+
return true;
50+
}
51+
52+
InputWindow1D load_input_window(
53+
const int w_start,
54+
const int w_end,
55+
const int h,
56+
const int c4,
57+
const Conv2dBlockExtents block_extents,
58+
const float input_scale,
59+
const int input_zp,
60+
const ivec4 input_zps) {
61+
InputWindow1D input_window = initial_input_window();
62+
63+
const int block_w_start = div_4(w_start);
64+
const int block_w_end = div_4(w_end);
65+
66+
int window_i = 0;
67+
for (int block_w = block_w_start; block_w <= block_w_end; ++block_w) {
68+
ivec4 input_block = input_zps;
69+
70+
if (in_bounds(block_w, h, c4, block_extents)) {
71+
#ifdef PACKED_INT8_INPUT_BUFFER
72+
const int buffer_idx =
73+
h * block_extents.data_xz + block_w * block_extents.data.z + c4;
74+
input_block = t_packed_int8_input[buffer_idx];
75+
#else
76+
input_block = texelFetch(t_packed_int8_input, ivec3(block_w, h, c4), 0);
77+
#endif
78+
}
79+
80+
const int loaded_w_start = mul_4(block_w);
81+
for (int row = 0; row < 4; ++row) {
82+
if (loaded_w_start + row >= w_start && loaded_w_start + row <= w_end) {
83+
input_window.data[window_i++] =
84+
dequantize(input_block[row], input_scale, input_zp);
85+
}
86+
}
87+
}
88+
input_window.len = window_i;
89+
return input_window;
90+
}
91+
92+
struct WeightRow {
93+
vec4[MAX_KERNEL_WIDTH] data;
94+
int len;
95+
};
96+
97+
WeightRow initial_weight_row() {
98+
WeightRow weight_row;
99+
for (int i = 0; i < MAX_KERNEL_WIDTH; ++i) {
100+
weight_row.data[i] = vec4(0);
101+
}
102+
weight_row.len = 0;
103+
return weight_row;
104+
}
105+
106+
WeightRow load_weight_row(
107+
const int oc4,
108+
const int ky,
109+
const int OC4,
110+
const int Kw,
111+
const int Kw4,
112+
const vec4 weight_scales) {
113+
WeightRow weight_row = initial_weight_row();
114+
115+
int k4 = ky * Kw4;
116+
int row_idx = 0;
117+
for (int w = 0; w < Kw; w += 4) {
118+
#ifdef WEIGHT_BUFFER
119+
const ivec4 weight_block = t_packed_int8_weight[k4 * OC4 + oc4];
120+
#else
121+
const ivec4 weight_block = texelFetch(
122+
t_packed_int8_weight, ivec2(oc4, k4), 0);
123+
#endif
124+
125+
for (int row = 0; row < 4; ++row) {
126+
if (w + row < Kw) {
127+
weight_row.data[row_idx++] = dequantize(weight_block[row], weight_scales);
128+
}
129+
}
130+
k4++;
131+
}
132+
weight_row.len = row_idx;
133+
return weight_row;
134+
}
135+
136+
struct FPOutBlock {
137+
vec4[4] data;
138+
};
139+
140+
void perform_conv1d(
141+
inout FPOutBlock out_block,
142+
const InputWindow1D input_window,
143+
const WeightRow weight_row) {
144+
for (int out_w = 0; out_w < 4; ++out_w) {
145+
[[unroll]] for (int kx = 0; kx < weight_row.len; ++kx) {
146+
const int in_w = out_w * conv2d_params.stride.x;
147+
out_block.data[out_w] = fma(
148+
input_window.data[in_w + kx],
149+
weight_row.data[kx],
150+
out_block.data[out_w]);
151+
}
152+
}
153+
}
154+
155+
ivec4 quantize(
156+
const vec4 texel, const float inv_scale, const int zp) {
157+
vec4 quantized = round(texel * inv_scale) + zp;
158+
return clamp(ivec4(quantized), -128, 127);
159+
}
160+
161+
ivec4 quantize_and_pack(
162+
FPOutBlock out_block, const float inv_scale, const int zp) {
163+
ivec4 packed_block;
164+
for (int row = 0; row < 4; ++row) {
165+
ivec4 quantized_texel = quantize(out_block.data[row], inv_scale, zp);
166+
packed_block[row] = pack_into_int32(quantized_texel);
167+
}
168+
return packed_block;
169+
}
170+
171+
#ifdef DEBUG_MODE
172+
173+
void printInputWindow1D(const InputWindow1D input_window) {
174+
debugPrintfEXT("InputWindow1D contents (len = %d): \\n", input_window.len);
175+
for (int i = 0; i < min(input_window.len, MAX_WINDOW_WIDTH); ++i) {
176+
debugPrintfEXT(
177+
" [%d]: (%.3f, %.3f, %.3f, %.3f) \\n",
178+
i,
179+
input_window.data[i].x,
180+
input_window.data[i].y,
181+
input_window.data[i].z,
182+
input_window.data[i].w);
183+
}
184+
}
185+
186+
void printWeightRow(const WeightRow weight_row) {
187+
debugPrintfEXT("WeightRow contents (len = %d): \\n", weight_row.len);
188+
for (int i = 0; i < min(weight_row.len, MAX_KERNEL_WIDTH); ++i) {
189+
debugPrintfEXT(
190+
" [%d]: (%.3f, %.3f, %.3f, %.3f) \\n",
191+
i,
192+
weight_row.data[i].x,
193+
weight_row.data[i].y,
194+
weight_row.data[i].z,
195+
weight_row.data[i].w);
196+
}
197+
}
198+
199+
void printFPOutBlock(const FPOutBlock out_block) {
200+
debugPrintfEXT("FPOutBlock contents: \\n");
201+
for (int i = 0; i < 4; ++i) {
202+
debugPrintfEXT(
203+
" [%d]: (%.3f, %.3f, %.3f, %.3f) \\n",
204+
i,
205+
out_block.data[i].x,
206+
out_block.data[i].y,
207+
out_block.data[i].z,
208+
out_block.data[i].w);
209+
}
210+
}
211+
212+
#endif // DEBUG_MODE
213+
214+
#endif // CONV2D_DW_Q8_UTILS_GLSLH
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
#define VEC4_T ${texel_load_type(DTYPE, "buffer")}
13+
#define T ${texel_load_component_type(DTYPE, "buffer")}
14+
15+
$if IO_STORAGE == "buffer":
16+
#define PACKED_INT8_OUTPUT_BUFFER
17+
#define PACKED_INT8_INPUT_BUFFER
18+
$if WEIGHT_STORAGE == "buffer":
19+
#define WEIGHT_BUFFER
20+
21+
#define MAX_WINDOW_WIDTH 12
22+
#define MAX_KERNEL_WIDTH 5
23+
24+
${define_required_extensions(DTYPE)}
25+
26+
layout(std430) buffer;
27+
28+
#include "conv2d_common.glslh"
29+
30+
${layout_declare_tensor(B, "w", "t_packed_int8_output", "int", IO_STORAGE, is_scalar_array=False)}
31+
${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", IO_STORAGE, is_scalar_array=False)}
32+
${layout_declare_tensor(B, "r", "t_packed_int8_weight", "int", WEIGHT_STORAGE, is_scalar_array=False)}
33+
${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_array=False)}
34+
${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)}
35+
${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)}
36+
37+
${layout_declare_ubo(B, "ivec4", "output_sizes")}
38+
${layout_declare_ubo(B, "ivec4", "input_sizes")}
39+
${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")}
40+
41+
layout(push_constant) uniform restrict Block {
42+
float input_scale;
43+
int input_zp;
44+
float output_inv_scale;
45+
int output_zp;
46+
};
47+
48+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
49+
50+
${layout_declare_spec_const(C, "int", "apply_bias", "1")}
51+
52+
#include "conv2d_dw_q8_utils.glslh"
53+
54+
void main() {
55+
const int tid = int(gl_GlobalInvocationID.x);
56+
Conv2dBlockExtents out_block_extents = make_block_extents(output_sizes);
57+
58+
Conv2dBlockIndex out_block_idx = linear_idx_to_block_idx(
59+
tid, out_block_extents);
60+
61+
if (block_idx_out_of_bounds(out_block_idx, out_block_extents)) {
62+
return;
63+
}
64+
65+
const int out_w = mul_4(out_block_idx.data.x);
66+
const int w_start =
67+
(out_w * conv2d_params.stride.x) - conv2d_params.padding.x;
68+
const int w_end = ((out_w + 3) * conv2d_params.stride.x) -
69+
conv2d_params.padding.x +
70+
(conv2d_params.kernel_size.x - 1) * conv2d_params.dilation.x;
71+
72+
Conv2dBlockExtents in_block_extents = make_block_extents(input_sizes);
73+
74+
const ivec4 input_zps = ivec4(pack_into_int32(ivec4(input_zp)));
75+
const vec4 weight_scales = vec4(t_weight_scales[out_block_idx.data.z]);
76+
77+
const int Kw4 = div_up_4(conv2d_params.kernel_size.x);
78+
79+
FPOutBlock out_block;
80+
for (int ky = 0; ky < conv2d_params.kernel_size.y; ky++) {
81+
const int out_h = out_block_idx.data.y;
82+
const int h = out_h * conv2d_params.stride.y - conv2d_params.padding.y +
83+
ky * conv2d_params.dilation.y;
84+
85+
InputWindow1D input_window = load_input_window(
86+
w_start,
87+
w_end,
88+
h,
89+
out_block_idx.data.z,
90+
in_block_extents,
91+
input_scale,
92+
input_zp,
93+
input_zps);
94+
95+
WeightRow weight_row = load_weight_row(
96+
out_block_idx.data.z,
97+
ky,
98+
out_block_extents.data.z,
99+
conv2d_params.kernel_size.x,
100+
Kw4,
101+
weight_scales);
102+
103+
perform_conv1d(out_block, input_window, weight_row);
104+
}
105+
106+
if (apply_bias > 0) {
107+
const vec4 bias = vec4(t_bias[out_block_idx.data.z]);
108+
for (int row = 0; row < 4; row++) {
109+
out_block.data[row] += bias;
110+
}
111+
}
112+
113+
const ivec4 packed_out_block = quantize_and_pack(
114+
out_block, output_inv_scale, output_zp);
115+
116+
#ifdef PACKED_INT8_OUTPUT_BUFFER
117+
t_packed_int8_output[tid] = packed_out_block;
118+
#else
119+
imageStore(t_packed_int8_output, out_block_idx.data, packed_out_block);
120+
#endif
121+
}

0 commit comments

Comments
 (0)