Skip to content

Commit dec0c8d

Browse files
authored
Merge pull request #8 from ggml-org/master
merge from upstream
2 parents 3df2244 + 4e0388a commit dec0c8d

File tree

21 files changed

+1062
-547
lines changed

21 files changed

+1062
-547
lines changed

common/arg.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3859,7 +3859,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
38593859
[](common_params & params) {
38603860
params.model.hf_repo = "ggml-org/bge-small-en-v1.5-Q8_0-GGUF";
38613861
params.model.hf_file = "bge-small-en-v1.5-q8_0.gguf";
3862-
params.pooling_type = LLAMA_POOLING_TYPE_NONE;
38633862
params.embd_normalize = 2;
38643863
params.n_ctx = 512;
38653864
params.verbose_prompt = true;
@@ -3873,7 +3872,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
38733872
[](common_params & params) {
38743873
params.model.hf_repo = "ggml-org/e5-small-v2-Q8_0-GGUF";
38753874
params.model.hf_file = "e5-small-v2-q8_0.gguf";
3876-
params.pooling_type = LLAMA_POOLING_TYPE_NONE;
38773875
params.embd_normalize = 2;
38783876
params.n_ctx = 512;
38793877
params.verbose_prompt = true;
@@ -3887,7 +3885,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
38873885
[](common_params & params) {
38883886
params.model.hf_repo = "ggml-org/gte-small-Q8_0-GGUF";
38893887
params.model.hf_file = "gte-small-q8_0.gguf";
3890-
params.pooling_type = LLAMA_POOLING_TYPE_NONE;
38913888
params.embd_normalize = 2;
38923889
params.n_ctx = 512;
38933890
params.verbose_prompt = true;

ggml/src/ggml-cpu/ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8135,7 +8135,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
81358135
}
81368136

81378137
// V /= S
8138-
const float S_inv = 1.0f/S;
8138+
const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
81398139
ggml_vec_scale_f32(DV, VKQ32, S_inv);
81408140

81418141
// dst indices

ggml/src/ggml-cuda/fattn.cu

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,12 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
208208

209209
const int cc = ggml_cuda_info().devices[device].cc;
210210

211+
// TODO: temporary until support is extended
212+
// https://github.com/ggml-org/llama.cpp/pull/16148#issuecomment-3343525206
213+
if (K->ne[1] % FATTN_KQ_STRIDE != 0) {
214+
return BEST_FATTN_KERNEL_NONE;
215+
}
216+
211217
switch (K->ne[0]) {
212218
case 64:
213219
case 128:

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 72 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_librar
338338
char base[256];
339339
char name[256];
340340

341-
snprintf(base, 256, "kernel_ssm_conv_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
341+
const char * suffix = "";
342+
343+
if (op->src[1]->ne[0] % 4 == 0) {
344+
suffix = "_4";
345+
}
346+
347+
snprintf(base, 256, "kernel_ssm_conv_%s_%s%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix);
342348
snprintf(name, 256, "%s", base);
343349

344350
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
@@ -352,15 +358,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_librar
352358
}
353359

354360
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op) {
361+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
362+
355363
char base[256];
356364
char name[256];
357365

358-
if (op->src[3]->ne[0] == 1) {
359-
snprintf(base, 256, "kernel_ssm_scan_group_%s", ggml_type_name(op->src[0]->type));
360-
} else {
361-
snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type));
362-
}
363-
snprintf(name, 256, "%s", base);
366+
const int nsg = (ne00 + 31)/32;
367+
368+
snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type));
369+
snprintf(name, 256, "%s_nsg=%d", base, nsg);
364370

365371
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
366372
if (res) {
@@ -369,7 +375,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_librar
369375

370376
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
371377

372-
ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
378+
ggml_metal_pipeline_set_smem(res, 32*sizeof(float)*nsg);
373379

374380
return res;
375381
}
@@ -918,13 +924,58 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library
918924
return res;
919925
}
920926

927+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
928+
ggml_metal_library_t lib,
929+
const struct ggml_tensor * op,
930+
bool has_mask,
931+
int32_t ncpsg) {
932+
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
933+
GGML_UNUSED(op);
934+
935+
char base[256];
936+
char name[256];
937+
938+
snprintf(base, 256, "kernel_%s",
939+
"flash_attn_ext_pad");
940+
941+
snprintf(name, 256, "%s_mask=%d_ncpsg=%d",
942+
base,
943+
has_mask,
944+
ncpsg);
945+
946+
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
947+
if (res) {
948+
return res;
949+
}
950+
951+
ggml_metal_cv_t cv = ggml_metal_cv_init();
952+
953+
ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0);
954+
//ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1);
955+
//ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2);
956+
//ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3);
957+
958+
//ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20);
959+
//ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21);
960+
//ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22);
961+
//ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23);
962+
ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 24);
963+
964+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
965+
966+
ggml_metal_cv_free(cv);
967+
968+
return res;
969+
}
970+
921971
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
922972
ggml_metal_library_t lib,
923973
const ggml_tensor * op,
924974
bool has_mask,
925975
bool has_sinks,
926976
bool has_bias,
927977
bool has_scap,
978+
bool has_kvpad,
928979
int32_t nsg) {
929980
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
930981

@@ -937,18 +988,23 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
937988
const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0];
938989
const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0];
939990

991+
// do bounds checks for the mask?
992+
const bool bc_mask = op->src[3] && (op->src[3]->ne[1] % 8 != 0);
993+
940994
snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d",
941995
"flash_attn_ext",
942996
ggml_type_name(op->src[1]->type),
943997
dk,
944998
dv);
945999

946-
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d",
1000+
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=%d_bcm=%d_ns10=%d_ns20=%d_nsg=%d",
9471001
base,
9481002
has_mask,
9491003
has_sinks,
9501004
has_bias,
9511005
has_scap,
1006+
has_kvpad,
1007+
bc_mask,
9521008
ns10,
9531009
ns20,
9541010
nsg);
@@ -964,6 +1020,9 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
9641020
ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1);
9651021
ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2);
9661022
ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3);
1023+
ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4);
1024+
1025+
ggml_metal_cv_set_bool(cv, bc_mask, FC_FLASH_ATTN_EXT + 10);
9671026

9681027
ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20);
9691028
ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21);
@@ -983,6 +1042,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
9831042
bool has_sinks,
9841043
bool has_bias,
9851044
bool has_scap,
1045+
bool has_kvpad,
9861046
int32_t nsg,
9871047
int32_t nwg) {
9881048
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
@@ -1002,12 +1062,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
10021062
dk,
10031063
dv);
10041064

1005-
snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
1065+
snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_scap=%d_kvpad=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
10061066
base,
10071067
has_mask,
10081068
has_sinks,
10091069
has_bias,
10101070
has_scap,
1071+
has_kvpad,
10111072
ns10,
10121073
ns20,
10131074
nsg, nwg);
@@ -1023,6 +1084,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
10231084
ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1);
10241085
ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2);
10251086
ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3);
1087+
ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4);
10261088

10271089
ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20);
10281090
ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21);

ggml/src/ggml-metal/ggml-metal-device.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,13 +135,20 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_me
135135
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op);
136136
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
137137

138+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
139+
ggml_metal_library_t lib,
140+
const struct ggml_tensor * op,
141+
bool has_mask,
142+
int32_t ncpsg);
143+
138144
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
139145
ggml_metal_library_t lib,
140146
const struct ggml_tensor * op,
141147
bool has_mask,
142148
bool has_sinks,
143149
bool has_bias,
144150
bool has_scap,
151+
bool has_kvpad,
145152
int32_t nsg);
146153

147154
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
@@ -151,6 +158,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
151158
bool has_sinks,
152159
bool has_bias,
153160
bool has_scap,
161+
bool has_kvpad,
154162
int32_t nsg,
155163
int32_t nwg);
156164

ggml/src/ggml-metal/ggml-metal-device.m

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -776,9 +776,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
776776
};
777777
}
778778
case GGML_OP_GET_ROWS:
779-
{
780-
return op->ne[3] == 1;
781-
}
779+
return true;
782780
case GGML_OP_SET_ROWS:
783781
{
784782
if (op->src[0]->type != GGML_TYPE_F32) {

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,12 @@
6969
#define N_SG_IQ4_XS 2
7070

7171
// function constants offsets
72-
#define FC_FLASH_ATTN_EXT 100
73-
#define FC_FLASH_ATTN_EXT_VEC 200
74-
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 300
75-
#define FC_MUL_MV 400
76-
#define FC_MUL_MM 500
72+
#define FC_FLASH_ATTN_EXT_PAD 100
73+
#define FC_FLASH_ATTN_EXT 200
74+
#define FC_FLASH_ATTN_EXT_VEC 300
75+
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 400
76+
#define FC_MUL_MV 500
77+
#define FC_MUL_MM 600
7778

7879
// kernel argument structs
7980
//
@@ -178,6 +179,7 @@ typedef struct {
178179
} ggml_metal_kargs_clamp;
179180

180181
typedef struct {
182+
int64_t nk0;
181183
int64_t ne00;
182184
int64_t ne01;
183185
int64_t ne02;
@@ -243,6 +245,24 @@ typedef struct {
243245
int32_t sect_3;
244246
} ggml_metal_kargs_rope;
245247

248+
typedef struct {
249+
int32_t ne11;
250+
int32_t ne_12_2; // assume K and V are same shape
251+
int32_t ne_12_3;
252+
uint64_t nb11;
253+
uint64_t nb12;
254+
uint64_t nb13;
255+
uint64_t nb21;
256+
uint64_t nb22;
257+
uint64_t nb23;
258+
int32_t ne31;
259+
int32_t ne32;
260+
int32_t ne33;
261+
uint64_t nb31;
262+
uint64_t nb32;
263+
uint64_t nb33;
264+
} ggml_metal_kargs_flash_attn_ext_pad;
265+
246266
typedef struct {
247267
int32_t ne01;
248268
int32_t ne02;
@@ -261,6 +281,7 @@ typedef struct {
261281
uint64_t nb21;
262282
uint64_t nb22;
263283
uint64_t nb23;
284+
int32_t ne31;
264285
int32_t ne32;
265286
int32_t ne33;
266287
uint64_t nb31;
@@ -295,6 +316,7 @@ typedef struct {
295316
uint64_t nb21;
296317
uint64_t nb22;
297318
uint64_t nb23;
319+
int32_t ne31;
298320
int32_t ne32;
299321
int32_t ne33;
300322
uint64_t nb31;
@@ -572,32 +594,45 @@ typedef struct {
572594
int64_t n_seq_tokens;
573595
int64_t n_seqs;
574596
uint64_t s_off;
597+
uint64_t nb00;
575598
uint64_t nb01;
576599
uint64_t nb02;
577600
uint64_t nb03;
601+
uint64_t nb10;
578602
uint64_t nb11;
579603
uint64_t nb12;
604+
uint64_t ns12;
580605
uint64_t nb13;
606+
uint64_t nb20;
581607
uint64_t nb21;
608+
uint64_t ns21;
582609
uint64_t nb22;
610+
int64_t ne30;
583611
uint64_t nb31;
584612
uint64_t nb41;
585613
uint64_t nb42;
614+
uint64_t ns42;
586615
uint64_t nb43;
587616
uint64_t nb51;
588617
uint64_t nb52;
618+
uint64_t ns52;
589619
uint64_t nb53;
620+
uint64_t nb0;
590621
} ggml_metal_kargs_ssm_scan;
591622

592623
typedef struct {
593-
int64_t ne00;
624+
int32_t ne00t;
625+
int32_t ne00;
594626
uint64_t nb01;
595627
uint64_t nb02;
596-
int64_t ne10;
628+
uint64_t nb03;
629+
int32_t ne10;
597630
uint64_t nb10;
598631
uint64_t nb11;
632+
uint64_t nb12;
599633
uint64_t nb1;
600634
uint64_t nb2;
635+
uint64_t nb3;
601636
} ggml_metal_kargs_get_rows;
602637

603638
typedef struct {

0 commit comments

Comments
 (0)