@@ -2009,11 +2009,20 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2009
2009
2010
2010
GGML_ASSERT (ne01 < 65536 );
2011
2011
2012
+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id (op->src [0 ]);
2013
+ ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id (op->src [1 ]);
2014
+ ggml_metal_buffer_id bid_src2 = ggml_metal_get_buffer_id (op->src [2 ]);
2015
+ ggml_metal_buffer_id bid_src3 = has_mask ? ggml_metal_get_buffer_id (op->src [3 ]) : bid_src0;
2016
+ ggml_metal_buffer_id bid_src4 = has_sinks ? ggml_metal_get_buffer_id (op->src [4 ]) : bid_src0;
2017
+
2012
2018
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id (op);
2013
2019
2014
2020
ggml_metal_buffer_id bid_pad = bid_dst;
2015
2021
bid_pad.offs += ggml_nbytes (op);
2016
2022
2023
+ ggml_metal_buffer_id bid_tmp = bid_pad;
2024
+ bid_tmp.offs += ggml_metal_op_flash_attn_ext_extra_pad (op);
2025
+
2017
2026
if (!ggml_metal_op_flash_attn_ext_use_vec (op)) {
2018
2027
// half8x8 kernel
2019
2028
const int64_t nqptg = 8 ; // queries per threadgroup !! sync with kernel template arguments !!
@@ -2050,14 +2059,10 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2050
2059
2051
2060
ggml_metal_encoder_set_pipeline (enc, pipeline0);
2052
2061
ggml_metal_encoder_set_bytes (enc, &args0, sizeof (args0), 0 );
2053
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [1 ]), 1 );
2054
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [2 ]), 2 );
2055
- if (op->src [3 ]) {
2056
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [3 ]), 3 );
2057
- } else {
2058
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [0 ]), 3 );
2059
- }
2060
- ggml_metal_encoder_set_buffer (enc, bid_pad, 4 );
2062
+ ggml_metal_encoder_set_buffer (enc, bid_src1, 1 );
2063
+ ggml_metal_encoder_set_buffer (enc, bid_src2, 2 );
2064
+ ggml_metal_encoder_set_buffer (enc, bid_src3, 3 );
2065
+ ggml_metal_encoder_set_buffer (enc, bid_pad, 4 );
2061
2066
2062
2067
assert (ne12 == ne22);
2063
2068
assert (ne13 == ne23);
@@ -2139,21 +2144,13 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2139
2144
2140
2145
ggml_metal_encoder_set_pipeline (enc, pipeline);
2141
2146
ggml_metal_encoder_set_bytes (enc, &args, sizeof (args), 0 );
2142
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [0 ]), 1 );
2143
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [1 ]), 2 );
2144
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [2 ]), 3 );
2145
- if (op->src [3 ]) {
2146
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [3 ]), 4 );
2147
- } else {
2148
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [0 ]), 4 );
2149
- }
2150
- if (op->src [4 ]) {
2151
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [4 ]), 5 );
2152
- } else {
2153
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [0 ]), 5 );
2154
- }
2155
- ggml_metal_encoder_set_buffer (enc, bid_pad, 6 );
2156
- ggml_metal_encoder_set_buffer (enc, bid_dst, 7 );
2147
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1 );
2148
+ ggml_metal_encoder_set_buffer (enc, bid_src1, 2 );
2149
+ ggml_metal_encoder_set_buffer (enc, bid_src2, 3 );
2150
+ ggml_metal_encoder_set_buffer (enc, bid_src3, 4 );
2151
+ ggml_metal_encoder_set_buffer (enc, bid_src4, 5 );
2152
+ ggml_metal_encoder_set_buffer (enc, bid_pad, 6 );
2153
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 7 );
2157
2154
2158
2155
ggml_metal_encoder_set_threadgroup_memory_size (enc, smem, 0 );
2159
2156
@@ -2196,14 +2193,10 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2196
2193
2197
2194
ggml_metal_encoder_set_pipeline (enc, pipeline0);
2198
2195
ggml_metal_encoder_set_bytes (enc, &args0, sizeof (args0), 0 );
2199
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [1 ]), 1 );
2200
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [2 ]), 2 );
2201
- if (op->src [3 ]) {
2202
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [3 ]), 3 );
2203
- } else {
2204
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [0 ]), 3 );
2205
- }
2206
- ggml_metal_encoder_set_buffer (enc, bid_pad, 4 );
2196
+ ggml_metal_encoder_set_buffer (enc, bid_src1, 1 );
2197
+ ggml_metal_encoder_set_buffer (enc, bid_src2, 2 );
2198
+ ggml_metal_encoder_set_buffer (enc, bid_src3, 3 );
2199
+ ggml_metal_encoder_set_buffer (enc, bid_pad, 4 );
2207
2200
2208
2201
assert (ne12 == ne22);
2209
2202
assert (ne13 == ne23);
@@ -2302,26 +2295,20 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2302
2295
2303
2296
ggml_metal_encoder_set_pipeline (enc, pipeline);
2304
2297
ggml_metal_encoder_set_bytes (enc, &args, sizeof (args), 0 );
2305
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [0 ]), 1 );
2306
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [1 ]), 2 );
2307
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [2 ]), 3 );
2308
- if (op->src [3 ]) {
2309
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [3 ]), 4 );
2310
- } else {
2311
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [0 ]), 4 );
2312
- }
2313
- if (op->src [4 ]) {
2314
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [4 ]), 5 );
2315
- } else {
2316
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [0 ]), 5 );
2317
- }
2298
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1 );
2299
+ ggml_metal_encoder_set_buffer (enc, bid_src1, 2 );
2300
+ ggml_metal_encoder_set_buffer (enc, bid_src2, 3 );
2301
+ ggml_metal_encoder_set_buffer (enc, bid_src3, 4 );
2302
+ ggml_metal_encoder_set_buffer (enc, bid_src4, 5 );
2318
2303
2319
2304
const size_t smem = FATTN_SMEM (nsg);
2320
2305
2321
2306
// printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, props_dev->max_theadgroup_memory_size, (int) nsg, (int) nsgmax);
2322
2307
GGML_ASSERT (smem <= props_dev->max_theadgroup_memory_size );
2323
2308
2324
2309
if (nwg == 1 ) {
2310
+ assert (ggml_metal_op_flash_attn_ext_extra_tmp (op) == 0 );
2311
+
2325
2312
// using 1 workgroup -> write the result directly into dst
2326
2313
ggml_metal_encoder_set_buffer (enc, bid_pad, 6 );
2327
2314
ggml_metal_encoder_set_buffer (enc, bid_dst, 7 );
@@ -2331,13 +2318,12 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2331
2318
ggml_metal_encoder_dispatch_threadgroups (enc, (ne01 + nqptg - 1 )/nqptg, ne02, ne03*nwg, 32 , nsg, 1 );
2332
2319
} else {
2333
2320
// sanity checks
2321
+ assert (ggml_metal_op_flash_attn_ext_extra_tmp (op) != 0 );
2322
+
2334
2323
GGML_ASSERT (ne01*ne02*ne03 == ne1*ne2*ne3);
2335
2324
GGML_ASSERT ((uint64_t )ne1*ne2*ne3 <= (1u << 31 ));
2336
2325
2337
2326
// write the results from each workgroup into a temp buffer
2338
- ggml_metal_buffer_id bid_tmp = bid_dst;
2339
- bid_tmp.offs += ggml_nbytes (op) + ggml_metal_op_flash_attn_ext_extra_pad (op);
2340
-
2341
2327
ggml_metal_encoder_set_buffer (enc, bid_pad, 6 );
2342
2328
ggml_metal_encoder_set_buffer (enc, bid_tmp, 7 );
2343
2329
0 commit comments