Skip to content

Commit 6669297

Browse files
committed
cont : simplify
1 parent 5d0d2d2 commit 6669297

File tree

2 files changed

+35
-48
lines changed

2 files changed

+35
-48
lines changed

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 33 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2007,11 +2007,20 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
20072007

20082008
GGML_ASSERT(ne01 < 65536);
20092009

2010+
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
2011+
ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
2012+
ggml_metal_buffer_id bid_src2 = ggml_metal_get_buffer_id(op->src[2]);
2013+
ggml_metal_buffer_id bid_src3 = has_mask ? ggml_metal_get_buffer_id(op->src[3]) : bid_src0;
2014+
ggml_metal_buffer_id bid_src4 = has_sinks ? ggml_metal_get_buffer_id(op->src[4]) : bid_src0;
2015+
20102016
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
20112017

20122018
ggml_metal_buffer_id bid_pad = bid_dst;
20132019
bid_pad.offs += ggml_nbytes(op);
20142020

2021+
ggml_metal_buffer_id bid_tmp = bid_pad;
2022+
bid_tmp.offs += ggml_metal_op_flash_attn_ext_extra_pad(op);
2023+
20152024
if (!ggml_metal_op_flash_attn_ext_use_vec(op)) {
20162025
// half8x8 kernel
20172026
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
@@ -2048,14 +2057,10 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
20482057

20492058
ggml_metal_encoder_set_pipeline(enc, pipeline0);
20502059
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
2051-
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1);
2052-
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 2);
2053-
if (op->src[3]) {
2054-
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 3);
2055-
} else {
2056-
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 3);
2057-
}
2058-
ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
2060+
ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
2061+
ggml_metal_encoder_set_buffer (enc, bid_src2, 2);
2062+
ggml_metal_encoder_set_buffer (enc, bid_src3, 3);
2063+
ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
20592064

20602065
assert(ne12 == ne22);
20612066
assert(ne13 == ne23);
@@ -2137,21 +2142,13 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
21372142

21382143
ggml_metal_encoder_set_pipeline(enc, pipeline);
21392144
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
2140-
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
2141-
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
2142-
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3);
2143-
if (op->src[3]) {
2144-
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 4);
2145-
} else {
2146-
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 4);
2147-
}
2148-
if (op->src[4]) {
2149-
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[4]), 5);
2150-
} else {
2151-
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 5);
2152-
}
2153-
ggml_metal_encoder_set_buffer (enc, bid_pad, 6);
2154-
ggml_metal_encoder_set_buffer (enc, bid_dst, 7);
2145+
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
2146+
ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
2147+
ggml_metal_encoder_set_buffer (enc, bid_src2, 3);
2148+
ggml_metal_encoder_set_buffer (enc, bid_src3, 4);
2149+
ggml_metal_encoder_set_buffer (enc, bid_src4, 5);
2150+
ggml_metal_encoder_set_buffer (enc, bid_pad, 6);
2151+
ggml_metal_encoder_set_buffer (enc, bid_dst, 7);
21552152

21562153
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
21572154

@@ -2194,14 +2191,10 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
21942191

21952192
ggml_metal_encoder_set_pipeline(enc, pipeline0);
21962193
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
2197-
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1);
2198-
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 2);
2199-
if (op->src[3]) {
2200-
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 3);
2201-
} else {
2202-
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 3);
2203-
}
2204-
ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
2194+
ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
2195+
ggml_metal_encoder_set_buffer (enc, bid_src2, 2);
2196+
ggml_metal_encoder_set_buffer (enc, bid_src3, 3);
2197+
ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
22052198

22062199
assert(ne12 == ne22);
22072200
assert(ne13 == ne23);
@@ -2300,26 +2293,20 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
23002293

23012294
ggml_metal_encoder_set_pipeline(enc, pipeline);
23022295
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
2303-
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
2304-
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
2305-
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3);
2306-
if (op->src[3]) {
2307-
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 4);
2308-
} else {
2309-
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 4);
2310-
}
2311-
if (op->src[4]) {
2312-
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[4]), 5);
2313-
} else {
2314-
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 5);
2315-
}
2296+
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
2297+
ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
2298+
ggml_metal_encoder_set_buffer (enc, bid_src2, 3);
2299+
ggml_metal_encoder_set_buffer (enc, bid_src3, 4);
2300+
ggml_metal_encoder_set_buffer (enc, bid_src4, 5);
23162301

23172302
const size_t smem = FATTN_SMEM(nsg);
23182303

23192304
//printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, props_dev->max_theadgroup_memory_size, (int) nsg, (int) nsgmax);
23202305
GGML_ASSERT(smem <= props_dev->max_theadgroup_memory_size);
23212306

23222307
if (nwg == 1) {
2308+
assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) == 0);
2309+
23232310
// using 1 workgroup -> write the result directly into dst
23242311
ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
23252312
ggml_metal_encoder_set_buffer(enc, bid_dst, 7);
@@ -2329,13 +2316,12 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
23292316
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
23302317
} else {
23312318
// sanity checks
2319+
assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0);
2320+
23322321
GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
23332322
GGML_ASSERT((uint64_t)ne1*ne2*ne3 <= (1u << 31));
23342323

23352324
// write the results from each workgroup into a temp buffer
2336-
ggml_metal_buffer_id bid_tmp = bid_dst;
2337-
bid_tmp.offs += ggml_nbytes(op) + ggml_metal_op_flash_attn_ext_extra_pad(op);
2338-
23392325
ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
23402326
ggml_metal_encoder_set_buffer(enc, bid_tmp, 7);
23412327

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4591,7 +4591,6 @@ void kernel_flash_attn_ext_impl(
45914591

45924592
// mask storage in shared mem
45934593
threadgroup half2 * sm2 = (threadgroup half2 *) (shmem_f16 + Q*T + 2*C);
4594-
threadgroup half * sm = (threadgroup half *) (sm2);
45954594

45964595
// per-query mask pointers
45974596
device const half2 * pm2[NQ];
@@ -4676,6 +4675,8 @@ void kernel_flash_attn_ext_impl(
46764675
v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C;
46774676

46784677
if (!FC_flash_attn_ext_has_mask) {
4678+
threadgroup half * sm = (threadgroup half *) (sm2);
4679+
46794680
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
46804681
const short j = jj*NSG + sgitg;
46814682

0 commit comments

Comments
 (0)