Skip to content

Commit ffa0590

Browse files
SavicStefanStefan Savic
andauthored
vulkan: Add ACC_TYPE_VEC2 implementation (ggml-org#16203)
Signed-off-by: Stefan Savic <[email protected]> Co-authored-by: Stefan Savic <[email protected]>
1 parent 120bf70 commit ffa0590

File tree

1 file changed

+30
-20
lines changed

1 file changed

+30
-20
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -313,12 +313,12 @@ void main() {
313313
sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
314314
}
315315
#else
316-
ACC_TYPE sums[WMITER * TM * WNITER * TN];
316+
ACC_TYPE_VEC2 sums[WMITER * TM * WNITER * TN/2];
317317
FLOAT_TYPE_VEC2 cache_a[WMITER * TM];
318-
FLOAT_TYPE_VEC2 cache_b[TN];
318+
FLOAT_TYPE_VEC2 cache_b;
319319

320-
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
321-
sums[i] = ACC_TYPE(0.0f);
320+
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) {
321+
sums[i] = ACC_TYPE_VEC2(0.0f, 0.0f);
322322
}
323323
#endif
324324

@@ -360,20 +360,22 @@ void main() {
360360
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i];
361361
}
362362
}
363-
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
364-
[[unroll]] for (uint j = 0; j < TN; j++) {
365-
cache_b[j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i];
366-
}
367363

368-
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
369-
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
370-
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
371-
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
372-
sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr].x), ACC_TYPE(cache_b[cc].x), fma(ACC_TYPE(cache_a[wsir * TM + cr].y), ACC_TYPE(cache_b[cc].y), sums[sums_idx]));
364+
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
365+
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
366+
cache_b = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + i];
367+
368+
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
369+
[[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
370+
// [WNITER][TN][WMITER][TM / 2] -> [wsic][cc][wsir][cr]
371+
const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr;
372+
sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y), sums[sums_idx].x));
373+
sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y), sums[sums_idx].y));
373374
}
374375
}
375376
}
376377
}
378+
377379
}
378380
#endif
379381

@@ -388,8 +390,9 @@ void main() {
388390
}
389391
}
390392
#else
391-
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
392-
sums[i] = clamp(sums[i], -ACC_TYPE_MAX, ACC_TYPE_MAX);
393+
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) {
394+
sums[i].x = clamp(sums[i].x, -ACC_TYPE_MAX, ACC_TYPE_MAX);
395+
sums[i].y = clamp(sums[i].y, -ACC_TYPE_MAX, ACC_TYPE_MAX);
393396
}
394397
#endif
395398
#endif
@@ -463,14 +466,21 @@ void main() {
463466

464467
const u16vec2 row_idx = row_ids[row_i - ic * BN];
465468
#endif // MUL_MAT_ID
466-
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
469+
[[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
470+
const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr;
467471
#ifdef MUL_MAT_ID
468-
if (dr_warp + cr < p.M) {
469-
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
472+
if (dr_warp + 2 * cr < p.M) {
473+
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x);
474+
}
475+
if (dr_warp + 2 * cr + 1 < p.M) {
476+
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y);
470477
}
471478
#else
472-
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
473-
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
479+
if (dr_warp + 2 * cr < p.M && dc_warp + cc < p.N) {
480+
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x);
481+
}
482+
if (dr_warp + 2 * cr + 1 < p.M && dc_warp + cc < p.N) {
483+
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y);
474484
}
475485
#endif // MUL_MAT_ID
476486
}

0 commit comments

Comments
 (0)