Skip to content

Commit b265324

Browse files
authored
[ET-VK] Optimize conv2d s1p0 (#14187)
This change improves the execution of the pointwise conv2d s1p0 shader. It does through more of a GEMM-like implementation and employing more explicit loop unrolling. cc @SS-JIA @manuelcandales @cbilgin
1 parent 421539e commit b265324

File tree

3 files changed

+84
-109
lines changed

3 files changed

+84
-109
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl

Lines changed: 80 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212

1313
#define PRECISION ${PRECISION}
1414

15-
#define VEC4_T ${texel_type(DTYPE)}
15+
$if DTYPE == "half":
16+
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
17+
#define VEC4_T f16vec4
18+
$else:
19+
#define VEC4_T ${texel_type(DTYPE)}
1620

17-
#define TILE_SIZE_X uint16_t(${TILE_SIZE_X})
18-
#define TILE_SIZE_Y uint16_t(${TILE_SIZE_Y})
1921

2022
#define op(X, A, B) ${OPERATOR}
2123

@@ -50,119 +52,90 @@ ${layout_declare_spec_const(C, "int", "ngroups", "1")}
5052
* size is only 1x1, making it easier to re-use loaded texels from t_kernel.
5153
*/
5254
void main() {
53-
const int out_limits_scaled[2] =
54-
{(out_limits.x + (TILE_SIZE_X - 1)) / TILE_SIZE_X,
55-
(out_limits.y + (TILE_SIZE_Y - 1)) / TILE_SIZE_Y};
5655

57-
const uint16_t div_by_x = uint16_t(gl_GlobalInvocationID.x / out_limits_scaled[0]);
58-
const uint16_t out_pos_xy[2] = {uint16_t(gl_GlobalInvocationID.x % out_limits_scaled[0]), div_by_x};
59-
const int out_pos_z = int(gl_GlobalInvocationID.y);
56+
int inputAndOutputWidth = out_limits.x;
57+
int inputAndOutputHeight = out_limits.y;
58+
int outputChannel = out_limits.z*4;
6059

61-
// If the top left position is out of bounds, then this invocation will have
62-
// no work to do.
63-
if (out_pos_xy[1] >= out_limits_scaled[1] || out_pos_z >= out_limits.z) {
60+
// Divided by 4 because the input channels are packed
61+
int inputChannel = in_group_size/4;
62+
63+
int threadHW = int(gl_GlobalInvocationID.x);
64+
int threadOutChannel = int(gl_GlobalInvocationID.y);
65+
66+
int xIdx = threadHW % inputAndOutputWidth;
67+
int yIdx = threadHW / inputAndOutputWidth;
68+
69+
if (threadHW >= inputAndOutputWidth * inputAndOutputHeight && threadOutChannel >= outputChannel) {
6470
return;
6571
}
6672

67-
// Output position for TILE_SIZE = 2
68-
// +--------+--------+
69-
// | pos[0] | pos[1] |
70-
// +--------+--------+
71-
// | pos[2] | pos[3] |
72-
// +--------+--------+
73-
uint16_t pos[TILE_SIZE_X * TILE_SIZE_Y * 2];
74-
for (uint16_t y = uint16_t(0), i = uint16_t(0); y < TILE_SIZE_Y; ++y) {
75-
for (uint16_t x = uint16_t(0); x < TILE_SIZE_X; ++x) {
76-
pos[i * 2] = out_pos_xy[0] * TILE_SIZE_X + x;
77-
pos[i * 2 + 1] = out_pos_xy[1] * TILE_SIZE_Y + y;
78-
i++;
79-
}
80-
}
73+
VEC4_T outputTexel = VEC4_T(texelFetch(t_bias, ivec2(threadOutChannel, 0), 0));
8174

82-
// Final output array where each element is a tensor value.
83-
// Tuple of consecutive 4 elements represents a single output texel.
84-
float sum[TILE_SIZE_X * TILE_SIZE_Y * 4];
75+
VEC4_T inputVec;
76+
VEC4_T weight1OutputChannelPacked;
77+
VEC4_T weight2OutputChannelPacked;
78+
VEC4_T weight3OutputChannelPacked;
79+
VEC4_T weight4OutputChannelPacked;
8580

86-
// Initialize the output array with the bias value
87-
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y * 4; i++) {
88-
sum[i] = 0;
89-
}
81+
// By unrolling the loop in sets of 4, this significantly reduces the number of branching instructions
82+
// and enables the compiler to rearrange instructions for more efficient memory retrieval and compute
83+
for (int inputC = 0; inputC < inputChannel; inputC += 1) {
9084

91-
int z4 = 0;
92-
// Since the kernel is 1x1, we only have to loop over the depth dimension.
93-
for (int z = 0; z < in_group_size; z += 4, ++z4) {
94-
// During prepacking, the weight tensor has been permuted so that the
95-
// channel (IC) dim is along the x-axis, and the batch (OC) dim is along
96-
// the z-axis.
97-
float kernel_values[4 * 4]; // 4 channels, 4 elements per channel
98-
99-
// Load kernel values from texels to array
100-
[[unroll]] for (int i = 0; i < 4; ++i) {
101-
const vec4 k_tex = texelFetch(t_kernel, ivec2(z + i, out_pos_z), 0);
102-
kernel_values[i * 4 + 0] = k_tex.x;
103-
kernel_values[i * 4 + 1] = k_tex.y;
104-
kernel_values[i * 4 + 2] = k_tex.z;
105-
kernel_values[i * 4 + 3] = k_tex.w;
106-
}
107-
108-
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
109-
const vec4 in_tex = texelFetch(t_in, ivec3(pos[i * 2], pos[i * 2 + 1], z4), 0);
110-
// Load the input texel into an array
111-
float tex_values[4];
112-
tex_values[0] = in_tex.x;
113-
tex_values[1] = in_tex.y;
114-
tex_values[2] = in_tex.z;
115-
tex_values[3] = in_tex.w;
116-
117-
// For 2x2 tile size algorithm works as follows.
118-
// To explain the calculations below, the contents of one in_tex and the
119-
// group of 4 texels loaded from t_kernel are shown:
120-
//
121-
// in_tex t_kernel
122-
// -x-> ---x--->
123-
// +---+ +----+----+----+----+
124-
// ^ | w | ^ | D0 | D1 | D2 | D3 |
125-
// | +---+ | +----+----+----+----+
126-
// | | z | | | C0 | C1 | C2 | C3 |
127-
// z +---+ z +----+----+----+----+
128-
// | | y | | | B0 | B2 | B2 | B3 |
129-
// | +---+ | +----+----+----+----+
130-
// | x | | A0 | A1 | A2 | A3 |
131-
// +---+ +----+----+----+----+
132-
//
133-
// In the t_kernel graphic, cells sharing the same letter are from
134-
// the same batch/output channel index, and the number denotes a unique
135-
// channel index. To calculate the output texel, the following
136-
// calculation is performed:
137-
//
138-
// +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+
139-
// | x | | D0 | | y | | D1 | | z | | D2 | | w | | D3 |
140-
// +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+
141-
// | x | | C0 | | y | | C1 | | z | | C2 | | w | | C3 |
142-
// +---+X+----+ + +---+X+----+ + +---+X+----+ + +---+X+----+
143-
// | x | | B0 | | y | | B1 | | z | | B2 | | w | | B3 |
144-
// +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+
145-
// | x | | A0 | | y | | A1 | | z | | A2 | | w | | A3 |
146-
// +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+
147-
//
148-
// which is what is expressed in the following calculations. This is done
149-
// for each output position.
150-
for (int j = 0; j < 4; ++j) {
151-
sum[i * 4 + j] = tex_values[0] * kernel_values[0 + j] + sum[i * 4 + j];
152-
sum[i * 4 + j] = tex_values[1] * kernel_values[4 + j] + sum[i * 4 + j];
153-
sum[i * 4 + j] = tex_values[2] * kernel_values[8 + j] + sum[i * 4 + j];
154-
sum[i * 4 + j] = tex_values[3] * kernel_values[12 + j] + sum[i * 4 + j];
155-
}
156-
}
157-
}
85+
inputVec = VEC4_T(texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0));
86+
87+
weight1OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0));
88+
weight2OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0));
89+
weight3OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0));
90+
weight4OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0));
91+
92+
outputTexel[0] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0]));
93+
outputTexel[1] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1]));
94+
outputTexel[2] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2]));
95+
outputTexel[3] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3]));
96+
97+
inputC += 1;
98+
99+
inputVec = VEC4_T(texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0));
158100

159-
const vec4 bias = texelFetch(t_bias, ivec2(out_pos_z, 0), 0);
101+
weight1OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0));
102+
weight2OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0));
103+
weight3OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0));
104+
weight4OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0));
160105

161-
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
162-
const ivec3 pos_l = ivec3(pos[i * 2], pos[i * 2 + 1], out_pos_z);
163-
if (all(lessThan(pos_l.xy, out_limits.xy))) {
164-
const vec4 out_sum = vec4(sum[i * 4], sum[i * 4 + 1], sum[i * 4 + 2], sum[i * 4 + 3]);
165-
imageStore(t_out, pos_l, op(out_sum + bias, out_min, out_max));
166-
}
106+
outputTexel[0] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0]));
107+
outputTexel[1] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1]));
108+
outputTexel[2] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2]));
109+
outputTexel[3] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3]));
110+
111+
inputC += 1;
112+
113+
inputVec = VEC4_T(texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0));
114+
115+
weight1OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0));
116+
weight2OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0));
117+
weight3OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0));
118+
weight4OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0));
119+
120+
outputTexel[0] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0]));
121+
outputTexel[1] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1]));
122+
outputTexel[2] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2]));
123+
outputTexel[3] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3]));
124+
125+
inputC += 1;
126+
127+
inputVec = VEC4_T(texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0));
128+
129+
weight1OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0));
130+
weight2OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0));
131+
weight3OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0));
132+
weight4OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0));
133+
134+
outputTexel[0] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0]));
135+
outputTexel[1] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1]));
136+
outputTexel[2] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2]));
137+
outputTexel[3] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3]));
167138
}
139+
140+
imageStore(t_out, ivec3(xIdx, yIdx, threadOutChannel), op(vec4(outputTexel), out_min, out_max));
168141
}

backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@ conv2d_pw_s1p0:
99
OPERATOR: X
1010
NDIM: 3
1111
DTYPE: float
12-
TILE_SIZE_X: 1
13-
TILE_SIZE_Y: 4
1412
generate_variant_forall:
1513
DTYPE:
1614
- VALUE: half

backends/vulkan/runtime/graph/ops/impl/Convolution.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,10 @@ utils::uvec3 conv2d_global_wg_size(
365365

366366
if (method == Conv2dMethod::Depthwise || method == Conv2dMethod::Pointwise) {
367367
wg_size = {wg_size[0] * wg_size[1], wg_size[2], 1};
368+
369+
if (shader.kernel_name.find("s1p0") != std::string::npos) {
370+
wg_size[0] *= 4;
371+
}
368372
}
369373

370374
return wg_size;

0 commit comments

Comments
 (0)