Skip to content

Commit defeeb3

Browse files
committed
metal : pad K, V and Mask when needed
1 parent bf6f3b3 commit defeeb3

File tree

8 files changed

+420
-42
lines changed

8 files changed

+420
-42
lines changed

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -918,13 +918,58 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library
918918
return res;
919919
}
920920

921+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
922+
ggml_metal_library_t lib,
923+
const struct ggml_tensor * op,
924+
bool has_mask,
925+
int32_t ncpsg) {
926+
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
927+
GGML_UNUSED(op);
928+
929+
char base[256];
930+
char name[256];
931+
932+
snprintf(base, 256, "kernel_%s",
933+
"flash_attn_ext_pad");
934+
935+
snprintf(name, 256, "%s_mask=%d_ncpsg=%d",
936+
base,
937+
has_mask,
938+
ncpsg);
939+
940+
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
941+
if (res) {
942+
return res;
943+
}
944+
945+
ggml_metal_cv_t cv = ggml_metal_cv_init();
946+
947+
ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0);
948+
//ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1);
949+
//ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2);
950+
//ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3);
951+
952+
//ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20);
953+
//ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21);
954+
//ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22);
955+
//ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23);
956+
ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 24);
957+
958+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
959+
960+
ggml_metal_cv_free(cv);
961+
962+
return res;
963+
}
964+
921965
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
922966
ggml_metal_library_t lib,
923967
const ggml_tensor * op,
924968
bool has_mask,
925969
bool has_sinks,
926970
bool has_bias,
927971
bool has_scap,
972+
bool has_kvpad,
928973
int32_t nsg) {
929974
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
930975

@@ -943,12 +988,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
943988
dk,
944989
dv);
945990

946-
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d",
991+
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=%d_ns10=%d_ns20=%d_nsg=%d",
947992
base,
948993
has_mask,
949994
has_sinks,
950995
has_bias,
951996
has_scap,
997+
has_kvpad,
952998
ns10,
953999
ns20,
9541000
nsg);
@@ -964,6 +1010,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
9641010
ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1);
9651011
ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2);
9661012
ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3);
1013+
ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4);
9671014

9681015
ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20);
9691016
ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21);
@@ -983,6 +1030,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
9831030
bool has_sinks,
9841031
bool has_bias,
9851032
bool has_scap,
1033+
bool has_kvpad,
9861034
int32_t nsg,
9871035
int32_t nwg) {
9881036
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
@@ -1002,12 +1050,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
10021050
dk,
10031051
dv);
10041052

1005-
snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
1053+
snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_scap=%d_kvpad=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
10061054
base,
10071055
has_mask,
10081056
has_sinks,
10091057
has_bias,
10101058
has_scap,
1059+
has_kvpad,
10111060
ns10,
10121061
ns20,
10131062
nsg, nwg);
@@ -1023,6 +1072,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
10231072
ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1);
10241073
ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2);
10251074
ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3);
1075+
ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4);
10261076

10271077
ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20);
10281078
ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21);

ggml/src/ggml-metal/ggml-metal-device.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,13 +135,20 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_me
135135
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op);
136136
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
137137

138+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
139+
ggml_metal_library_t lib,
140+
const struct ggml_tensor * op,
141+
bool has_mask,
142+
int32_t ncpsg);
143+
138144
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
139145
ggml_metal_library_t lib,
140146
const struct ggml_tensor * op,
141147
bool has_mask,
142148
bool has_sinks,
143149
bool has_bias,
144150
bool has_scap,
151+
bool has_kvpad,
145152
int32_t nsg);
146153

147154
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
@@ -151,6 +158,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
151158
bool has_sinks,
152159
bool has_bias,
153160
bool has_scap,
161+
bool has_kvpad,
154162
int32_t nsg,
155163
int32_t nwg);
156164

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,12 @@
6969
#define N_SG_IQ4_XS 2
7070

7171
// function constants offsets
72-
#define FC_FLASH_ATTN_EXT 100
73-
#define FC_FLASH_ATTN_EXT_VEC 200
74-
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 300
75-
#define FC_MUL_MV 400
76-
#define FC_MUL_MM 500
72+
#define FC_FLASH_ATTN_EXT_PAD 100
73+
#define FC_FLASH_ATTN_EXT 200
74+
#define FC_FLASH_ATTN_EXT_VEC 300
75+
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 400
76+
#define FC_MUL_MV 500
77+
#define FC_MUL_MM 600
7778

7879
// kernel argument structs
7980
//
@@ -243,6 +244,24 @@ typedef struct {
243244
int32_t sect_3;
244245
} ggml_metal_kargs_rope;
245246

247+
typedef struct {
248+
int32_t ne11;
249+
int32_t ne_12_2; // assume K and V are same shape
250+
int32_t ne_12_3;
251+
uint64_t nb11;
252+
uint64_t nb12;
253+
uint64_t nb13;
254+
uint64_t nb21;
255+
uint64_t nb22;
256+
uint64_t nb23;
257+
int32_t ne31;
258+
int32_t ne32;
259+
int32_t ne33;
260+
uint64_t nb31;
261+
uint64_t nb32;
262+
uint64_t nb33;
263+
} ggml_metal_kargs_flash_attn_ext_pad;
264+
246265
typedef struct {
247266
int32_t ne01;
248267
int32_t ne02;
@@ -261,6 +280,7 @@ typedef struct {
261280
uint64_t nb21;
262281
uint64_t nb22;
263282
uint64_t nb23;
283+
int32_t ne31;
264284
int32_t ne32;
265285
int32_t ne33;
266286
uint64_t nb31;
@@ -295,6 +315,7 @@ typedef struct {
295315
uint64_t nb21;
296316
uint64_t nb22;
297317
uint64_t nb23;
318+
int32_t ne31;
298319
int32_t ne32;
299320
int32_t ne33;
300321
uint64_t nb31;

0 commit comments

Comments
 (0)