@@ -7909,10 +7909,10 @@ void ggml_compute_forward_argsort(
79097909
79107910//  ggml_compute_forward_flash_attn_ext
79117911
7912- static  void  ggml_compute_forward_flash_attn_ext_f16 (
7912+ static  void  ggml_compute_forward_flash_attn_ext_f16_one_chunk (
79137913        const  ggml_compute_params * params,
7914-         ggml_tensor * dst) { 
7915- 
7914+         ggml_tensor * dst, 
7915+          int  ir0,  int  ir1) { 
79167916    const  ggml_tensor * q     = dst->src [0 ];
79177917    const  ggml_tensor * k     = dst->src [1 ];
79187918    const  ggml_tensor * v     = dst->src [2 ];
@@ -7928,9 +7928,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
79287928    GGML_TENSOR_LOCALS (int64_t , ne,  dst, ne)
79297929    GGML_TENSOR_LOCALS (size_t ,  nb,  dst, nb)
79307930
7931-     const  int  ith = params->ith ;
7932-     const  int  nth = params->nth ;
7933- 
79347931    const  int64_t  DK = nek0;
79357932    const  int64_t  DV = nev0;
79367933    const  int64_t  N  = neq1;
@@ -7964,16 +7961,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
79647961
79657962    //  parallelize by q rows using ggml_vec_dot_f32
79667963
7967-     //  total rows in q
7968-     const  int  nr = neq1*neq2*neq3;
7969- 
7970-     //  rows per thread
7971-     const  int  dr = (nr + nth - 1 )/nth;
7972- 
7973-     //  row range for this thread
7974-     const  int  ir0 = dr*ith;
7975-     const  int  ir1 = MIN (ir0 + dr, nr);
7976- 
79777964    float  scale         = 1 .0f ;
79787965    float  max_bias      = 0 .0f ;
79797966    float  logit_softcap = 0 .0f ;
@@ -8000,6 +7987,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
80007987    GGML_ASSERT ((                            q_to_vec_dot) && " fattn: unsupported K-type" 
80017988    GGML_ASSERT ((v->type  == GGML_TYPE_F32 || v_to_float  ) && " fattn: unsupported V-type" 
80027989
7990+     int  ith = params->ith ;
7991+ 
80037992    //  loop over n_batch and n_head
80047993    for  (int  ir = ir0; ir < ir1; ++ir) {
80057994        //  q indices
@@ -8147,6 +8136,91 @@ static void ggml_compute_forward_flash_attn_ext_f16(
81478136    }
81488137}
81498138
8139+ static  void  ggml_compute_forward_flash_attn_ext_f16 (
8140+         const  ggml_compute_params * params,
8141+         ggml_tensor * dst) {
8142+ 
8143+     const  ggml_tensor * q     = dst->src [0 ];
8144+     const  ggml_tensor * k     = dst->src [1 ];
8145+     const  ggml_tensor * v     = dst->src [2 ];
8146+ 
8147+     GGML_TENSOR_LOCALS (int64_t , neq, q,   ne)
8148+     GGML_TENSOR_LOCALS (size_t ,  nbq, q,   nb)
8149+     GGML_TENSOR_LOCALS (int64_t , nek, k,   ne)
8150+     GGML_TENSOR_LOCALS (size_t ,  nbk, k,   nb)
8151+     GGML_TENSOR_LOCALS (int64_t , nev, v,   ne)
8152+     GGML_TENSOR_LOCALS (size_t ,  nbv, v,   nb)
8153+     GGML_TENSOR_LOCALS (int64_t , ne,  dst, ne)
8154+     GGML_TENSOR_LOCALS (size_t ,  nb,  dst, nb)
8155+ 
8156+     const  int64_t  DK = nek0;
8157+     const  int64_t  DV = nev0;
8158+     const  int64_t  N  = neq1;
8159+ 
8160+     GGML_ASSERT (ne0 == DV);
8161+     GGML_ASSERT (ne2 == N);
8162+ 
8163+     //  input tensor rows must be contiguous
8164+     GGML_ASSERT (nbq0 == ggml_type_size (q->type ));
8165+     GGML_ASSERT (nbk0 == ggml_type_size (k->type ));
8166+     GGML_ASSERT (nbv0 == ggml_type_size (v->type ));
8167+ 
8168+     GGML_ASSERT (neq0 == DK);
8169+     GGML_ASSERT (nek0 == DK);
8170+     GGML_ASSERT (nev0 == DV);
8171+ 
8172+     GGML_ASSERT (neq1 == N);
8173+ 
8174+     //  dst cannot be transposed or permuted
8175+     GGML_ASSERT (nb0 == sizeof (float ));
8176+     GGML_ASSERT (nb0 <= nb1);
8177+     GGML_ASSERT (nb1 <= nb2);
8178+     GGML_ASSERT (nb2 <= nb3);
8179+ 
8180+     //  parallelize by q rows using ggml_vec_dot_f32
8181+ 
8182+     //  total rows in q
8183+     const  int64_t  nr = neq1*neq2*neq3;
8184+ 
8185+     //  rows per thread
8186+     const  int  ith = params->ith ;
8187+     const  int  nth = params->nth ;
8188+ 
8189+     //  disable for NUMA
8190+     const  bool  disable_chunking = ggml_is_numa ();
8191+ 
8192+     //  4x chunks per thread
8193+     int  nth_scaled = nth * 4 ;
8194+     int64_t  chunk_size = (nr + nth_scaled - 1 ) / nth_scaled;
8195+     int64_t  nchunk     = (nr + chunk_size - 1 ) / chunk_size;
8196+ 
8197+     if  (nth == 1  || nchunk < nth || disable_chunking) {
8198+         nchunk = nth;
8199+     }
8200+ 
8201+     if  (ith == 0 ) {
8202+         //  Every thread starts at ith, so the first unprocessed chunk is nth.  This save a bit of coordination right at the start.
8203+         ggml_threadpool_chunk_set (params->threadpool , nth);
8204+     }
8205+ 
8206+     ggml_barrier (params->threadpool );
8207+ 
8208+     //  The number of elements in each chunk
8209+     const  int64_t  dr = (nr + nchunk - 1 ) / nchunk;
8210+ 
8211+     //  The first chunk comes from our thread_id, the rest will get auto-assigned.
8212+     int  current_chunk = ith;
8213+ 
8214+     while  (current_chunk < nchunk) {
8215+         const  int64_t  ir0 = dr * current_chunk;
8216+         const  int64_t  ir1 = MIN (ir0 + dr, nr);
8217+ 
8218+         ggml_compute_forward_flash_attn_ext_f16_one_chunk (params, dst, ir0, ir1);
8219+ 
8220+         current_chunk = ggml_threadpool_chunk_add (params->threadpool , 1 );
8221+     }
8222+ }
8223+ 
81508224void  ggml_compute_forward_flash_attn_ext (
81518225        const  ggml_compute_params * params,
81528226        ggml_tensor * dst) {
0 commit comments