Skip to content

Commit 5d0d2d2

Browse files
committed
metal : pad K, V and Mask when needed
1 parent d8359f5 commit 5d0d2d2

File tree

8 files changed

+420
-42
lines changed

8 files changed

+420
-42
lines changed

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -930,13 +930,58 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library
930930
return res;
931931
}
932932

933+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
934+
ggml_metal_library_t lib,
935+
const struct ggml_tensor * op,
936+
bool has_mask,
937+
int32_t ncpsg) {
938+
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
939+
GGML_UNUSED(op);
940+
941+
char base[256];
942+
char name[256];
943+
944+
snprintf(base, 256, "kernel_%s",
945+
"flash_attn_ext_pad");
946+
947+
snprintf(name, 256, "%s_mask=%d_ncpsg=%d",
948+
base,
949+
has_mask,
950+
ncpsg);
951+
952+
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
953+
if (res) {
954+
return res;
955+
}
956+
957+
ggml_metal_cv_t cv = ggml_metal_cv_init();
958+
959+
ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0);
960+
//ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1);
961+
//ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2);
962+
//ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3);
963+
964+
//ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20);
965+
//ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21);
966+
//ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22);
967+
//ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23);
968+
ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 24);
969+
970+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
971+
972+
ggml_metal_cv_free(cv);
973+
974+
return res;
975+
}
976+
933977
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
934978
ggml_metal_library_t lib,
935979
const ggml_tensor * op,
936980
bool has_mask,
937981
bool has_sinks,
938982
bool has_bias,
939983
bool has_scap,
984+
bool has_kvpad,
940985
int32_t nsg) {
941986
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
942987

@@ -955,12 +1000,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
9551000
dk,
9561001
dv);
9571002

958-
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d",
1003+
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=%d_ns10=%d_ns20=%d_nsg=%d",
9591004
base,
9601005
has_mask,
9611006
has_sinks,
9621007
has_bias,
9631008
has_scap,
1009+
has_kvpad,
9641010
ns10,
9651011
ns20,
9661012
nsg);
@@ -976,6 +1022,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
9761022
ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1);
9771023
ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2);
9781024
ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3);
1025+
ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4);
9791026

9801027
ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20);
9811028
ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21);
@@ -995,6 +1042,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
9951042
bool has_sinks,
9961043
bool has_bias,
9971044
bool has_scap,
1045+
bool has_kvpad,
9981046
int32_t nsg,
9991047
int32_t nwg) {
10001048
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
@@ -1014,12 +1062,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
10141062
dk,
10151063
dv);
10161064

1017-
snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
1065+
snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_scap=%d_kvpad=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
10181066
base,
10191067
has_mask,
10201068
has_sinks,
10211069
has_bias,
10221070
has_scap,
1071+
has_kvpad,
10231072
ns10,
10241073
ns20,
10251074
nsg, nwg);
@@ -1035,6 +1084,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
10351084
ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1);
10361085
ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2);
10371086
ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3);
1087+
ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4);
10381088

10391089
ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20);
10401090
ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21);

ggml/src/ggml-metal/ggml-metal-device.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,13 +135,20 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_me
135135
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op);
136136
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
137137

138+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
139+
ggml_metal_library_t lib,
140+
const struct ggml_tensor * op,
141+
bool has_mask,
142+
int32_t ncpsg);
143+
138144
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
139145
ggml_metal_library_t lib,
140146
const struct ggml_tensor * op,
141147
bool has_mask,
142148
bool has_sinks,
143149
bool has_bias,
144150
bool has_scap,
151+
bool has_kvpad,
145152
int32_t nsg);
146153

147154
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
@@ -151,6 +158,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
151158
bool has_sinks,
152159
bool has_bias,
153160
bool has_scap,
161+
bool has_kvpad,
154162
int32_t nsg,
155163
int32_t nwg);
156164

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,12 @@
7272
#define N_SG_IQ4_XS 2
7373

7474
// function constants offsets
75-
#define FC_FLASH_ATTN_EXT 100
76-
#define FC_FLASH_ATTN_EXT_VEC 200
77-
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 300
78-
#define FC_MUL_MV 400
79-
#define FC_MUL_MM 500
75+
#define FC_FLASH_ATTN_EXT_PAD 100
76+
#define FC_FLASH_ATTN_EXT 200
77+
#define FC_FLASH_ATTN_EXT_VEC 300
78+
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 400
79+
#define FC_MUL_MV 500
80+
#define FC_MUL_MM 600
8081

8182
// kernel argument structs
8283
//
@@ -246,6 +247,24 @@ typedef struct {
246247
int32_t sect_3;
247248
} ggml_metal_kargs_rope;
248249

250+
typedef struct {
251+
int32_t ne11;
252+
int32_t ne_12_2; // assume K and V are same shape
253+
int32_t ne_12_3;
254+
uint64_t nb11;
255+
uint64_t nb12;
256+
uint64_t nb13;
257+
uint64_t nb21;
258+
uint64_t nb22;
259+
uint64_t nb23;
260+
int32_t ne31;
261+
int32_t ne32;
262+
int32_t ne33;
263+
uint64_t nb31;
264+
uint64_t nb32;
265+
uint64_t nb33;
266+
} ggml_metal_kargs_flash_attn_ext_pad;
267+
249268
typedef struct {
250269
int32_t ne01;
251270
int32_t ne02;
@@ -264,6 +283,7 @@ typedef struct {
264283
uint64_t nb21;
265284
uint64_t nb22;
266285
uint64_t nb23;
286+
int32_t ne31;
267287
int32_t ne32;
268288
int32_t ne33;
269289
uint64_t nb31;
@@ -298,6 +318,7 @@ typedef struct {
298318
uint64_t nb21;
299319
uint64_t nb22;
300320
uint64_t nb23;
321+
int32_t ne31;
301322
int32_t ne32;
302323
int32_t ne33;
303324
uint64_t nb31;

0 commit comments

Comments
 (0)