@@ -1918,19 +1918,19 @@ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
1918
1918
const bool has_mask = op->src [3 ] != nullptr ;
1919
1919
1920
1920
if (ggml_metal_op_flash_attn_ext_use_vec (op)) {
1921
- const bool has_kvpad = ne11 % 32 != 0 ;
1921
+ const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0 ;
1922
1922
1923
1923
if (has_kvpad) {
1924
- res += 32 *(
1924
+ res += OP_FLASH_ATTN_EXT_VEC_NCPSG *(
1925
1925
nb11*ne12*ne13 +
1926
1926
nb21*ne22*ne23 +
1927
1927
(has_mask ? ggml_type_size (GGML_TYPE_F16)*ne31*ne32*ne33 : 0 ));
1928
1928
}
1929
1929
} else {
1930
- const bool has_kvpad = ne11 % 64 != 0 ;
1930
+ const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0 ;
1931
1931
1932
1932
if (has_kvpad) {
1933
- res += 64 *(
1933
+ res += OP_FLASH_ATTN_EXT_NCPSG *(
1934
1934
nb11*ne12*ne13 +
1935
1935
nb21*ne22*ne23 +
1936
1936
(has_mask ? ggml_type_size (GGML_TYPE_F16)*ne31*ne32*ne33 : 0 ));
@@ -1940,6 +1940,44 @@ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
1940
1940
return res;
1941
1941
}
1942
1942
1943
+ size_t ggml_metal_op_flash_attn_ext_extra_blk (const ggml_tensor * op) {
1944
+ assert (op->op == GGML_OP_FLASH_ATTN_EXT);
1945
+
1946
+ GGML_TENSOR_LOCALS ( int32_t , ne0, op->src [0 ], ne);
1947
+ // GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1948
+ // GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1949
+ // GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1950
+ // GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
1951
+ // GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
1952
+ GGML_TENSOR_LOCALS ( int32_t , ne3, op->src [3 ], ne);
1953
+ GGML_TENSOR_LOCALS (uint64_t , nb3, op->src [3 ], nb);
1954
+
1955
+ size_t res = 0 ;
1956
+
1957
+ const bool has_mask = op->src [3 ] != nullptr ;
1958
+
1959
+ if (!has_mask) {
1960
+ return res;
1961
+ }
1962
+
1963
+ const bool is_vec = ggml_metal_op_flash_attn_ext_use_vec (op);
1964
+
1965
+ // this optimization is not useful for the vector kernels
1966
+ if (is_vec) {
1967
+ return res;
1968
+ }
1969
+
1970
+ const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG;
1971
+ const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;
1972
+
1973
+ const int64_t ne1 = (ne01 + nqptg - 1 )/nqptg;
1974
+ const int64_t ne0 = (ne30 + ncpsg - 1 )/ncpsg;
1975
+
1976
+ res += GGML_PAD (ggml_type_size (GGML_TYPE_I8)*ne0*ne1*ne32*ne33, 32 );
1977
+
1978
+ return res;
1979
+ }
1980
+
1943
1981
size_t ggml_metal_op_flash_attn_ext_extra_tmp (const ggml_tensor * op) {
1944
1982
assert (op->op == GGML_OP_FLASH_ATTN_EXT);
1945
1983
@@ -2034,18 +2072,23 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2034
2072
ggml_metal_buffer_id bid_pad = bid_dst;
2035
2073
bid_pad.offs += ggml_nbytes (op);
2036
2074
2037
- ggml_metal_buffer_id bid_tmp = bid_pad;
2038
- bid_tmp.offs += ggml_metal_op_flash_attn_ext_extra_pad (op);
2075
+ ggml_metal_buffer_id bid_blk = bid_pad;
2076
+ bid_blk.offs += ggml_metal_op_flash_attn_ext_extra_pad (op);
2077
+
2078
+ ggml_metal_buffer_id bid_tmp = bid_blk;
2079
+ bid_tmp.offs += ggml_metal_op_flash_attn_ext_extra_blk (op);
2039
2080
2040
2081
if (!ggml_metal_op_flash_attn_ext_use_vec (op)) {
2041
2082
// half8x8 kernel
2042
- const int64_t nqptg = 8 ; // queries per threadgroup !! sync with kernel template arguments !!
2043
- const int64_t ncpsg = 64 ; // cache values per simdgroup !! sync with kernel template arguments !!
2083
+ const int nqptg = OP_FLASH_ATTN_EXT_NQPTG; // queries per threadgroup
2084
+ const int ncpsg = OP_FLASH_ATTN_EXT_NCPSG ; // cache values per simdgroup
2044
2085
2045
2086
GGML_ASSERT (nqptg <= 32 );
2046
2087
GGML_ASSERT (nqptg % 8 == 0 );
2047
2088
GGML_ASSERT (ncpsg % 32 == 0 );
2048
2089
2090
+ bool need_sync = false ;
2091
+
2049
2092
const bool has_kvpad = ne11 % ncpsg != 0 ;
2050
2093
2051
2094
if (has_kvpad) {
@@ -2083,11 +2126,46 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2083
2126
2084
2127
ggml_metal_encoder_dispatch_threadgroups (enc, ncpsg, std::max (ne12, ne32), std::max (ne13, ne33), 32 , 1 , 1 );
2085
2128
2086
- ggml_metal_op_concurrency_reset (ctx) ;
2129
+ need_sync = true ;
2087
2130
} else {
2088
2131
assert (ggml_metal_op_flash_attn_ext_extra_pad (op) == 0 );
2089
2132
}
2090
2133
2134
+ if (has_mask) {
2135
+ assert (ggml_metal_op_flash_attn_ext_extra_blk (op) != 0 );
2136
+
2137
+ ggml_metal_kargs_flash_attn_ext_blk args0 = {
2138
+ /* .ne01 =*/ ne01,
2139
+ /* .ne30 =*/ ne30,
2140
+ /* .ne31 =*/ ne31,
2141
+ /* .ne32 =*/ ne32,
2142
+ /* .ne33 =*/ ne33,
2143
+ /* .nb31 =*/ nb31,
2144
+ /* .nb32 =*/ nb32,
2145
+ /* .nb33 =*/ nb33,
2146
+ };
2147
+
2148
+ ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_blk (lib, op, nqptg, ncpsg);
2149
+
2150
+ ggml_metal_encoder_set_pipeline (enc, pipeline0);
2151
+ ggml_metal_encoder_set_bytes (enc, &args0, sizeof (args0), 0 );
2152
+ ggml_metal_encoder_set_buffer (enc, bid_src3, 1 );
2153
+ ggml_metal_encoder_set_buffer (enc, bid_blk, 2 );
2154
+
2155
+ const int32_t nblk1 = ((ne01 + nqptg - 1 )/nqptg);
2156
+ const int32_t nblk0 = ((ne30 + ncpsg - 1 )/ncpsg);
2157
+
2158
+ ggml_metal_encoder_dispatch_threadgroups (enc, nblk0, nblk1, ne32*ne33, 32 , 1 , 1 );
2159
+
2160
+ need_sync = true ;
2161
+ } else {
2162
+ assert (ggml_metal_op_flash_attn_ext_extra_blk (op) == 0 );
2163
+ }
2164
+
2165
+ if (need_sync) {
2166
+ ggml_metal_op_concurrency_reset (ctx);
2167
+ }
2168
+
2091
2169
const int is_q = ggml_is_quantized (op->src [1 ]->type ) ? 1 : 0 ;
2092
2170
2093
2171
// 2*(2*ncpsg)
@@ -2164,22 +2242,25 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2164
2242
ggml_metal_encoder_set_buffer (enc, bid_src3, 4 );
2165
2243
ggml_metal_encoder_set_buffer (enc, bid_src4, 5 );
2166
2244
ggml_metal_encoder_set_buffer (enc, bid_pad, 6 );
2167
- ggml_metal_encoder_set_buffer (enc, bid_dst, 7 );
2245
+ ggml_metal_encoder_set_buffer (enc, bid_blk, 7 );
2246
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 8 );
2168
2247
2169
2248
ggml_metal_encoder_set_threadgroup_memory_size (enc, smem, 0 );
2170
2249
2171
2250
ggml_metal_encoder_dispatch_threadgroups (enc, (ne01 + nqptg - 1 )/nqptg, ne02, ne03, 32 , nsg, 1 );
2172
2251
#undef FATTN_SMEM
2173
2252
} else {
2174
2253
// half4x4 kernel
2175
- const int64_t nqptg = 1 ; // queries per threadgroup !! sync with kernel template arguments !!
2176
- const int64_t ncpsg = 32 ; // cache values per simdgroup !! sync with kernel template arguments !!
2177
- const int64_t nkpsg = 1 *ncpsg;
2254
+ const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPTG; // queries per threadgroup
2255
+ const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG ; // cache values per simdgroup !! sync with kernel template arguments !!
2256
+ const int nkpsg = 1 *ncpsg;
2178
2257
2179
2258
GGML_ASSERT (nqptg <= 32 );
2180
2259
GGML_ASSERT (nqptg % 1 == 0 );
2181
2260
GGML_ASSERT (ncpsg % 32 == 0 );
2182
2261
2262
+ bool need_sync = false ;
2263
+
2183
2264
const bool has_kvpad = ne11 % ncpsg != 0 ;
2184
2265
2185
2266
if (has_kvpad) {
@@ -2217,11 +2298,15 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2217
2298
2218
2299
ggml_metal_encoder_dispatch_threadgroups (enc, ncpsg, std::max (ne12, ne32), std::max (ne13, ne33), 32 , 1 , 1 );
2219
2300
2220
- ggml_metal_op_concurrency_reset (ctx) ;
2301
+ need_sync = true ;
2221
2302
} else {
2222
2303
assert (ggml_metal_op_flash_attn_ext_extra_pad (op) == 0 );
2223
2304
}
2224
2305
2306
+ if (need_sync) {
2307
+ ggml_metal_op_concurrency_reset (ctx);
2308
+ }
2309
+
2225
2310
// ne00 + 2*ncpsg*(nsg)
2226
2311
// for each query, we load it as f16 in shared memory (ne00)
2227
2312
// and store the soft_max values and the mask
0 commit comments