Skip to content

Commit 94accca

Browse files
committed
vec move mask to shmem
1 parent 3b96250 commit 94accca

File tree

2 files changed

+18
-11
lines changed

2 files changed

+18
-11
lines changed

ggml/src/ggml-metal.m

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3297,7 +3297,7 @@ static void ggml_metal_encode_node(
32973297
// ne00*(nsg)
32983298
// each simdgroup has a full f16 head vector in shared mem to accumulate results
32993299
//
3300-
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + ne00*(nsg))*(sizeof(float)/2), 16))
3300+
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 4*ncpsg*(nsg)) + ne00*(nsg))*(sizeof(float)/2), 16))
33013301

33023302
int64_t nsgmax = 2;
33033303

ggml/src/ggml-metal.metal

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2844,7 +2844,7 @@ kernel void kernel_flash_attn_ext(
28442844
const short D8 = D/8;
28452845
const short D16 = D/16;
28462846
const short NW = N_SIMDWIDTH;
2847-
const short SH = (2*C + Q); // shared memory per simdgroup in (half)
2847+
const short SH = (2*C + Q); // shared memory per simdgroup (s_t == float)
28482848

28492849
const short TS = nsg*SH; // shared memory size per query in (s_t == float)
28502850
const short T = D + 2*TS; // shared memory size per query in (half)
@@ -3353,16 +3353,17 @@ kernel void kernel_flash_attn_ext_vec(
33533353
const short D16 = D/16;
33543354
const short NW = N_SIMDWIDTH;
33553355
const short NW4 = NW/4;
3356-
const short SH = C; // shared memory per simdgroup in (half)
3356+
const short SH = 2*C; // shared memory per simdgroup
33573357

33583358
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
33593359

3360-
//threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data
3361-
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in half4
3362-
threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shared + 0*D); // same as above but in half4x4
3363-
threadgroup s_t * ss = (threadgroup s_t *) (shared + 2*sgitg*SH + Q*D); // scratch buffer for attention
3364-
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shared + 2*sgitg*SH + Q*D); // same as above but in half4
3365-
threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shared + sgitg*D + Q*T); // scratch buffer for the results
3360+
//threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data
3361+
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t
3362+
threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shared + 0*D); // same as above but in q4x4_t
3363+
threadgroup s_t * ss = (threadgroup s_t *) (shared + 2*sgitg*SH + Q*D); // scratch buffer for attention
3364+
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shared + 2*sgitg*SH + Q*D); // same as above but in s4_t
3365+
threadgroup half * sm = (threadgroup half *) (shared + 2*sgitg*SH + SH + Q*D); // scratch buffer for mask
3366+
threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shared + sgitg*D + Q*T); // scratch buffer for the results
33663367

33673368
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
33683369
o4x4_t lo[D16/NW4];
@@ -3412,8 +3413,10 @@ kernel void kernel_flash_attn_ext_vec(
34123413
mq[ii/NW4] = sq4x4[ii + tx];
34133414
}
34143415

3416+
const bool has_mask = mask != q;
3417+
34153418
// pointer to the mask
3416-
device const half * mp = (device const half *) (mask + iq1*nb31);
3419+
device const half * pm = (device const half *) (mask + iq1*nb31);
34173420

34183421
half slope = 1.0f;
34193422

@@ -3435,6 +3438,10 @@ kernel void kernel_flash_attn_ext_vec(
34353438
break;
34363439
}
34373440

3441+
if (has_mask) {
3442+
sm[tiisg] = pm[ic + tiisg];
3443+
}
3444+
34383445
// Q*K^T
34393446
{
34403447
// each simdgroup processes 1 query and 4 keys
@@ -3476,7 +3483,7 @@ kernel void kernel_flash_attn_ext_vec(
34763483
mqk = logit_softcap*precise::tanh(mqk);
34773484
}
34783485

3479-
mqk += (s_t) ((mask != q) ? ((float) mp[ic + 4*cc + ty])*slope : (float) 0.0f);
3486+
mqk += sm[4*cc + ty]*slope;
34803487

34813488
ss[4*cc + ty] = mqk;
34823489
}

0 commit comments

Comments
 (0)