@@ -3228,37 +3228,41 @@ static void ggml_metal_encode_node(
32283228 }
32293229 }
32303230
3231+ ggml_metal_kargs_flash_attn_ext args = {
3232+ .ne01 = ne01,
3233+ .ne02 = ne02,
3234+ .ne03 = ne03,
3235+ .nb01 = nb01,
3236+ .nb02 = nb02,
3237+ .nb03 = nb03,
3238+ .ne11 = ne11,
3239+ .ne_12_2 = ne12,
3240+ .ne_12_3 = ne13,
3241+ .nb_12_1 = nb11,
3242+ .nb_12_2 = nb12,
3243+ .nb_12_3 = nb13,
3244+ .nb31 = nb31,
3245+ .ne1 = ne1,
3246+ .ne2 = ne2,
3247+ .scale = scale,
3248+ .max_bias = max_bias,
3249+ .m0 = m0,
3250+ .m1 = m1,
3251+ .n_head_log2 = n_head_log2,
3252+ .logit_softcap = logit_softcap,
3253+ };
3254+
32313255 [encoder setComputePipelineState: pipeline];
3232- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
3233- [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
3234- [encoder setBuffer: id_src2 offset: offs_src2 atIndex: 2 ];
3256+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
3257+ [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
3258+ [encoder setBuffer: id_src2 offset: offs_src2 atIndex: 2 ];
32353259 if (id_src3) {
3236- [encoder setBuffer: id_src3 offset: offs_src3 atIndex: 3 ];
3260+ [encoder setBuffer: id_src3 offset: offs_src3 atIndex: 3 ];
32373261 } else {
3238- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 3 ];
3262+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 3 ];
32393263 }
3240- [encoder setBuffer: id_dst offset: offs_dst atIndex: 4 ];
3241- [encoder setBytes: &ne01 length: sizeof ( int64_t ) atIndex: 5 ];
3242- [encoder setBytes: &ne02 length: sizeof ( int64_t ) atIndex: 6 ];
3243- [encoder setBytes: &ne03 length: sizeof ( int64_t ) atIndex: 7 ];
3244- [encoder setBytes: &nb01 length: sizeof (uint64_t ) atIndex: 8 ];
3245- [encoder setBytes: &nb02 length: sizeof (uint64_t ) atIndex: 9 ];
3246- [encoder setBytes: &nb03 length: sizeof (uint64_t ) atIndex: 10 ];
3247- [encoder setBytes: &ne11 length: sizeof ( int64_t ) atIndex: 11 ];
3248- [encoder setBytes: &ne12 length: sizeof ( int64_t ) atIndex: 12 ];
3249- [encoder setBytes: &ne13 length: sizeof ( int64_t ) atIndex: 13 ];
3250- [encoder setBytes: &nb11 length: sizeof (uint64_t ) atIndex: 14 ];
3251- [encoder setBytes: &nb12 length: sizeof (uint64_t ) atIndex: 15 ];
3252- [encoder setBytes: &nb13 length: sizeof (uint64_t ) atIndex: 16 ];
3253- [encoder setBytes: &nb31 length: sizeof (uint64_t ) atIndex: 17 ];
3254- [encoder setBytes: &ne1 length: sizeof ( int64_t ) atIndex: 18 ];
3255- [encoder setBytes: &ne2 length: sizeof ( int64_t ) atIndex: 19 ];
3256- [encoder setBytes: &scale length: sizeof ( float ) atIndex: 20 ];
3257- [encoder setBytes: &max_bias length: sizeof ( float ) atIndex: 21 ];
3258- [encoder setBytes: &m0 length: sizeof (m0) atIndex: 22 ];
3259- [encoder setBytes: &m1 length: sizeof (m1) atIndex: 23 ];
3260- [encoder setBytes: &n_head_log2 length: sizeof (n_head_log2) atIndex: 24 ];
3261- [encoder setBytes: &logit_softcap length: sizeof (logit_softcap) atIndex: 25 ];
3264+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 4 ];
3265+ [encoder setBytes: &args length: sizeof (args) atIndex: 5 ];
32623266
32633267 if (!use_vec_kernel) {
32643268 // half8x8 kernel
0 commit comments