Skip to content

Commit 0a319bb

Browse files
authored
metal : add support for non-padded FA KV (ggml-org#16148)
* metal : pad K, V and Mask when needed * cont : simplify * cuda : add TODO about KV padding requirement * metal : add comments * metal : remove mask padding requirement
1 parent 1d6092f commit 0a319bb

File tree

9 files changed

+460
-72
lines changed

9 files changed

+460
-72
lines changed

ggml/src/ggml-cuda/fattn.cu

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,12 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
208208

209209
const int cc = ggml_cuda_info().devices[device].cc;
210210

211+
// TODO: temporary until support is extended
212+
// https://github.com/ggml-org/llama.cpp/pull/16148#issuecomment-3343525206
213+
if (K->ne[1] % FATTN_KQ_STRIDE != 0) {
214+
return BEST_FATTN_KERNEL_NONE;
215+
}
216+
211217
switch (K->ne[0]) {
212218
case 64:
213219
case 128:

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -924,13 +924,58 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library
924924
return res;
925925
}
926926

927+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
928+
ggml_metal_library_t lib,
929+
const struct ggml_tensor * op,
930+
bool has_mask,
931+
int32_t ncpsg) {
932+
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
933+
GGML_UNUSED(op);
934+
935+
char base[256];
936+
char name[256];
937+
938+
snprintf(base, 256, "kernel_%s",
939+
"flash_attn_ext_pad");
940+
941+
snprintf(name, 256, "%s_mask=%d_ncpsg=%d",
942+
base,
943+
has_mask,
944+
ncpsg);
945+
946+
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
947+
if (res) {
948+
return res;
949+
}
950+
951+
ggml_metal_cv_t cv = ggml_metal_cv_init();
952+
953+
ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0);
954+
//ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1);
955+
//ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2);
956+
//ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3);
957+
958+
//ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20);
959+
//ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21);
960+
//ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22);
961+
//ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23);
962+
ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 24);
963+
964+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
965+
966+
ggml_metal_cv_free(cv);
967+
968+
return res;
969+
}
970+
927971
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
928972
ggml_metal_library_t lib,
929973
const ggml_tensor * op,
930974
bool has_mask,
931975
bool has_sinks,
932976
bool has_bias,
933977
bool has_scap,
978+
bool has_kvpad,
934979
int32_t nsg) {
935980
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
936981

@@ -943,18 +988,23 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
943988
const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0];
944989
const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0];
945990

991+
// do bounds checks for the mask?
992+
const bool bc_mask = op->src[3] && (op->src[3]->ne[1] % 8 != 0);
993+
946994
snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d",
947995
"flash_attn_ext",
948996
ggml_type_name(op->src[1]->type),
949997
dk,
950998
dv);
951999

952-
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d",
1000+
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=%d_bcm=%d_ns10=%d_ns20=%d_nsg=%d",
9531001
base,
9541002
has_mask,
9551003
has_sinks,
9561004
has_bias,
9571005
has_scap,
1006+
has_kvpad,
1007+
bc_mask,
9581008
ns10,
9591009
ns20,
9601010
nsg);
@@ -970,6 +1020,9 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
9701020
ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1);
9711021
ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2);
9721022
ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3);
1023+
ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4);
1024+
1025+
ggml_metal_cv_set_bool(cv, bc_mask, FC_FLASH_ATTN_EXT + 10);
9731026

9741027
ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20);
9751028
ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21);
@@ -989,6 +1042,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
9891042
bool has_sinks,
9901043
bool has_bias,
9911044
bool has_scap,
1045+
bool has_kvpad,
9921046
int32_t nsg,
9931047
int32_t nwg) {
9941048
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
@@ -1008,12 +1062,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
10081062
dk,
10091063
dv);
10101064

1011-
snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
1065+
snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_scap=%d_kvpad=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
10121066
base,
10131067
has_mask,
10141068
has_sinks,
10151069
has_bias,
10161070
has_scap,
1071+
has_kvpad,
10171072
ns10,
10181073
ns20,
10191074
nsg, nwg);
@@ -1029,6 +1084,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
10291084
ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1);
10301085
ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2);
10311086
ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3);
1087+
ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4);
10321088

10331089
ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20);
10341090
ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21);

ggml/src/ggml-metal/ggml-metal-device.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,13 +135,20 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_me
135135
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op);
136136
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
137137

138+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
139+
ggml_metal_library_t lib,
140+
const struct ggml_tensor * op,
141+
bool has_mask,
142+
int32_t ncpsg);
143+
138144
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
139145
ggml_metal_library_t lib,
140146
const struct ggml_tensor * op,
141147
bool has_mask,
142148
bool has_sinks,
143149
bool has_bias,
144150
bool has_scap,
151+
bool has_kvpad,
145152
int32_t nsg);
146153

147154
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
@@ -151,6 +158,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
151158
bool has_sinks,
152159
bool has_bias,
153160
bool has_scap,
161+
bool has_kvpad,
154162
int32_t nsg,
155163
int32_t nwg);
156164

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,12 @@
6969
#define N_SG_IQ4_XS 2
7070

7171
// function constants offsets
72-
#define FC_FLASH_ATTN_EXT 100
73-
#define FC_FLASH_ATTN_EXT_VEC 200
74-
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 300
75-
#define FC_MUL_MV 400
76-
#define FC_MUL_MM 500
72+
#define FC_FLASH_ATTN_EXT_PAD 100
73+
#define FC_FLASH_ATTN_EXT 200
74+
#define FC_FLASH_ATTN_EXT_VEC 300
75+
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 400
76+
#define FC_MUL_MV 500
77+
#define FC_MUL_MM 600
7778

7879
// kernel argument structs
7980
//
@@ -244,6 +245,24 @@ typedef struct {
244245
int32_t sect_3;
245246
} ggml_metal_kargs_rope;
246247

248+
typedef struct {
249+
int32_t ne11;
250+
int32_t ne_12_2; // assume K and V are same shape
251+
int32_t ne_12_3;
252+
uint64_t nb11;
253+
uint64_t nb12;
254+
uint64_t nb13;
255+
uint64_t nb21;
256+
uint64_t nb22;
257+
uint64_t nb23;
258+
int32_t ne31;
259+
int32_t ne32;
260+
int32_t ne33;
261+
uint64_t nb31;
262+
uint64_t nb32;
263+
uint64_t nb33;
264+
} ggml_metal_kargs_flash_attn_ext_pad;
265+
247266
typedef struct {
248267
int32_t ne01;
249268
int32_t ne02;
@@ -262,6 +281,7 @@ typedef struct {
262281
uint64_t nb21;
263282
uint64_t nb22;
264283
uint64_t nb23;
284+
int32_t ne31;
265285
int32_t ne32;
266286
int32_t ne33;
267287
uint64_t nb31;
@@ -296,6 +316,7 @@ typedef struct {
296316
uint64_t nb21;
297317
uint64_t nb22;
298318
uint64_t nb23;
319+
int32_t ne31;
299320
int32_t ne32;
300321
int32_t ne33;
301322
uint64_t nb31;

0 commit comments

Comments
 (0)