Skip to content

Commit a0a6278

Browse files
authored
Modify depthwise int8 conv2d to reduce register/memory pressure
Differential Revision: D88183224 Pull Request resolved: #16054
1 parent b6c1837 commit a0a6278

File tree

2 files changed

+97
-150
lines changed

2 files changed

+97
-150
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_q8_utils.glslh

Lines changed: 44 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,6 @@
1111

1212
#extension GL_EXT_control_flow_attributes : require
1313

14-
struct InputWindow1D {
15-
vec4[MAX_WINDOW_WIDTH] data;
16-
int len;
17-
};
18-
19-
InputWindow1D initial_input_window() {
20-
InputWindow1D input_window;
21-
for (int i = 0; i < MAX_WINDOW_WIDTH; ++i) {
22-
input_window.data[i] = vec4(0);
23-
}
24-
input_window.len = 0;
25-
return input_window;
26-
}
27-
2814
vec4 dequantize(const int packed_texel, const float scale, const int zp) {
2915
return vec4(unpack_int8x4(packed_texel) - zp) * scale;
3016
}
@@ -49,109 +35,10 @@ bool in_bounds(
4935
return true;
5036
}
5137

52-
InputWindow1D load_input_window(
53-
const int w_start,
54-
const int w_end,
55-
const int h,
56-
const int c4,
57-
const Conv2dBlockExtents block_extents,
58-
const float input_scale,
59-
const int input_zp,
60-
const ivec4 input_zps) {
61-
InputWindow1D input_window = initial_input_window();
62-
63-
const int block_w_start = div_4(w_start);
64-
const int block_w_end = div_4(w_end);
65-
66-
int window_i = 0;
67-
for (int block_w = block_w_start; block_w <= block_w_end; ++block_w) {
68-
ivec4 input_block = input_zps;
69-
70-
if (in_bounds(block_w, h, c4, block_extents)) {
71-
#ifdef PACKED_INT8_INPUT_BUFFER
72-
const int buffer_idx =
73-
h * block_extents.data_xz + block_w * block_extents.data.z + c4;
74-
input_block = t_packed_int8_input[buffer_idx];
75-
#else
76-
input_block = texelFetch(t_packed_int8_input, ivec3(block_w, h, c4), 0);
77-
#endif
78-
}
79-
80-
const int loaded_w_start = mul_4(block_w);
81-
for (int row = 0; row < 4; ++row) {
82-
if (loaded_w_start + row >= w_start && loaded_w_start + row <= w_end) {
83-
input_window.data[window_i++] =
84-
dequantize(input_block[row], input_scale, input_zp);
85-
}
86-
}
87-
}
88-
input_window.len = window_i;
89-
return input_window;
90-
}
91-
92-
struct WeightRow {
93-
vec4[MAX_KERNEL_WIDTH] data;
94-
int len;
95-
};
96-
97-
WeightRow initial_weight_row() {
98-
WeightRow weight_row;
99-
for (int i = 0; i < MAX_KERNEL_WIDTH; ++i) {
100-
weight_row.data[i] = vec4(0);
101-
}
102-
weight_row.len = 0;
103-
return weight_row;
104-
}
105-
106-
WeightRow load_weight_row(
107-
const int oc4,
108-
const int ky,
109-
const int OC4,
110-
const int Kw,
111-
const int Kw4,
112-
const vec4 weight_scales) {
113-
WeightRow weight_row = initial_weight_row();
114-
115-
int k4 = ky * Kw4;
116-
int row_idx = 0;
117-
for (int w = 0; w < Kw; w += 4) {
118-
#ifdef WEIGHT_BUFFER
119-
const ivec4 weight_block = t_packed_int8_weight[k4 * OC4 + oc4];
120-
#else
121-
const ivec4 weight_block = texelFetch(
122-
t_packed_int8_weight, ivec2(oc4, k4), 0);
123-
#endif
124-
125-
for (int row = 0; row < 4; ++row) {
126-
if (w + row < Kw) {
127-
weight_row.data[row_idx++] = dequantize(weight_block[row], weight_scales);
128-
}
129-
}
130-
k4++;
131-
}
132-
weight_row.len = row_idx;
133-
return weight_row;
134-
}
135-
13638
struct FPOutBlock {
13739
vec4[4] data;
13840
};
13941

140-
void perform_conv1d(
141-
inout FPOutBlock out_block,
142-
const InputWindow1D input_window,
143-
const WeightRow weight_row) {
144-
for (int out_w = 0; out_w < 4; ++out_w) {
145-
[[unroll]] for (int kx = 0; kx < weight_row.len; ++kx) {
146-
const int in_w = out_w * conv2d_params.stride.x;
147-
out_block.data[out_w] = fma(
148-
input_window.data[in_w + kx],
149-
weight_row.data[kx],
150-
out_block.data[out_w]);
151-
}
152-
}
153-
}
154-
15542
ivec4 quantize(
15643
const vec4 texel, const float inv_scale, const int zp) {
15744
vec4 quantized = round(texel * inv_scale) + zp;
@@ -168,6 +55,50 @@ ivec4 quantize_and_pack(
16855
return packed_block;
16956
}
17057

58+
// Load a 4xint8 block of weights. Equivalent to unpacked_weights[kh][kw][c:c+4].
59+
int load_weight_1w4c(
60+
int kw, // w coordinate
61+
int kh, // h coordinate
62+
int oc4, // channel block
63+
int KW4, // kernel width / 4 (rounded up)
64+
int OC4 // out channels count / 4 (rounded up)
65+
) {
66+
67+
// Find the packed block index. Weights are packed as 4W4C tiles.
68+
int kw4 = kw / 4; // W block
69+
int linear_idx = ((kh * KW4 + kw4) * OC4 + oc4) * 4;
70+
int block_x_offset = kw % 4;
71+
#ifdef WEIGHT_BUFFER
72+
return t_packed_int8_weight[linear_idx + block_x_offset];
73+
#else
74+
return texelFetch(t_packed_int8_weight, ivec2(oc4, kh * KW4 + kw4), 0)[block_x_offset];
75+
#endif
76+
}
77+
78+
// Load a 4xint8 block of inputs - channel c through c+3 (c = oc4*4) at
79+
// the given spatial location. Equivalent to unpacked_input[0][c:c+4][h][w].
80+
int load_input_1w4c(
81+
int w, // w coordinate
82+
int h, // h coordinate
83+
int oc4, // channel block
84+
int OC4, // out channels / 4 (rounded up)
85+
Conv2dBlockExtents block_extents
86+
) {
87+
int block_w = w / 4;
88+
89+
if (in_bounds(block_w, h, oc4, block_extents) && w >= 0) {
90+
#ifdef PACKED_INT8_INPUT_BUFFER
91+
const int buffer_idx =
92+
(h * block_extents.data_xz + block_w * block_extents.data.z + oc4) * 4 + (w % 4);
93+
return t_packed_int8_input[buffer_idx];
94+
#else
95+
#error Unimplemented
96+
#endif
97+
} else {
98+
return pack_into_int32(ivec4(input_zp));
99+
}
100+
}
101+
171102
#ifdef DEBUG_MODE
172103

173104
void printInputWindow1D(const InputWindow1D input_window) {

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_q8ta_q8csw_q8to.glsl

Lines changed: 53 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ layout(std430) buffer;
2828
#include "conv2d_common.glslh"
2929

3030
${layout_declare_tensor(B, "w", "t_packed_int8_output", "int", IO_STORAGE, is_scalar_array=False)}
31-
${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", IO_STORAGE, is_scalar_array=False)}
32-
${layout_declare_tensor(B, "r", "t_packed_int8_weight", "int", WEIGHT_STORAGE, is_scalar_array=False)}
31+
${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", IO_STORAGE, is_scalar_array=True)}
32+
${layout_declare_tensor(B, "r", "t_packed_int8_weight", "int", WEIGHT_STORAGE, is_scalar_array=True)}
3333
${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_array=False)}
3434
${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)}
3535
${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)}
@@ -62,60 +62,76 @@ void main() {
6262
return;
6363
}
6464

65+
const int out_h = out_block_idx.data.y;
6566
const int out_w = mul_4(out_block_idx.data.x);
66-
const int w_start =
67-
(out_w * conv2d_params.stride.x) - conv2d_params.padding.x;
68-
const int w_end = ((out_w + 3) * conv2d_params.stride.x) -
69-
conv2d_params.padding.x +
70-
(conv2d_params.kernel_size.x - 1) * conv2d_params.dilation.x;
7167

7268
Conv2dBlockExtents in_block_extents = make_block_extents(input_sizes);
7369

74-
const ivec4 input_zps = ivec4(pack_into_int32(ivec4(input_zp)));
75-
const vec4 weight_scales = vec4(t_weight_scales[out_block_idx.data.z]);
76-
7770
const int Kw4 = div_up_4(conv2d_params.kernel_size.x);
7871

79-
FPOutBlock out_block;
72+
// Compute 4 channels for 4 output elements.
73+
ivec4 acc[4];
74+
[[unroll]] for (int i = 0; i < 4; ++i) {
75+
acc[i] = ivec4(0);
76+
}
77+
8078
for (int ky = 0; ky < conv2d_params.kernel_size.y; ky++) {
81-
const int out_h = out_block_idx.data.y;
8279
const int h = out_h * conv2d_params.stride.y - conv2d_params.padding.y +
8380
ky * conv2d_params.dilation.y;
8481

85-
InputWindow1D input_window = load_input_window(
86-
w_start,
87-
w_end,
88-
h,
89-
out_block_idx.data.z,
90-
in_block_extents,
91-
input_scale,
92-
input_zp,
93-
input_zps);
94-
95-
WeightRow weight_row = load_weight_row(
96-
out_block_idx.data.z,
97-
ky,
98-
out_block_extents.data.z,
99-
conv2d_params.kernel_size.x,
100-
Kw4,
101-
weight_scales);
102-
103-
perform_conv1d(out_block, input_window, weight_row);
82+
for (int kx = 0; kx < conv2d_params.kernel_size.x; kx++) {
83+
const int w = out_w * conv2d_params.stride.x - conv2d_params.padding.x +
84+
kx * conv2d_params.dilation.x;
85+
86+
// Load and unpack weights.
87+
const int packed_weight_4c = load_weight_1w4c(
88+
kx,
89+
ky,
90+
out_block_idx.data.z,
91+
Kw4,
92+
out_block_extents.data.z
93+
);
94+
95+
const ivec4 weight_4c = unpack_int8x4(packed_weight_4c);
96+
97+
[[unroll]] for (int subtile_w = 0; subtile_w < 4; ++subtile_w) {
98+
ivec4 input_texel = unpack_int8x4(load_input_1w4c(
99+
w + conv2d_params.stride.x * subtile_w,
100+
h,
101+
out_block_idx.data.z,
102+
out_block_extents.data.z,
103+
in_block_extents));
104+
acc[subtile_w] += weight_4c * input_texel;
105+
}
106+
}
107+
}
108+
109+
// Apply input zero point as weight_sum * input_zp.
110+
vec4 weight_sums = vec4(t_weight_sums[out_block_idx.data.z]);
111+
const vec4 weight_scales = vec4(t_weight_scales[out_block_idx.data.z]);
112+
113+
vec4 facc[4];
114+
[[unroll]] for (int subtile_w = 0; subtile_w < 4; ++subtile_w) {
115+
facc[subtile_w] = vec4(acc[subtile_w]);
116+
facc[subtile_w] -= weight_sums * input_zp;
117+
facc[subtile_w] *= weight_scales * input_scale;
104118
}
105119

106120
if (apply_bias > 0) {
107121
const vec4 bias = vec4(t_bias[out_block_idx.data.z]);
108-
for (int row = 0; row < 4; row++) {
109-
out_block.data[row] += bias;
122+
[[unroll]] for (int subtile_w = 0; subtile_w < 4; ++subtile_w) {
123+
facc[subtile_w] += bias;
110124
}
111125
}
112126

113-
const ivec4 packed_out_block = quantize_and_pack(
114-
out_block, output_inv_scale, output_zp);
127+
ivec4 packed_out;
128+
[[unroll]] for (int subtile_w = 0; subtile_w < 4; ++subtile_w) {
129+
packed_out[subtile_w] = pack_into_int32(quantize(facc[subtile_w], output_inv_scale, output_zp));
130+
}
115131

116132
#ifdef PACKED_INT8_OUTPUT_BUFFER
117-
t_packed_int8_output[tid] = packed_out_block;
133+
t_packed_int8_output[tid] = packed_out;
118134
#else
119-
imageStore(t_packed_int8_output, out_block_idx.data, packed_out_block);
135+
imageStore(t_packed_int8_output, out_block_idx.data, packed_out);
120136
#endif
121137
}

0 commit comments

Comments
 (0)