@@ -7909,10 +7909,10 @@ void ggml_compute_forward_argsort(
79097909
79107910// ggml_compute_forward_flash_attn_ext
79117911
7912- static void ggml_compute_forward_flash_attn_ext_f16 (
7912+ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk (
79137913 const ggml_compute_params * params,
7914- ggml_tensor * dst) {
7915-
7914+ ggml_tensor * dst,
7915+ int ir0, int ir1) {
79167916 const ggml_tensor * q = dst->src [0 ];
79177917 const ggml_tensor * k = dst->src [1 ];
79187918 const ggml_tensor * v = dst->src [2 ];
@@ -7928,9 +7928,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
79287928 GGML_TENSOR_LOCALS (int64_t , ne, dst, ne)
79297929 GGML_TENSOR_LOCALS (size_t , nb, dst, nb)
79307930
7931- const int ith = params->ith ;
7932- const int nth = params->nth ;
7933-
79347931 const int64_t DK = nek0;
79357932 const int64_t DV = nev0;
79367933 const int64_t N = neq1;
@@ -7964,16 +7961,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
79647961
79657962 // parallelize by q rows using ggml_vec_dot_f32
79667963
7967- // total rows in q
7968- const int nr = neq1*neq2*neq3;
7969-
7970- // rows per thread
7971- const int dr = (nr + nth - 1 )/nth;
7972-
7973- // row range for this thread
7974- const int ir0 = dr*ith;
7975- const int ir1 = MIN (ir0 + dr, nr);
7976-
79777964 float scale = 1 .0f ;
79787965 float max_bias = 0 .0f ;
79797966 float logit_softcap = 0 .0f ;
@@ -8000,6 +7987,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
80007987 GGML_ASSERT (( q_to_vec_dot) && " fattn: unsupported K-type" );
80017988 GGML_ASSERT ((v->type == GGML_TYPE_F32 || v_to_float ) && " fattn: unsupported V-type" );
80027989
7990+ int ith = params->ith ;
7991+
80037992 // loop over n_batch and n_head
80047993 for (int ir = ir0; ir < ir1; ++ir) {
80057994 // q indices
@@ -8147,6 +8136,91 @@ static void ggml_compute_forward_flash_attn_ext_f16(
81478136 }
81488137}
81498138
8139+ static void ggml_compute_forward_flash_attn_ext_f16 (
8140+ const ggml_compute_params * params,
8141+ ggml_tensor * dst) {
8142+
8143+ const ggml_tensor * q = dst->src [0 ];
8144+ const ggml_tensor * k = dst->src [1 ];
8145+ const ggml_tensor * v = dst->src [2 ];
8146+
8147+ GGML_TENSOR_LOCALS (int64_t , neq, q, ne)
8148+ GGML_TENSOR_LOCALS (size_t , nbq, q, nb)
8149+ GGML_TENSOR_LOCALS (int64_t , nek, k, ne)
8150+ GGML_TENSOR_LOCALS (size_t , nbk, k, nb)
8151+ GGML_TENSOR_LOCALS (int64_t , nev, v, ne)
8152+ GGML_TENSOR_LOCALS (size_t , nbv, v, nb)
8153+ GGML_TENSOR_LOCALS (int64_t , ne, dst, ne)
8154+ GGML_TENSOR_LOCALS (size_t , nb, dst, nb)
8155+
8156+ const int64_t DK = nek0;
8157+ const int64_t DV = nev0;
8158+ const int64_t N = neq1;
8159+
8160+ GGML_ASSERT (ne0 == DV);
8161+ GGML_ASSERT (ne2 == N);
8162+
8163+ // input tensor rows must be contiguous
8164+ GGML_ASSERT (nbq0 == ggml_type_size (q->type ));
8165+ GGML_ASSERT (nbk0 == ggml_type_size (k->type ));
8166+ GGML_ASSERT (nbv0 == ggml_type_size (v->type ));
8167+
8168+ GGML_ASSERT (neq0 == DK);
8169+ GGML_ASSERT (nek0 == DK);
8170+ GGML_ASSERT (nev0 == DV);
8171+
8172+ GGML_ASSERT (neq1 == N);
8173+
8174+ // dst cannot be transposed or permuted
8175+ GGML_ASSERT (nb0 == sizeof (float ));
8176+ GGML_ASSERT (nb0 <= nb1);
8177+ GGML_ASSERT (nb1 <= nb2);
8178+ GGML_ASSERT (nb2 <= nb3);
8179+
8180+ // parallelize by q rows using ggml_vec_dot_f32
8181+
8182+ // total rows in q
8183+ const int64_t nr = neq1*neq2*neq3;
8184+
8185+ // rows per thread
8186+ const int ith = params->ith ;
8187+ const int nth = params->nth ;
8188+
8189+ // disable for NUMA
8190+ const bool disable_chunking = ggml_is_numa ();
8191+
8192+ // 4x chunks per thread
8193+ int nth_scaled = nth * 4 ;
8194+ int64_t chunk_size = (nr + nth_scaled - 1 ) / nth_scaled;
8195+ int64_t nchunk = (nr + chunk_size - 1 ) / chunk_size;
8196+
8197+ if (nth == 1 || nchunk < nth || disable_chunking) {
8198+ nchunk = nth;
8199+ }
8200+
8201+ if (ith == 0 ) {
8202+ // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
8203+ ggml_threadpool_chunk_set (params->threadpool , nth);
8204+ }
8205+
8206+ ggml_barrier (params->threadpool );
8207+
8208+ // The number of elements in each chunk
8209+ const int64_t dr = (nr + nchunk - 1 ) / nchunk;
8210+
8211+ // The first chunk comes from our thread_id, the rest will get auto-assigned.
8212+ int current_chunk = ith;
8213+
8214+ while (current_chunk < nchunk) {
8215+ const int64_t ir0 = dr * current_chunk;
8216+ const int64_t ir1 = MIN (ir0 + dr, nr);
8217+
8218+ ggml_compute_forward_flash_attn_ext_f16_one_chunk (params, dst, ir0, ir1);
8219+
8220+ current_chunk = ggml_threadpool_chunk_add (params->threadpool , 1 );
8221+ }
8222+ }
8223+
81508224void ggml_compute_forward_flash_attn_ext (
81518225 const ggml_compute_params * params,
81528226 ggml_tensor * dst) {
0 commit comments