File tree Expand file tree Collapse file tree 4 files changed +1404
-1314
lines changed Expand file tree Collapse file tree 4 files changed +1404
-1314
lines changed Original file line number Diff line number Diff line change 2020#define N_R0_Q5_1 4
2121#define N_SG_Q5_1 2
2222
23- #define N_R0_Q8_0 4
24- #define N_SG_Q8_0 2
23+ #define N_R0_Q8_0 2
24+ #define N_SG_Q8_0 4
2525
2626#define N_R0_MXFP4 2
2727#define N_SG_MXFP4 2
6868#define N_R0_IQ4_XS 2
6969#define N_SG_IQ4_XS 2
7070
71+ // function constants offsets
72+ #define FC_FLASH_ATTN_EXT 100
73+ #define FC_FLASH_ATTN_EXT_VEC 200
74+ #define FC_FLASH_ATTN_EXT_VEC_REDUCE 300
75+
7176// kernel argument structs
7277//
7378// - element counters (e.g. ne00) typically use int32_t to reduce register usage
@@ -236,9 +241,11 @@ typedef struct {
236241 int32_t ne11 ;
237242 int32_t ne_12_2 ; // assume K and V are same shape
238243 int32_t ne_12_3 ;
244+ int32_t ns10 ;
239245 uint64_t nb11 ;
240246 uint64_t nb12 ;
241247 uint64_t nb13 ;
248+ int32_t ns20 ;
242249 uint64_t nb21 ;
243250 uint64_t nb22 ;
244251 uint64_t nb23 ;
@@ -258,10 +265,43 @@ typedef struct {
258265 float logit_softcap ;
259266} ggml_metal_kargs_flash_attn_ext ;
260267
268+ typedef struct {
269+ int32_t ne01 ;
270+ int32_t ne02 ;
271+ int32_t ne03 ;
272+ uint64_t nb01 ;
273+ uint64_t nb02 ;
274+ uint64_t nb03 ;
275+ int32_t ne11 ;
276+ int32_t ne_12_2 ; // assume K and V are same shape
277+ int32_t ne_12_3 ;
278+ int32_t ns10 ;
279+ uint64_t nb11 ;
280+ uint64_t nb12 ;
281+ uint64_t nb13 ;
282+ int32_t ns20 ;
283+ uint64_t nb21 ;
284+ uint64_t nb22 ;
285+ uint64_t nb23 ;
286+ int32_t ne32 ;
287+ int32_t ne33 ;
288+ uint64_t nb31 ;
289+ uint64_t nb32 ;
290+ uint64_t nb33 ;
291+ int32_t ne1 ;
292+ int32_t ne2 ;
293+ int32_t ne3 ;
294+ float scale ;
295+ float max_bias ;
296+ float m0 ;
297+ float m1 ;
298+ int32_t n_head_log2 ;
299+ float logit_softcap ;
300+ } ggml_metal_kargs_flash_attn_ext_vec ;
301+
261302typedef struct {
262303 int32_t nrows ;
263- int32_t ne20 ;
264- } ggml_metal_kargs_flash_attn_ext_reduce ;
304+ } ggml_metal_kargs_flash_attn_ext_vec_reduce ;
265305
266306typedef struct {
267307 int32_t ne00 ;
You can’t perform that action at this time.
0 commit comments