Skip to content

Commit d576add

Browse files
GregoryComerfacebook-github-bot
authored andcommitted
Modify depthwise int8 conv2d to reduce register/memory pressure (#16054)
Summary: Modify the Vulkan int8 depthwise convolution shader to reduce register pressure by not loading the full input window upfront. On Mali, it seems to spill to main memory with significant performance impact. This is a relatively naive implementation that largely just loads as needed. It shows significant speedup on Mali-G720, though it regresses Adreno performance by 10-20%. There is likely a lot of room for additional optimization here. In particular, optimizing cycle count by reducing bounds checks, looking at improvements to read coalescing, and using spec constants for the major params. Differential Revision: D88183224
1 parent 5daa6e3 commit d576add

File tree

5 files changed

+159
-160
lines changed

5 files changed

+159
-160
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_q8_utils.glslh

Lines changed: 42 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,6 @@
1111

1212
#extension GL_EXT_control_flow_attributes : require
1313

14-
struct InputWindow1D {
15-
vec4[MAX_WINDOW_WIDTH] data;
16-
int len;
17-
};
18-
19-
InputWindow1D initial_input_window() {
20-
InputWindow1D input_window;
21-
for (int i = 0; i < MAX_WINDOW_WIDTH; ++i) {
22-
input_window.data[i] = vec4(0);
23-
}
24-
input_window.len = 0;
25-
return input_window;
26-
}
27-
2814
vec4 dequantize(const int packed_texel, const float scale, const int zp) {
2915
return vec4(unpack_int8x4(packed_texel) - zp) * scale;
3016
}
@@ -49,109 +35,10 @@ bool in_bounds(
4935
return true;
5036
}
5137

52-
InputWindow1D load_input_window(
53-
const int w_start,
54-
const int w_end,
55-
const int h,
56-
const int c4,
57-
const Conv2dBlockExtents block_extents,
58-
const float input_scale,
59-
const int input_zp,
60-
const ivec4 input_zps) {
61-
InputWindow1D input_window = initial_input_window();
62-
63-
const int block_w_start = div_4(w_start);
64-
const int block_w_end = div_4(w_end);
65-
66-
int window_i = 0;
67-
for (int block_w = block_w_start; block_w <= block_w_end; ++block_w) {
68-
ivec4 input_block = input_zps;
69-
70-
if (in_bounds(block_w, h, c4, block_extents)) {
71-
#ifdef PACKED_INT8_INPUT_BUFFER
72-
const int buffer_idx =
73-
h * block_extents.data_xz + block_w * block_extents.data.z + c4;
74-
input_block = t_packed_int8_input[buffer_idx];
75-
#else
76-
input_block = texelFetch(t_packed_int8_input, ivec3(block_w, h, c4), 0);
77-
#endif
78-
}
79-
80-
const int loaded_w_start = mul_4(block_w);
81-
for (int row = 0; row < 4; ++row) {
82-
if (loaded_w_start + row >= w_start && loaded_w_start + row <= w_end) {
83-
input_window.data[window_i++] =
84-
dequantize(input_block[row], input_scale, input_zp);
85-
}
86-
}
87-
}
88-
input_window.len = window_i;
89-
return input_window;
90-
}
91-
92-
struct WeightRow {
93-
vec4[MAX_KERNEL_WIDTH] data;
94-
int len;
95-
};
96-
97-
WeightRow initial_weight_row() {
98-
WeightRow weight_row;
99-
for (int i = 0; i < MAX_KERNEL_WIDTH; ++i) {
100-
weight_row.data[i] = vec4(0);
101-
}
102-
weight_row.len = 0;
103-
return weight_row;
104-
}
105-
106-
WeightRow load_weight_row(
107-
const int oc4,
108-
const int ky,
109-
const int OC4,
110-
const int Kw,
111-
const int Kw4,
112-
const vec4 weight_scales) {
113-
WeightRow weight_row = initial_weight_row();
114-
115-
int k4 = ky * Kw4;
116-
int row_idx = 0;
117-
for (int w = 0; w < Kw; w += 4) {
118-
#ifdef WEIGHT_BUFFER
119-
const ivec4 weight_block = t_packed_int8_weight[k4 * OC4 + oc4];
120-
#else
121-
const ivec4 weight_block = texelFetch(
122-
t_packed_int8_weight, ivec2(oc4, k4), 0);
123-
#endif
124-
125-
for (int row = 0; row < 4; ++row) {
126-
if (w + row < Kw) {
127-
weight_row.data[row_idx++] = dequantize(weight_block[row], weight_scales);
128-
}
129-
}
130-
k4++;
131-
}
132-
weight_row.len = row_idx;
133-
return weight_row;
134-
}
135-
13638
struct FPOutBlock {
13739
vec4[4] data;
13840
};
13941

140-
void perform_conv1d(
141-
inout FPOutBlock out_block,
142-
const InputWindow1D input_window,
143-
const WeightRow weight_row) {
144-
for (int out_w = 0; out_w < 4; ++out_w) {
145-
[[unroll]] for (int kx = 0; kx < weight_row.len; ++kx) {
146-
const int in_w = out_w * conv2d_params.stride.x;
147-
out_block.data[out_w] = fma(
148-
input_window.data[in_w + kx],
149-
weight_row.data[kx],
150-
out_block.data[out_w]);
151-
}
152-
}
153-
}
154-
15542
ivec4 quantize(
15643
const vec4 texel, const float inv_scale, const int zp) {
15744
vec4 quantized = round(texel * inv_scale) + zp;
@@ -168,6 +55,48 @@ ivec4 quantize_and_pack(
16855
return packed_block;
16956
}
17057

58+
// Load a 4xint8 block of weights - channel c through c+3 (c = oc4*4).
59+
// Equivalent to unpacked_weights[ky][kx][c:c+4].
60+
int load_weight(
61+
int kx, // w coordinate to load
62+
int ky, // h coordinate to load
63+
int oc4, // channel block to load
64+
int kw4, // kernel width / 4 (rounded up)
65+
int OC4 // out channels / 4 (rounded up)
66+
) {
67+
68+
// Find the packed block index.
69+
int kx4 = kx / 4; // W block
70+
// Index into the packed weights for a 4W4C block.
71+
int linear_idx = ((ky * kw4 + kx4) * OC4 + oc4) * 4;
72+
int block_x_offset = kx % 4;
73+
return t_packed_int8_weight[linear_idx + block_x_offset];
74+
}
75+
76+
// Load a 4xint8 block of inputs - channel c through c+3 (c = oc4*4) at
77+
// the given spatial location.
78+
int load_input(
79+
int x, // w coordinate
80+
int y, // h coordinate
81+
int oc4, // channel block
82+
int OC4, // out channels / 4 (rounded up)
83+
Conv2dBlockExtents block_extents
84+
) {
85+
int block_w = x / 4;
86+
87+
if (in_bounds(block_w, y, oc4, block_extents) && x >= 0) {
88+
#ifdef PACKED_INT8_INPUT_BUFFER
89+
const int buffer_idx =
90+
(y * block_extents.data_xz + block_w * block_extents.data.z + oc4) * 4 + (x % 4);
91+
return t_packed_int8_input[buffer_idx];
92+
#else
93+
#error Unimplemented
94+
#endif
95+
} else {
96+
return pack_into_int32(ivec4(input_zp));
97+
}
98+
}
99+
171100
#ifdef DEBUG_MODE
172101

173102
void printInputWindow1D(const InputWindow1D input_window) {

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_q8ta_q8csw_q8to.glsl

Lines changed: 105 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ layout(std430) buffer;
2727

2828
#include "conv2d_common.glslh"
2929

30-
${layout_declare_tensor(B, "w", "t_packed_int8_output", "int", IO_STORAGE, is_scalar_array=False)}
31-
${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", IO_STORAGE, is_scalar_array=False)}
32-
${layout_declare_tensor(B, "r", "t_packed_int8_weight", "int", WEIGHT_STORAGE, is_scalar_array=False)}
30+
${layout_declare_tensor(B, "w", "t_packed_int8_output", "int", IO_STORAGE, is_scalar_array=True)}
31+
${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", IO_STORAGE, is_scalar_array=True)}
32+
${layout_declare_tensor(B, "r", "t_packed_int8_weight", "int", WEIGHT_STORAGE, is_scalar_array=True)}
3333
${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_array=False)}
3434
${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)}
3535
${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)}
@@ -62,60 +62,126 @@ void main() {
6262
return;
6363
}
6464

65+
const int out_h = out_block_idx.data.y;
6566
const int out_w = mul_4(out_block_idx.data.x);
66-
const int w_start =
67-
(out_w * conv2d_params.stride.x) - conv2d_params.padding.x;
68-
const int w_end = ((out_w + 3) * conv2d_params.stride.x) -
69-
conv2d_params.padding.x +
70-
(conv2d_params.kernel_size.x - 1) * conv2d_params.dilation.x;
7167

7268
Conv2dBlockExtents in_block_extents = make_block_extents(input_sizes);
7369

74-
const ivec4 input_zps = ivec4(pack_into_int32(ivec4(input_zp)));
75-
const vec4 weight_scales = vec4(t_weight_scales[out_block_idx.data.z]);
76-
7770
const int Kw4 = div_up_4(conv2d_params.kernel_size.x);
7871

79-
FPOutBlock out_block;
72+
// Compute 4 channels for 4 output elements.
73+
ivec4 acc0 = ivec4(0);
74+
ivec4 acc1 = ivec4(0);
75+
ivec4 acc2 = ivec4(0);
76+
ivec4 acc3 = ivec4(0);
77+
8078
for (int ky = 0; ky < conv2d_params.kernel_size.y; ky++) {
81-
const int out_h = out_block_idx.data.y;
8279
const int h = out_h * conv2d_params.stride.y - conv2d_params.padding.y +
8380
ky * conv2d_params.dilation.y;
8481

85-
InputWindow1D input_window = load_input_window(
86-
w_start,
87-
w_end,
88-
h,
89-
out_block_idx.data.z,
90-
in_block_extents,
91-
input_scale,
92-
input_zp,
93-
input_zps);
94-
95-
WeightRow weight_row = load_weight_row(
96-
out_block_idx.data.z,
97-
ky,
98-
out_block_extents.data.z,
99-
conv2d_params.kernel_size.x,
100-
Kw4,
101-
weight_scales);
102-
103-
perform_conv1d(out_block, input_window, weight_row);
82+
for (int kx = 0; kx < conv2d_params.kernel_size.x; kx++) {
83+
const int w = out_w * conv2d_params.stride.x - conv2d_params.padding.x +
84+
kx * conv2d_params.dilation.x;
85+
86+
// Load and unpack weights.
87+
const int packed_weight_4c = load_weight(
88+
kx,
89+
ky,
90+
out_block_idx.data.z,
91+
Kw4,
92+
out_block_extents.data.z
93+
);
94+
95+
const ivec4 weight_4c = unpack_int8x4(packed_weight_4c);
96+
97+
// Load and unpack inputs.
98+
int packed_input0 = load_input(
99+
w,
100+
h,
101+
out_block_idx.data.z,
102+
out_block_extents.data.z,
103+
in_block_extents);
104+
105+
// Compute weight * input for all 4 accumulators.
106+
ivec4 input0 = unpack_int8x4(packed_input0);
107+
acc0 += weight_4c * input0;
108+
109+
int packed_input1 = load_input(
110+
w + conv2d_params.stride.x,
111+
h,
112+
out_block_idx.data.z,
113+
out_block_extents.data.z,
114+
in_block_extents);
115+
116+
ivec4 input1 = unpack_int8x4(packed_input1);
117+
acc1 += weight_4c * input1;
118+
119+
int packed_input2 = load_input(
120+
w + conv2d_params.stride.x * 2,
121+
h,
122+
out_block_idx.data.z,
123+
out_block_extents.data.z,
124+
in_block_extents);
125+
126+
ivec4 input2 = unpack_int8x4(packed_input2);
127+
acc2 += weight_4c * input2;
128+
129+
int packed_input3 = load_input(
130+
w + conv2d_params.stride.x * 3,
131+
h,
132+
out_block_idx.data.z,
133+
out_block_extents.data.z,
134+
in_block_extents);
135+
136+
ivec4 input3 = unpack_int8x4(packed_input3);
137+
acc3 += weight_4c * input3;
138+
}
104139
}
105140

141+
// Apply input zero point as weight_sum * input_zp.
142+
vec4 weight_sums = vec4(t_weight_sums[out_block_idx.data.z]);
143+
const vec4 weight_scales = vec4(t_weight_scales[out_block_idx.data.z]);
144+
145+
vec4 facc0 = vec4(acc0);
146+
facc0 -= weight_sums * input_zp;
147+
facc0 *= weight_scales * input_scale;
148+
149+
vec4 facc1 = vec4(acc1);
150+
facc1 -= weight_sums * input_zp;
151+
facc1 *= weight_scales * input_scale;
152+
153+
vec4 facc2 = vec4(acc2);
154+
facc2 -= weight_sums * input_zp;
155+
facc2 *= weight_scales * input_scale;
156+
157+
vec4 facc3 = vec4(acc3);
158+
facc3 -= weight_sums * input_zp;
159+
facc3 *= weight_scales * input_scale;
160+
106161
if (apply_bias > 0) {
107162
const vec4 bias = vec4(t_bias[out_block_idx.data.z]);
108-
for (int row = 0; row < 4; row++) {
109-
out_block.data[row] += bias;
110-
}
163+
facc0 += bias;
164+
facc1 += bias;
165+
facc2 += bias;
166+
facc3 += bias;
111167
}
112168

113-
const ivec4 packed_out_block = quantize_and_pack(
114-
out_block, output_inv_scale, output_zp);
169+
const ivec4 quantized_out0 = clamp(ivec4(round(facc0 * output_inv_scale) + output_zp), -128, 127);
170+
const ivec4 quantized_out1 = clamp(ivec4(round(facc1 * output_inv_scale) + output_zp), -128, 127);
171+
const ivec4 quantized_out2 = clamp(ivec4(round(facc2 * output_inv_scale) + output_zp), -128, 127);
172+
const ivec4 quantized_out3 = clamp(ivec4(round(facc3 * output_inv_scale) + output_zp), -128, 127);
173+
174+
const int packed_out_subblock0 = pack_into_int32(quantized_out0);
175+
const int packed_out_subblock1 = pack_into_int32(quantized_out1);
176+
const int packed_out_subblock2 = pack_into_int32(quantized_out2);
177+
const int packed_out_subblock3 = pack_into_int32(quantized_out3);
115178

116179
#ifdef PACKED_INT8_OUTPUT_BUFFER
117-
t_packed_int8_output[tid] = packed_out_block;
180+
t_packed_int8_output[tid * 4] = packed_out_subblock0;
181+
t_packed_int8_output[tid * 4 + 1] = packed_out_subblock1;
182+
t_packed_int8_output[tid * 4 + 2] = packed_out_subblock2;
183+
t_packed_int8_output[tid * 4 + 3] = packed_out_subblock3;
118184
#else
119-
imageStore(t_packed_int8_output, out_block_idx.data, packed_out_block);
185+
#error Unimplemented
120186
#endif
121187
}

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_q8ta_q8csw_q8to.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ conv2d_dw_q8ta_q8csw_q8to:
1313
combination:
1414
parameter_names: [IO_STORAGE, WEIGHT_STORAGE]
1515
combos:
16-
- parameter_values: [buffer, texture2d]
16+
- parameter_values: [buffer, buffer]
1717
DTYPE:
1818
- VALUE: float
1919
shader_variants:

backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,7 @@ ValueRef prepack_quantized_conv2d_dw_weight(
527527

528528
std::vector<int64_t> packed_weight_sizes{output_height, output_width};
529529

530-
utils::StorageType storage_type = utils::kTexture2D;
530+
utils::StorageType storage_type = utils::kBuffer;
531531
uint32_t max_extent = graph.context()->adapter_ptr()->max_texture2d_dim();
532532
if (output_width > max_extent * 4 || output_height > max_extent) {
533533
storage_type = utils::kBuffer;

0 commit comments

Comments
 (0)