@@ -2007,11 +2007,20 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2007
2007
2008
2008
GGML_ASSERT (ne01 < 65536 );
2009
2009
2010
+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id (op->src [0 ]);
2011
+ ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id (op->src [1 ]);
2012
+ ggml_metal_buffer_id bid_src2 = ggml_metal_get_buffer_id (op->src [2 ]);
2013
+ ggml_metal_buffer_id bid_src3 = has_mask ? ggml_metal_get_buffer_id (op->src [3 ]) : bid_src0;
2014
+ ggml_metal_buffer_id bid_src4 = has_sinks ? ggml_metal_get_buffer_id (op->src [4 ]) : bid_src0;
2015
+
2010
2016
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id (op);
2011
2017
2012
2018
ggml_metal_buffer_id bid_pad = bid_dst;
2013
2019
bid_pad.offs += ggml_nbytes (op);
2014
2020
2021
+ ggml_metal_buffer_id bid_tmp = bid_pad;
2022
+ bid_tmp.offs += ggml_metal_op_flash_attn_ext_extra_pad (op);
2023
+
2015
2024
if (!ggml_metal_op_flash_attn_ext_use_vec (op)) {
2016
2025
// half8x8 kernel
2017
2026
const int64_t nqptg = 8 ; // queries per threadgroup !! sync with kernel template arguments !!
@@ -2048,14 +2057,10 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2048
2057
2049
2058
ggml_metal_encoder_set_pipeline (enc, pipeline0);
2050
2059
ggml_metal_encoder_set_bytes (enc, &args0, sizeof (args0), 0 );
2051
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [1 ]), 1 );
2052
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [2 ]), 2 );
2053
- if (op->src [3 ]) {
2054
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [3 ]), 3 );
2055
- } else {
2056
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [0 ]), 3 );
2057
- }
2058
- ggml_metal_encoder_set_buffer (enc, bid_pad, 4 );
2060
+ ggml_metal_encoder_set_buffer (enc, bid_src1, 1 );
2061
+ ggml_metal_encoder_set_buffer (enc, bid_src2, 2 );
2062
+ ggml_metal_encoder_set_buffer (enc, bid_src3, 3 );
2063
+ ggml_metal_encoder_set_buffer (enc, bid_pad, 4 );
2059
2064
2060
2065
assert (ne12 == ne22);
2061
2066
assert (ne13 == ne23);
@@ -2137,21 +2142,13 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2137
2142
2138
2143
ggml_metal_encoder_set_pipeline (enc, pipeline);
2139
2144
ggml_metal_encoder_set_bytes (enc, &args, sizeof (args), 0 );
2140
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [0 ]), 1 );
2141
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [1 ]), 2 );
2142
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [2 ]), 3 );
2143
- if (op->src [3 ]) {
2144
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [3 ]), 4 );
2145
- } else {
2146
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [0 ]), 4 );
2147
- }
2148
- if (op->src [4 ]) {
2149
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [4 ]), 5 );
2150
- } else {
2151
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [0 ]), 5 );
2152
- }
2153
- ggml_metal_encoder_set_buffer (enc, bid_pad, 6 );
2154
- ggml_metal_encoder_set_buffer (enc, bid_dst, 7 );
2145
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1 );
2146
+ ggml_metal_encoder_set_buffer (enc, bid_src1, 2 );
2147
+ ggml_metal_encoder_set_buffer (enc, bid_src2, 3 );
2148
+ ggml_metal_encoder_set_buffer (enc, bid_src3, 4 );
2149
+ ggml_metal_encoder_set_buffer (enc, bid_src4, 5 );
2150
+ ggml_metal_encoder_set_buffer (enc, bid_pad, 6 );
2151
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 7 );
2155
2152
2156
2153
ggml_metal_encoder_set_threadgroup_memory_size (enc, smem, 0 );
2157
2154
@@ -2194,14 +2191,10 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2194
2191
2195
2192
ggml_metal_encoder_set_pipeline (enc, pipeline0);
2196
2193
ggml_metal_encoder_set_bytes (enc, &args0, sizeof (args0), 0 );
2197
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [1 ]), 1 );
2198
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [2 ]), 2 );
2199
- if (op->src [3 ]) {
2200
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [3 ]), 3 );
2201
- } else {
2202
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [0 ]), 3 );
2203
- }
2204
- ggml_metal_encoder_set_buffer (enc, bid_pad, 4 );
2194
+ ggml_metal_encoder_set_buffer (enc, bid_src1, 1 );
2195
+ ggml_metal_encoder_set_buffer (enc, bid_src2, 2 );
2196
+ ggml_metal_encoder_set_buffer (enc, bid_src3, 3 );
2197
+ ggml_metal_encoder_set_buffer (enc, bid_pad, 4 );
2205
2198
2206
2199
assert (ne12 == ne22);
2207
2200
assert (ne13 == ne23);
@@ -2300,26 +2293,20 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2300
2293
2301
2294
ggml_metal_encoder_set_pipeline (enc, pipeline);
2302
2295
ggml_metal_encoder_set_bytes (enc, &args, sizeof (args), 0 );
2303
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [0 ]), 1 );
2304
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [1 ]), 2 );
2305
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [2 ]), 3 );
2306
- if (op->src [3 ]) {
2307
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [3 ]), 4 );
2308
- } else {
2309
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [0 ]), 4 );
2310
- }
2311
- if (op->src [4 ]) {
2312
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [4 ]), 5 );
2313
- } else {
2314
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [0 ]), 5 );
2315
- }
2296
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1 );
2297
+ ggml_metal_encoder_set_buffer (enc, bid_src1, 2 );
2298
+ ggml_metal_encoder_set_buffer (enc, bid_src2, 3 );
2299
+ ggml_metal_encoder_set_buffer (enc, bid_src3, 4 );
2300
+ ggml_metal_encoder_set_buffer (enc, bid_src4, 5 );
2316
2301
2317
2302
const size_t smem = FATTN_SMEM (nsg);
2318
2303
2319
2304
// printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, props_dev->max_theadgroup_memory_size, (int) nsg, (int) nsgmax);
2320
2305
GGML_ASSERT (smem <= props_dev->max_theadgroup_memory_size );
2321
2306
2322
2307
if (nwg == 1 ) {
2308
+ assert (ggml_metal_op_flash_attn_ext_extra_tmp (op) == 0 );
2309
+
2323
2310
// using 1 workgroup -> write the result directly into dst
2324
2311
ggml_metal_encoder_set_buffer (enc, bid_pad, 6 );
2325
2312
ggml_metal_encoder_set_buffer (enc, bid_dst, 7 );
@@ -2329,13 +2316,12 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2329
2316
ggml_metal_encoder_dispatch_threadgroups (enc, (ne01 + nqptg - 1 )/nqptg, ne02, ne03*nwg, 32 , nsg, 1 );
2330
2317
} else {
2331
2318
// sanity checks
2319
+ assert (ggml_metal_op_flash_attn_ext_extra_tmp (op) != 0 );
2320
+
2332
2321
GGML_ASSERT (ne01*ne02*ne03 == ne1*ne2*ne3);
2333
2322
GGML_ASSERT ((uint64_t )ne1*ne2*ne3 <= (1u << 31 ));
2334
2323
2335
2324
// write the results from each workgroup into a temp buffer
2336
- ggml_metal_buffer_id bid_tmp = bid_dst;
2337
- bid_tmp.offs += ggml_nbytes (op) + ggml_metal_op_flash_attn_ext_extra_pad (op);
2338
-
2339
2325
ggml_metal_encoder_set_buffer (enc, bid_pad, 6 );
2340
2326
ggml_metal_encoder_set_buffer (enc, bid_tmp, 7 );
2341
2327
0 commit comments