Skip to content

Commit b2c08c9

Browse files
authored
metal : mark FA blocks (ggml-org#16372)
* metal : better unroll in the FA kernels * metal : index FA blocks * tests : restore [no ci] * metal : prevent division by zero in FA kernels * metal : fix -INF detection logic
1 parent 7fdd16b commit b2c08c9

File tree

7 files changed

+324
-65
lines changed

7 files changed

+324
-65
lines changed

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -959,7 +959,53 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
959959
//ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21);
960960
//ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22);
961961
//ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23);
962-
ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 24);
962+
//ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_PAD + 24);
963+
ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 25);
964+
965+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
966+
967+
ggml_metal_cv_free(cv);
968+
969+
return res;
970+
}
971+
972+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk(
973+
ggml_metal_library_t lib,
974+
const struct ggml_tensor * op,
975+
int32_t nqptg,
976+
int32_t ncpsg) {
977+
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
978+
GGML_UNUSED(op);
979+
980+
char base[256];
981+
char name[256];
982+
983+
snprintf(base, 256, "kernel_%s",
984+
"flash_attn_ext_blk");
985+
986+
snprintf(name, 256, "%s_nqptg=%d_ncpsg=%d",
987+
base,
988+
nqptg,
989+
ncpsg);
990+
991+
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
992+
if (res) {
993+
return res;
994+
}
995+
996+
ggml_metal_cv_t cv = ggml_metal_cv_init();
997+
998+
//ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_BLK + 0);
999+
//ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_BLK + 1);
1000+
//ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_BLK + 2);
1001+
//ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_BLK + 3);
1002+
1003+
//ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_BLK + 20);
1004+
//ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_BLK + 21);
1005+
//ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_BLK + 22);
1006+
//ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_BLK + 23);
1007+
ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_BLK + 24);
1008+
ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_BLK + 25);
9631009

9641010
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
9651011

ggml/src/ggml-metal/ggml-metal-device.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,12 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
141141
bool has_mask,
142142
int32_t ncpsg);
143143

144+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk(
145+
ggml_metal_library_t lib,
146+
const struct ggml_tensor * op,
147+
int32_t nqptg,
148+
int32_t ncpsg);
149+
144150
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
145151
ggml_metal_library_t lib,
146152
const struct ggml_tensor * op,

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,19 @@
7070

7171
// function constants offsets
7272
#define FC_FLASH_ATTN_EXT_PAD 100
73-
#define FC_FLASH_ATTN_EXT 200
74-
#define FC_FLASH_ATTN_EXT_VEC 300
75-
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 400
76-
#define FC_MUL_MV 500
77-
#define FC_MUL_MM 600
73+
#define FC_FLASH_ATTN_EXT_BLK 200
74+
#define FC_FLASH_ATTN_EXT 300
75+
#define FC_FLASH_ATTN_EXT_VEC 400
76+
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 500
77+
#define FC_MUL_MV 600
78+
#define FC_MUL_MM 700
79+
80+
// op-specific constants
81+
#define OP_FLASH_ATTN_EXT_NQPTG 8
82+
#define OP_FLASH_ATTN_EXT_NCPSG 64
83+
84+
#define OP_FLASH_ATTN_EXT_VEC_NQPTG 1
85+
#define OP_FLASH_ATTN_EXT_VEC_NCPSG 32
7886

7987
// kernel argument structs
8088
//
@@ -263,6 +271,17 @@ typedef struct {
263271
uint64_t nb33;
264272
} ggml_metal_kargs_flash_attn_ext_pad;
265273

274+
typedef struct {
275+
int32_t ne01;
276+
int32_t ne30;
277+
int32_t ne31;
278+
int32_t ne32;
279+
int32_t ne33;
280+
uint64_t nb31;
281+
uint64_t nb32;
282+
uint64_t nb33;
283+
} ggml_metal_kargs_flash_attn_ext_blk;
284+
266285
typedef struct {
267286
int32_t ne01;
268287
int32_t ne02;

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 99 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1918,19 +1918,19 @@ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
19181918
const bool has_mask = op->src[3] != nullptr;
19191919

19201920
if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
1921-
const bool has_kvpad = ne11 % 32 != 0;
1921+
const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
19221922

19231923
if (has_kvpad) {
1924-
res += 32*(
1924+
res += OP_FLASH_ATTN_EXT_VEC_NCPSG*(
19251925
nb11*ne12*ne13 +
19261926
nb21*ne22*ne23 +
19271927
(has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
19281928
}
19291929
} else {
1930-
const bool has_kvpad = ne11 % 64 != 0;
1930+
const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0;
19311931

19321932
if (has_kvpad) {
1933-
res += 64*(
1933+
res += OP_FLASH_ATTN_EXT_NCPSG*(
19341934
nb11*ne12*ne13 +
19351935
nb21*ne22*ne23 +
19361936
(has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
@@ -1940,6 +1940,44 @@ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
19401940
return res;
19411941
}
19421942

1943+
size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) {
1944+
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
1945+
1946+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1947+
//GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1948+
//GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1949+
//GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1950+
//GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
1951+
//GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
1952+
GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
1953+
GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
1954+
1955+
size_t res = 0;
1956+
1957+
const bool has_mask = op->src[3] != nullptr;
1958+
1959+
if (!has_mask) {
1960+
return res;
1961+
}
1962+
1963+
const bool is_vec = ggml_metal_op_flash_attn_ext_use_vec(op);
1964+
1965+
// this optimization is not useful for the vector kernels
1966+
if (is_vec) {
1967+
return res;
1968+
}
1969+
1970+
const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG;
1971+
const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;
1972+
1973+
const int64_t ne1 = (ne01 + nqptg - 1)/nqptg;
1974+
const int64_t ne0 = (ne30 + ncpsg - 1)/ncpsg;
1975+
1976+
res += GGML_PAD(ggml_type_size(GGML_TYPE_I8)*ne0*ne1*ne32*ne33, 32);
1977+
1978+
return res;
1979+
}
1980+
19431981
size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) {
19441982
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
19451983

@@ -2034,18 +2072,23 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
20342072
ggml_metal_buffer_id bid_pad = bid_dst;
20352073
bid_pad.offs += ggml_nbytes(op);
20362074

2037-
ggml_metal_buffer_id bid_tmp = bid_pad;
2038-
bid_tmp.offs += ggml_metal_op_flash_attn_ext_extra_pad(op);
2075+
ggml_metal_buffer_id bid_blk = bid_pad;
2076+
bid_blk.offs += ggml_metal_op_flash_attn_ext_extra_pad(op);
2077+
2078+
ggml_metal_buffer_id bid_tmp = bid_blk;
2079+
bid_tmp.offs += ggml_metal_op_flash_attn_ext_extra_blk(op);
20392080

20402081
if (!ggml_metal_op_flash_attn_ext_use_vec(op)) {
20412082
// half8x8 kernel
2042-
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
2043-
const int64_t ncpsg = 64; // cache values per simdgroup !! sync with kernel template arguments !!
2083+
const int nqptg = OP_FLASH_ATTN_EXT_NQPTG; // queries per threadgroup
2084+
const int ncpsg = OP_FLASH_ATTN_EXT_NCPSG; // cache values per simdgroup
20442085

20452086
GGML_ASSERT(nqptg <= 32);
20462087
GGML_ASSERT(nqptg % 8 == 0);
20472088
GGML_ASSERT(ncpsg % 32 == 0);
20482089

2090+
bool need_sync = false;
2091+
20492092
const bool has_kvpad = ne11 % ncpsg != 0;
20502093

20512094
if (has_kvpad) {
@@ -2083,11 +2126,46 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
20832126

20842127
ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
20852128

2086-
ggml_metal_op_concurrency_reset(ctx);
2129+
need_sync = true;
20872130
} else {
20882131
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
20892132
}
20902133

2134+
if (has_mask) {
2135+
assert(ggml_metal_op_flash_attn_ext_extra_blk(op) != 0);
2136+
2137+
ggml_metal_kargs_flash_attn_ext_blk args0 = {
2138+
/*.ne01 =*/ ne01,
2139+
/*.ne30 =*/ ne30,
2140+
/*.ne31 =*/ ne31,
2141+
/*.ne32 =*/ ne32,
2142+
/*.ne33 =*/ ne33,
2143+
/*.nb31 =*/ nb31,
2144+
/*.nb32 =*/ nb32,
2145+
/*.nb33 =*/ nb33,
2146+
};
2147+
2148+
ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg);
2149+
2150+
ggml_metal_encoder_set_pipeline(enc, pipeline0);
2151+
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
2152+
ggml_metal_encoder_set_buffer (enc, bid_src3, 1);
2153+
ggml_metal_encoder_set_buffer (enc, bid_blk, 2);
2154+
2155+
const int32_t nblk1 = ((ne01 + nqptg - 1)/nqptg);
2156+
const int32_t nblk0 = ((ne30 + ncpsg - 1)/ncpsg);
2157+
2158+
ggml_metal_encoder_dispatch_threadgroups(enc, nblk0, nblk1, ne32*ne33, 32, 1, 1);
2159+
2160+
need_sync = true;
2161+
} else {
2162+
assert(ggml_metal_op_flash_attn_ext_extra_blk(op) == 0);
2163+
}
2164+
2165+
if (need_sync) {
2166+
ggml_metal_op_concurrency_reset(ctx);
2167+
}
2168+
20912169
const int is_q = ggml_is_quantized(op->src[1]->type) ? 1 : 0;
20922170

20932171
// 2*(2*ncpsg)
@@ -2164,22 +2242,25 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
21642242
ggml_metal_encoder_set_buffer (enc, bid_src3, 4);
21652243
ggml_metal_encoder_set_buffer (enc, bid_src4, 5);
21662244
ggml_metal_encoder_set_buffer (enc, bid_pad, 6);
2167-
ggml_metal_encoder_set_buffer (enc, bid_dst, 7);
2245+
ggml_metal_encoder_set_buffer (enc, bid_blk, 7);
2246+
ggml_metal_encoder_set_buffer (enc, bid_dst, 8);
21682247

21692248
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
21702249

21712250
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03, 32, nsg, 1);
21722251
#undef FATTN_SMEM
21732252
} else {
21742253
// half4x4 kernel
2175-
const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
2176-
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
2177-
const int64_t nkpsg = 1*ncpsg;
2254+
const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPTG; // queries per threadgroup
2255+
const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG; // cache values per simdgroup !! sync with kernel template arguments !!
2256+
const int nkpsg = 1*ncpsg;
21782257

21792258
GGML_ASSERT(nqptg <= 32);
21802259
GGML_ASSERT(nqptg % 1 == 0);
21812260
GGML_ASSERT(ncpsg % 32 == 0);
21822261

2262+
bool need_sync = false;
2263+
21832264
const bool has_kvpad = ne11 % ncpsg != 0;
21842265

21852266
if (has_kvpad) {
@@ -2217,11 +2298,15 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
22172298

22182299
ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
22192300

2220-
ggml_metal_op_concurrency_reset(ctx);
2301+
need_sync = true;
22212302
} else {
22222303
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
22232304
}
22242305

2306+
if (need_sync) {
2307+
ggml_metal_op_concurrency_reset(ctx);
2308+
}
2309+
22252310
// ne00 + 2*ncpsg*(nsg)
22262311
// for each query, we load it as f16 in shared memory (ne00)
22272312
// and store the soft_max values and the mask

ggml/src/ggml-metal/ggml-metal-ops.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ size_t ggml_metal_op_mul_mat_id_extra_ids(const struct ggml_tensor * op);
4040
bool ggml_metal_op_flash_attn_ext_use_vec(const struct ggml_tensor * op);
4141

4242
size_t ggml_metal_op_flash_attn_ext_extra_pad(const struct ggml_tensor * op);
43+
size_t ggml_metal_op_flash_attn_ext_extra_blk(const struct ggml_tensor * op);
4344
size_t ggml_metal_op_flash_attn_ext_extra_tmp(const struct ggml_tensor * op);
4445

4546
int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx);

ggml/src/ggml-metal/ggml-metal.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_
194194
case GGML_OP_FLASH_ATTN_EXT:
195195
{
196196
res += ggml_metal_op_flash_attn_ext_extra_pad(tensor);
197+
res += ggml_metal_op_flash_attn_ext_extra_blk(tensor);
197198
res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor);
198199
} break;
199200
default:

0 commit comments

Comments
 (0)