Skip to content

Commit 9c12ef7

Browse files
committed
vulkan: optimizations for direct convolution
- Empirically choose a better tile size. Reducing BS_K/BS_NPQ helps fill the GPU. The new size should be amenable to using coopmat, too. - Fix shmem bank conflicts. 16B padding should work with coopmat. - Some explicit loop unrolling. - Skip math/stores work for parts of the tile that are OOB. - Apply fastdiv opt. - Disable shuffles for NV.
1 parent 8ad7b3e commit 9c12ef7

File tree

2 files changed

+83
-38
lines changed

2 files changed

+83
-38
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -908,8 +908,22 @@ struct vk_op_conv2d_push_constants {
908908
uint32_t nb1;
909909
uint32_t nb2;
910910
uint32_t nb3;
911+
912+
// init_fastdiv_values constants for dividing by KW, KW*KH, OW, OW*OH
913+
uint32_t KWmp; uint32_t KWL;
914+
uint32_t KWKHmp; uint32_t KWKHL;
915+
uint32_t OWmp; uint32_t OWL;
916+
uint32_t OWOHmp; uint32_t OWOHL;
911917
};
912918

919+
template <> void init_pushconst_fastdiv(vk_op_conv2d_push_constants &p) {
920+
// Compute magic values to divide by KW, KW*KH, OW, OW*OH
921+
init_fastdiv_values(p.KW, p.KWmp, p.KWL);
922+
init_fastdiv_values(p.KW*p.KH, p.KWKHmp, p.KWKHL);
923+
init_fastdiv_values(p.OW, p.OWmp, p.OWL);
924+
init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL);
925+
}
926+
913927
struct vk_op_conv2d_dw_push_constants {
914928
uint32_t ne;
915929
uint32_t batches;
@@ -3052,17 +3066,28 @@ static void ggml_vk_load_shaders(vk_device& device) {
30523066
uint32_t conv2d_BS_K = 128;
30533067
uint32_t conv2d_BS_CRS = 16;
30543068
uint32_t use_collectives = 0; // Enables subgroup ops for preventing the re-calculation of indices.
3069+
uint32_t conv2d_BS_NPQ = 128;
3070+
uint32_t conv2d_TS_K = 8;
3071+
uint32_t conv2d_SHMEM_PAD = 4;
3072+
3073+
if (device->vendor_id == VK_VENDOR_ID_NVIDIA) {
3074+
conv2d_BS_K = 64;
3075+
conv2d_BS_CRS = 32;
3076+
conv2d_BS_NPQ = 32;
3077+
conv2d_TS_K = 4;
3078+
}
3079+
30553080
if (device->subgroup_shuffle &&
3056-
device->vendor_id != VK_VENDOR_ID_INTEL) { // Do not enable collectives on Intel, see PR 14316
3081+
device->vendor_id != VK_VENDOR_ID_INTEL && // Do not enable collectives on Intel, see PR 14316.
3082+
device->vendor_id != VK_VENDOR_ID_NVIDIA) { // Collectives no faster on NVIDIA.
30573083
use_collectives = 1;
30583084
conv2d_BS_CRS = std::min(
30593085
device->subgroup_size,
3060-
conv2d_BS_CRS); // CRS block size should be capped at sugroup size for correctness when shuffle is used.
3086+
conv2d_BS_CRS); // CRS block size should be capped at subgroup size for correctness when shuffle is used.
30613087
}
3062-
uint32_t conv2d_BS_NPQ = 128;
3063-
uint32_t conv2d_TS_K = 8;
3088+
30643089
uint32_t conv2d_shmem_req =
3065-
(conv2d_BS_K * (conv2d_BS_CRS + 1) + conv2d_BS_CRS * (conv2d_BS_NPQ + 1)) * sizeof(float);
3090+
(conv2d_BS_K * (conv2d_BS_CRS + conv2d_SHMEM_PAD) + conv2d_BS_CRS * (conv2d_BS_NPQ + conv2d_SHMEM_PAD)) * sizeof(float);
30663091
if (device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req) {
30673092
conv2d_BS_CRS = 8;
30683093
if (use_collectives) {

ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp

Lines changed: 53 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
#version 450
22

3+
#extension GL_EXT_control_flow_attributes : enable
4+
35
#ifdef USE_COLLECTIVES
46
# extension GL_KHR_shader_subgroup_shuffle : enable
57
#endif
68

79
#include "types.comp"
810

911
// Make spec constant
10-
#define SHMEM_PAD 0
12+
#define SHMEM_PAD 4
1113

1214
// shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j
1315
layout(binding = 0) readonly buffer A {
@@ -56,6 +58,12 @@ layout(push_constant) uniform parameter {
5658
uint32_t nb1;
5759
uint32_t nb2;
5860
uint32_t nb3;
61+
62+
// fastdiv helper values
63+
uint32_t KWmp; uint32_t KWL;
64+
uint32_t KWKHmp; uint32_t KWKHL;
65+
uint32_t OWmp; uint32_t OWL;
66+
uint32_t OWOHmp; uint32_t OWOHL;
5967
}
6068

6169
p;
@@ -131,6 +139,14 @@ uint32_t Br = tid / BS_NPQ;
131139
uint32_t Bc = tid % BS_NPQ;
132140
const uint32_t BrpWg = WG_SIZE / BS_NPQ;
133141

142+
// see init_fastdiv_values in ggml-vulkan.cpp
143+
uint fastdiv(uint n, uint mp, uint L) {
144+
uint msbs, lsbs;
145+
// msbs = mulhi(n, mp)
146+
umulExtended(n, mp, msbs, lsbs);
147+
return (msbs + n) >> L;
148+
}
149+
134150
void main() {
135151
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
136152
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
@@ -151,9 +167,9 @@ void main() {
151167
uint32_t cached_KW_idx;
152168
if (use_collectives == 1) {
153169
cached_CRS_idx = B_idx_CRS * BS_CRS + gl_SubgroupInvocationID;
154-
cached_Cin_idx = cached_CRS_idx / (p.KW * p.KH);
170+
cached_Cin_idx = fastdiv(cached_CRS_idx, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
155171
uint32_t cached_CRS_remainder = (cached_CRS_idx - cached_Cin_idx * p.KW * p.KH);
156-
cached_KH_idx = cached_CRS_remainder / p.KW;
172+
cached_KH_idx = fastdiv(cached_CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
157173
cached_KW_idx = cached_CRS_remainder - cached_KH_idx * p.KW;
158174

159175
CRS_idx_a = subgroupShuffle(cached_CRS_idx, Ac);
@@ -162,16 +178,16 @@ void main() {
162178
KW_idx_a = subgroupShuffle(cached_KW_idx, Ac);
163179
} else {
164180
CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A)
165-
Cin_idx_a = CRS_idx_a / (p.KW * p.KH);
181+
Cin_idx_a = fastdiv(CRS_idx_a, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
166182
uint32_t CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH;
167-
KH_idx_a = CRS_remainder / p.KW;
183+
KH_idx_a = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
168184
KW_idx_a = CRS_remainder - KH_idx_a * p.KW;
169185
}
170186
#else
171187
CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A)
172-
Cin_idx_a = CRS_idx_a / (p.KW * p.KH);
188+
Cin_idx_a = fastdiv(CRS_idx_a, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); / (p.KW * p.KH);
173189
CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH;
174-
KH_idx_a = CRS_remainder / p.KW;
190+
KH_idx_a = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
175191
KW_idx_a = CRS_remainder - KH_idx_a * p.KW;
176192
#endif
177193

@@ -188,13 +204,13 @@ void main() {
188204
Ash[B_ly * Ash_stride + B_lx] = val;
189205
}
190206
/* Load input to B_block: (BS_CRS x BS_NPQ) */
191-
for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) {
207+
[[unroll]] for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) {
192208
uint32_t B_ly = r_offset + Br; /* Row index of B block */
193209
uint32_t B_lx = Bc;
194210
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + B_lx; /* Global NPQ index (column index of B) */
195-
uint32_t N_idx = NPQ_idx / (p.OH * p.OW);
211+
uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW;
196212
uint32_t NPQ_remainder = NPQ_idx - N_idx * p.OH * p.OW;
197-
uint32_t OH_idx = NPQ_remainder / p.OW;
213+
uint32_t OH_idx = fastdiv(NPQ_remainder, p.OWmp, p.OWL); // divide by p.OW;
198214
uint32_t OW_idx = NPQ_remainder - OH_idx * p.OW;
199215

200216
uint32_t CRS_idx_b;
@@ -209,16 +225,16 @@ void main() {
209225
KW_idx_b = subgroupShuffle(cached_KW_idx, r_offset + Br);
210226
} else {
211227
CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */
212-
Cin_idx_b = CRS_idx_b / (p.KW * p.KH);
228+
Cin_idx_b = fastdiv(CRS_idx_b, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
213229
uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH;
214-
KH_idx_b = CRS_remainder / p.KW;
230+
KH_idx_b = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
215231
KW_idx_b = CRS_remainder - KH_idx_b * p.KW;
216232
}
217233
#else
218234
CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */
219-
Cin_idx_b = CRS_idx_b / (p.KW * p.KH);
235+
Cin_idx_b = fastdiv(CRS_idx_b, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
220236
uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH;
221-
KH_idx_b = CRS_remainder / p.KW;
237+
KH_idx_b = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
222238
KW_idx_b = CRS_remainder - KH_idx_b * p.KW;
223239
#endif
224240

@@ -233,32 +249,36 @@ void main() {
233249
Bsh[B_ly * Bsh_stride + B_lx] = val;
234250
}
235251
barrier();
236-
for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) {
237-
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
238-
regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx];
239-
}
240-
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
241-
regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx];
242-
}
243-
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
252+
if (T_y * TS_K < K) {
253+
[[unroll]] for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) {
254+
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
255+
regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx];
256+
}
244257
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
245-
regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]);
258+
regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx];
259+
}
260+
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
261+
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
262+
regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]);
263+
}
246264
}
247265
}
248266
}
249267
barrier();
250268
}
251269
/* Save C* */
252-
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
253-
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
254-
uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly;
255-
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;
256-
uint32_t N_idx = NPQ_idx / (p.OH * p.OW);
257-
uint32_t OH_idx = (NPQ_idx - N_idx * p.OH * p.OW) / p.OW;
258-
uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW;
259-
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3;
260-
if (K_idx < K && NPQ_idx < NPQ) {
261-
dst_data[dst_idx] = regC[T_ly][T_lx];
270+
if (T_y * TS_K < K) {
271+
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
272+
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
273+
uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly;
274+
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;
275+
uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW;
276+
uint32_t OH_idx = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); // divide by p.OW;
277+
uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW;
278+
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3;
279+
if (K_idx < K && NPQ_idx < NPQ) {
280+
dst_data[dst_idx] = regC[T_ly][T_lx];
281+
}
262282
}
263283
}
264284
}

0 commit comments

Comments
 (0)