Skip to content

Commit cbea390

Browse files
committed
cont : fix FA vec kernel
ggml-ci
1 parent c648f1f commit cbea390

File tree

2 files changed

+13
-13
lines changed

2 files changed

+13
-13
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4179,7 +4179,7 @@ static void ggml_metal_encode_node(
41794179
// ne00*(nsg)
41804180
// each simdgroup has a full f16 head vector in shared mem to accumulate results
41814181
//
4182-
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 2*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16))
4182+
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16))
41834183

41844184
int64_t nsgmax = 2;
41854185
while (true) {

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3646,16 +3646,16 @@ kernel void kernel_flash_attn_ext_vec(
36463646
constexpr short DV4 = DV/4;
36473647
constexpr short NW = N_SIMDWIDTH;
36483648
constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
3649-
constexpr short SH = 2*C; // shared memory per simdgroup
3649+
constexpr short SH = 4*C; // shared memory per simdgroup
36503650

36513651
const short T = DK + nsg*SH; // shared memory size per query in (half)
36523652

3653-
//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
3654-
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
3655-
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
3656-
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
3657-
threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + C + Q*DK); // scratch buffer for mask
3658-
threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results
3653+
//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
3654+
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
3655+
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
3656+
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
3657+
threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask
3658+
threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results
36593659

36603660
// store the result for all queries in local memory (the O matrix from the paper)
36613661
o4_t lo[DV4/NL];
@@ -3836,7 +3836,7 @@ kernel void kernel_flash_attn_ext_vec(
38363836
v4_t mv;
38373837
deq_v_t4(pv4 + i/nl_v, i%nl_v, mv);
38383838

3839-
lo[ii/NL] += dot((float4) mv, (float4) ms);
3839+
lo[ii/NL] += o4_t(float4(mv)*float4(ms));
38403840
}
38413841
}
38423842
}
@@ -3907,11 +3907,11 @@ kernel void kernel_flash_attn_ext_vec(
39073907
// parallel reduce
39083908
for (short r = nsg/2; r > 0; r >>= 1) {
39093909
if (sgitg < r) {
3910-
const float S0 = ss[ 0];
3911-
const float S1 = ss[r*SH + 0];
3910+
const float S0 = ss[ 0];
3911+
const float S1 = ss[r*(SH/2) + 0];
39123912

3913-
const float M0 = ss[ 1];
3914-
const float M1 = ss[r*SH + 1];
3913+
const float M0 = ss[ 1];
3914+
const float M1 = ss[r*(SH/2) + 1];
39153915

39163916
const float M = max(M0, M1);
39173917

0 commit comments

Comments
 (0)