Skip to content

Commit cd21e1d

Browse files
GregoryComerfacebook-github-bot
authored andcommitted
Modify depthwise int8 conv2d to reduce register/memory pressure
Summary: Modify the Vulkan int8 depthwise convolution shader to reduce register pressure by not loading the full input window upfront. On Mali, it seems to spill to main memory with significant performance impact. This is a relatively naive implementation that largely just loads as needed. It shows significant speedup on Mali-G720, though it regresses Adreno performance by 10-20%. There is likely a lot of room for additional optimization here. In particular, optimizing cycle count by reducing bounds checks, looking at improvements to read coalescing, and using spec constants for the major params. Differential Revision: D88183224
1 parent 33ec615 commit cd21e1d

File tree

5 files changed

+169
-162
lines changed

5 files changed

+169
-162
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_q8_utils.glslh

Lines changed: 0 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,6 @@
1111

1212
#extension GL_EXT_control_flow_attributes : require
1313

14-
struct InputWindow1D {
15-
vec4[MAX_WINDOW_WIDTH] data;
16-
int len;
17-
};
18-
19-
InputWindow1D initial_input_window() {
20-
InputWindow1D input_window;
21-
for (int i = 0; i < MAX_WINDOW_WIDTH; ++i) {
22-
input_window.data[i] = vec4(0);
23-
}
24-
input_window.len = 0;
25-
return input_window;
26-
}
27-
2814
vec4 dequantize(const int packed_texel, const float scale, const int zp) {
2915
return vec4(unpack_int8x4(packed_texel) - zp) * scale;
3016
}
@@ -49,109 +35,10 @@ bool in_bounds(
4935
return true;
5036
}
5137

52-
InputWindow1D load_input_window(
53-
const int w_start,
54-
const int w_end,
55-
const int h,
56-
const int c4,
57-
const Conv2dBlockExtents block_extents,
58-
const float input_scale,
59-
const int input_zp,
60-
const ivec4 input_zps) {
61-
InputWindow1D input_window = initial_input_window();
62-
63-
const int block_w_start = div_4(w_start);
64-
const int block_w_end = div_4(w_end);
65-
66-
int window_i = 0;
67-
for (int block_w = block_w_start; block_w <= block_w_end; ++block_w) {
68-
ivec4 input_block = input_zps;
69-
70-
if (in_bounds(block_w, h, c4, block_extents)) {
71-
#ifdef PACKED_INT8_INPUT_BUFFER
72-
const int buffer_idx =
73-
h * block_extents.data_xz + block_w * block_extents.data.z + c4;
74-
input_block = t_packed_int8_input[buffer_idx];
75-
#else
76-
input_block = texelFetch(t_packed_int8_input, ivec3(block_w, h, c4), 0);
77-
#endif
78-
}
79-
80-
const int loaded_w_start = mul_4(block_w);
81-
for (int row = 0; row < 4; ++row) {
82-
if (loaded_w_start + row >= w_start && loaded_w_start + row <= w_end) {
83-
input_window.data[window_i++] =
84-
dequantize(input_block[row], input_scale, input_zp);
85-
}
86-
}
87-
}
88-
input_window.len = window_i;
89-
return input_window;
90-
}
91-
92-
struct WeightRow {
93-
vec4[MAX_KERNEL_WIDTH] data;
94-
int len;
95-
};
96-
97-
WeightRow initial_weight_row() {
98-
WeightRow weight_row;
99-
for (int i = 0; i < MAX_KERNEL_WIDTH; ++i) {
100-
weight_row.data[i] = vec4(0);
101-
}
102-
weight_row.len = 0;
103-
return weight_row;
104-
}
105-
106-
WeightRow load_weight_row(
107-
const int oc4,
108-
const int ky,
109-
const int OC4,
110-
const int Kw,
111-
const int Kw4,
112-
const vec4 weight_scales) {
113-
WeightRow weight_row = initial_weight_row();
114-
115-
int k4 = ky * Kw4;
116-
int row_idx = 0;
117-
for (int w = 0; w < Kw; w += 4) {
118-
#ifdef WEIGHT_BUFFER
119-
const ivec4 weight_block = t_packed_int8_weight[k4 * OC4 + oc4];
120-
#else
121-
const ivec4 weight_block = texelFetch(
122-
t_packed_int8_weight, ivec2(oc4, k4), 0);
123-
#endif
124-
125-
for (int row = 0; row < 4; ++row) {
126-
if (w + row < Kw) {
127-
weight_row.data[row_idx++] = dequantize(weight_block[row], weight_scales);
128-
}
129-
}
130-
k4++;
131-
}
132-
weight_row.len = row_idx;
133-
return weight_row;
134-
}
135-
13638
struct FPOutBlock {
13739
vec4[4] data;
13840
};
13941

140-
void perform_conv1d(
141-
inout FPOutBlock out_block,
142-
const InputWindow1D input_window,
143-
const WeightRow weight_row) {
144-
for (int out_w = 0; out_w < 4; ++out_w) {
145-
[[unroll]] for (int kx = 0; kx < weight_row.len; ++kx) {
146-
const int in_w = out_w * conv2d_params.stride.x;
147-
out_block.data[out_w] = fma(
148-
input_window.data[in_w + kx],
149-
weight_row.data[kx],
150-
out_block.data[out_w]);
151-
}
152-
}
153-
}
154-
15542
ivec4 quantize(
15643
const vec4 texel, const float inv_scale, const int zp) {
15744
vec4 quantized = round(texel * inv_scale) + zp;

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_q8ta_q8csw_q8to.glsl

Lines changed: 157 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
#define VEC4_T ${texel_load_type(DTYPE, "buffer")}
1313
#define T ${texel_load_component_type(DTYPE, "buffer")}
1414

15+
// #define DEBUG_MODE 1
16+
17+
#extension GL_EXT_control_flow_attributes : require
18+
1519
$if IO_STORAGE == "buffer":
1620
#define PACKED_INT8_OUTPUT_BUFFER
1721
#define PACKED_INT8_INPUT_BUFFER
@@ -27,9 +31,9 @@ layout(std430) buffer;
2731

2832
#include "conv2d_common.glslh"
2933

30-
${layout_declare_tensor(B, "w", "t_packed_int8_output", "int", IO_STORAGE, is_scalar_array=False)}
31-
${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", IO_STORAGE, is_scalar_array=False)}
32-
${layout_declare_tensor(B, "r", "t_packed_int8_weight", "int", WEIGHT_STORAGE, is_scalar_array=False)}
34+
${layout_declare_tensor(B, "w", "t_packed_int8_output", "int", IO_STORAGE, is_scalar_array=True)}
35+
${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", IO_STORAGE, is_scalar_array=True)}
36+
${layout_declare_tensor(B, "r", "t_packed_int8_weight", "int", WEIGHT_STORAGE, is_scalar_array=True)}
3337
${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_array=False)}
3438
${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)}
3539
${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)}
@@ -51,71 +55,183 @@ ${layout_declare_spec_const(C, "int", "apply_bias", "1")}
5155

5256
#include "conv2d_dw_q8_utils.glslh"
5357

58+
// Load a 4xint8 block of weights - channel c through c+3 (c = oc4/4).
59+
// Equivalent to unpacked_weights[ky][kx][c:c+4].
60+
int load_weight(
61+
int kx, // w coordinate to load
62+
int ky, // h coordinate to load
63+
int oc4, // channel block to load
64+
int kw4, // kernel width / 4 (rounded up)
65+
int OC4 // out channels / 4 (rounded up)
66+
) {
67+
68+
// Find the packed block index.
69+
int kx4 = kx / 4; // W block
70+
// Index into the packed weights for a 4W4C block.
71+
int linear_idx = ((ky * kw4 + kx4) * OC4 + oc4) * 4;
72+
int block_x_offset = kx % 4;
73+
return t_packed_int8_weight[linear_idx + block_x_offset];
74+
}
75+
76+
// Load a 4xint8 block of inputs - channel c through c+3 (c = ic4/4) at
77+
// the given spatial location.
78+
int load_input(
79+
int x, // w coordinate
80+
int y, // h coordinate
81+
int oc4, // channel block
82+
int OC4, // out channels / 4 (rounded up)
83+
Conv2dBlockExtents block_extents
84+
) {
85+
int block_w = x / 4;
86+
87+
if (in_bounds(block_w, y, oc4, block_extents) && x >= 0) {
88+
#ifdef PACKED_INT8_INPUT_BUFFER
89+
const int buffer_idx =
90+
(y * block_extents.data_xz + block_w * block_extents.data.z + oc4) * 4 + (x % 4);
91+
return t_packed_int8_input[buffer_idx];
92+
#else
93+
#error Unimplemented
94+
// input_block = texelFetch(t_packed_int8_input, ivec3(block_w, h, oc4), 0);
95+
#endif
96+
} else {
97+
return pack_into_int32(ivec4(input_zp));
98+
}
99+
}
100+
54101
void main() {
55102
const int tid = int(gl_GlobalInvocationID.x);
56103
Conv2dBlockExtents out_block_extents = make_block_extents(output_sizes);
104+
const int subtile = int(gl_GlobalInvocationID.y); // x offset within a block
57105

58106
Conv2dBlockIndex out_block_idx = linear_idx_to_block_idx(
59107
tid, out_block_extents);
60108

61-
if (block_idx_out_of_bounds(out_block_idx, out_block_extents)) {
109+
if (block_idx_out_of_bounds(out_block_idx, out_block_extents) || subtile >= 2) {
62110
return;
63111
}
64112

65-
const int out_w = mul_4(out_block_idx.data.x);
66-
const int w_start =
67-
(out_w * conv2d_params.stride.x) - conv2d_params.padding.x;
68-
const int w_end = ((out_w + 3) * conv2d_params.stride.x) -
69-
conv2d_params.padding.x +
70-
(conv2d_params.kernel_size.x - 1) * conv2d_params.dilation.x;
113+
const int out_h = out_block_idx.data.y;
114+
const int out_w = mul_4(out_block_idx.data.x) + subtile * 2;
71115

72-
Conv2dBlockExtents in_block_extents = make_block_extents(input_sizes);
116+
// debugPrintfEXT("tid: %d, subtile: %d\\nComputing output %d %d\\n", tid, subtile, out_h, out_w);
73117

74-
const ivec4 input_zps = ivec4(pack_into_int32(ivec4(input_zp)));
75-
const vec4 weight_scales = vec4(t_weight_scales[out_block_idx.data.z]);
118+
Conv2dBlockExtents in_block_extents = make_block_extents(input_sizes);
76119

77120
const int Kw4 = div_up_4(conv2d_params.kernel_size.x);
78121

79-
FPOutBlock out_block;
122+
// Compute 4 channels for 4 output elements.
123+
ivec4 acc0 = ivec4(0);
124+
ivec4 acc1 = ivec4(0);
125+
ivec4 acc2 = ivec4(0);
126+
ivec4 acc3 = ivec4(0);
127+
80128
for (int ky = 0; ky < conv2d_params.kernel_size.y; ky++) {
81-
const int out_h = out_block_idx.data.y;
82129
const int h = out_h * conv2d_params.stride.y - conv2d_params.padding.y +
83130
ky * conv2d_params.dilation.y;
84131

85-
InputWindow1D input_window = load_input_window(
86-
w_start,
87-
w_end,
88-
h,
89-
out_block_idx.data.z,
90-
in_block_extents,
91-
input_scale,
92-
input_zp,
93-
input_zps);
94-
95-
WeightRow weight_row = load_weight_row(
96-
out_block_idx.data.z,
97-
ky,
98-
out_block_extents.data.z,
99-
conv2d_params.kernel_size.x,
100-
Kw4,
101-
weight_scales);
102-
103-
perform_conv1d(out_block, input_window, weight_row);
132+
for (int kx = 0; kx < conv2d_params.kernel_size.x; kx++) {
133+
const int w = out_w * conv2d_params.stride.x - conv2d_params.padding.x +
134+
kx * conv2d_params.dilation.x;
135+
136+
// Load, unpack, and dequantize weights.
137+
const int packed_weight_4c = load_weight(
138+
kx,
139+
ky,
140+
out_block_idx.data.z,
141+
Kw4,
142+
out_block_extents.data.z
143+
);
144+
145+
const ivec4 weight_4c = unpack_int8x4(packed_weight_4c);
146+
147+
// Load, unpack, and dequantize inputs.
148+
int packed_input0 = load_input(
149+
w,
150+
h,
151+
out_block_idx.data.z,
152+
out_block_extents.data.z,
153+
in_block_extents);
154+
155+
ivec4 input0 = unpack_int8x4(packed_input0);
156+
acc0 += weight_4c * input0;
157+
158+
int packed_input1 = load_input(
159+
w + conv2d_params.stride.x,
160+
h,
161+
out_block_idx.data.z,
162+
out_block_extents.data.z,
163+
in_block_extents);
164+
165+
ivec4 input1 = unpack_int8x4(packed_input1);
166+
acc1 += weight_4c * input1;
167+
168+
int packed_input2 = load_input(
169+
w + conv2d_params.stride.x * 2,
170+
h,
171+
out_block_idx.data.z,
172+
out_block_extents.data.z,
173+
in_block_extents);
174+
175+
ivec4 input2 = unpack_int8x4(packed_input2);
176+
acc2 += weight_4c * input2;
177+
178+
int packed_input3 = load_input(
179+
w + conv2d_params.stride.x * 3,
180+
h,
181+
out_block_idx.data.z,
182+
out_block_extents.data.z,
183+
in_block_extents);
184+
185+
ivec4 input3 = unpack_int8x4(packed_input3);
186+
acc3 += weight_4c * input3;
187+
}
104188
}
105189

190+
// Apply input zero point as weight_sum * input_zp.
191+
vec4 weight_sums = vec4(t_weight_sums[out_block_idx.data.z]);
192+
const vec4 weight_scales = vec4(t_weight_scales[out_block_idx.data.z]);
193+
194+
vec4 facc0 = vec4(acc0);
195+
facc0 -= weight_sums * input_zp;
196+
facc0 *= weight_scales * input_scale;
197+
198+
vec4 facc1 = vec4(acc1);
199+
facc1 -= weight_sums * input_zp;
200+
facc1 *= weight_scales * input_scale;
201+
202+
vec4 facc2 = vec4(acc2);
203+
facc2 -= weight_sums * input_zp;
204+
facc2 *= weight_scales * input_scale;
205+
206+
vec4 facc3 = vec4(acc3);
207+
facc3 -= weight_sums * input_zp;
208+
facc3 *= weight_scales * input_scale;
209+
106210
if (apply_bias > 0) {
107211
const vec4 bias = vec4(t_bias[out_block_idx.data.z]);
108-
for (int row = 0; row < 4; row++) {
109-
out_block.data[row] += bias;
110-
}
212+
facc0 += bias;
213+
facc1 += bias;
214+
facc2 += bias;
215+
facc3 += bias;
111216
}
112217

113-
const ivec4 packed_out_block = quantize_and_pack(
114-
out_block, output_inv_scale, output_zp);
218+
const ivec4 quantized_out0 = clamp(ivec4(round(facc0 * output_inv_scale) + output_zp), -128, 127);
219+
const ivec4 quantized_out1 = clamp(ivec4(round(facc1 * output_inv_scale) + output_zp), -128, 127);
220+
const ivec4 quantized_out2 = clamp(ivec4(round(facc2 * output_inv_scale) + output_zp), -128, 127);
221+
const ivec4 quantized_out3 = clamp(ivec4(round(facc3 * output_inv_scale) + output_zp), -128, 127);
222+
223+
const int packed_out_subblock0 = pack_into_int32(quantized_out0);
224+
const int packed_out_subblock1 = pack_into_int32(quantized_out1);
225+
const int packed_out_subblock2 = pack_into_int32(quantized_out2);
226+
const int packed_out_subblock3 = pack_into_int32(quantized_out3);
115227

116228
#ifdef PACKED_INT8_OUTPUT_BUFFER
117-
t_packed_int8_output[tid] = packed_out_block;
229+
t_packed_int8_output[tid * 4] = packed_out_subblock0;
230+
t_packed_int8_output[tid * 4 + 1] = packed_out_subblock1;
231+
t_packed_int8_output[tid * 4 + 2] = packed_out_subblock2;
232+
t_packed_int8_output[tid * 4 + 3] = packed_out_subblock3;
118233
#else
119-
imageStore(t_packed_int8_output, out_block_idx.data, packed_out_block);
234+
// imageStore(t_packed_int8_output, out_block_idx.data, packed_out_block);
235+
#error Unimplemented
120236
#endif
121237
}

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_q8ta_q8csw_q8to.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ conv2d_dw_q8ta_q8csw_q8to:
1313
combination:
1414
parameter_names: [IO_STORAGE, WEIGHT_STORAGE]
1515
combos:
16-
- parameter_values: [buffer, texture2d]
16+
- parameter_values: [buffer, buffer]
1717
DTYPE:
1818
- VALUE: float
1919
shader_variants:

backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,7 @@ ValueRef prepack_quantized_conv2d_dw_weight(
527527

528528
std::vector<int64_t> packed_weight_sizes{output_height, output_width};
529529

530-
utils::StorageType storage_type = utils::kTexture2D;
530+
utils::StorageType storage_type = utils::kBuffer;
531531
uint32_t max_extent = graph.context()->adapter_ptr()->max_texture2d_dim();
532532
if (output_width > max_extent * 4 || output_height > max_extent) {
533533
storage_type = utils::kBuffer;

0 commit comments

Comments
 (0)