Skip to content

Commit 089404f

Browse files
committed
metal : fattn args
ggml-ci
1 parent 996e479 commit 089404f

File tree

3 files changed

+113
-125
lines changed

3 files changed

+113
-125
lines changed

ggml/src/ggml-common.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,30 @@ typedef struct {
446446
float beta_fast;
447447
float beta_slow;
448448
} ggml_metal_kargs_rope;
449+
450+
typedef struct {
451+
int32_t ne01;
452+
int32_t ne02;
453+
int32_t ne03;
454+
uint64_t nb01;
455+
uint64_t nb02;
456+
uint64_t nb03;
457+
int32_t ne11;
458+
int32_t ne_12_2; // assume K and V are same shape
459+
int32_t ne_12_3;
460+
uint64_t nb_12_1;
461+
uint64_t nb_12_2;
462+
uint64_t nb_12_3;
463+
uint64_t nb31;
464+
int32_t ne1;
465+
int32_t ne2;
466+
float scale;
467+
float max_bias;
468+
float m0;
469+
float m1;
470+
uint16_t n_head_log2;
471+
float logit_softcap;
472+
} ggml_metal_kargs_flash_attn_ext;
449473
#endif
450474

451475
#endif // GGML_COMMON_DECL

ggml/src/ggml-metal.m

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3228,37 +3228,41 @@ static void ggml_metal_encode_node(
32283228
}
32293229
}
32303230

3231+
ggml_metal_kargs_flash_attn_ext args = {
3232+
.ne01 = ne01,
3233+
.ne02 = ne02,
3234+
.ne03 = ne03,
3235+
.nb01 = nb01,
3236+
.nb02 = nb02,
3237+
.nb03 = nb03,
3238+
.ne11 = ne11,
3239+
.ne_12_2 = ne12,
3240+
.ne_12_3 = ne13,
3241+
.nb_12_1 = nb11,
3242+
.nb_12_2 = nb12,
3243+
.nb_12_3 = nb13,
3244+
.nb31 = nb31,
3245+
.ne1 = ne1,
3246+
.ne2 = ne2,
3247+
.scale = scale,
3248+
.max_bias = max_bias,
3249+
.m0 = m0,
3250+
.m1 = m1,
3251+
.n_head_log2 = n_head_log2,
3252+
.logit_softcap = logit_softcap,
3253+
};
3254+
32313255
[encoder setComputePipelineState:pipeline];
3232-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3233-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
3234-
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
3256+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3257+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
3258+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
32353259
if (id_src3) {
3236-
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
3260+
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
32373261
} else {
3238-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
3262+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
32393263
}
3240-
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
3241-
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
3242-
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
3243-
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
3244-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
3245-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
3246-
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
3247-
[encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11];
3248-
[encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12];
3249-
[encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13];
3250-
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
3251-
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
3252-
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
3253-
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:17];
3254-
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:18];
3255-
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:19];
3256-
[encoder setBytes:&scale length:sizeof( float) atIndex:20];
3257-
[encoder setBytes:&max_bias length:sizeof( float) atIndex:21];
3258-
[encoder setBytes:&m0 length:sizeof(m0) atIndex:22];
3259-
[encoder setBytes:&m1 length:sizeof(m1) atIndex:23];
3260-
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:24];
3261-
[encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:25];
3264+
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
3265+
[encoder setBytes:&args length:sizeof(args) atIndex:5];
32623266

32633267
if (!use_vec_kernel) {
32643268
// half8x8 kernel

0 commit comments

Comments
 (0)