@@ -5232,14 +5232,17 @@ static void ggml_compute_forward_soft_max_f32(
52325232    memcpy (&scale,    (float  *) dst->op_params  + 0 , sizeof (float ));
52335233    memcpy (&max_bias, (float  *) dst->op_params  + 1 , sizeof (float ));
52345234
5235-     //  TODO: handle transposed/permuted matrices
5236- 
52375235    const  int  ith = params->ith ;
52385236    const  int  nth = params->nth ;
52395237
52405238    GGML_TENSOR_UNARY_OP_LOCALS
52415239
5242-     // const int64_t ne11 = src1 ? src1->ne[1] : 1;
5240+     const  int64_t  nb11 = src1 ? src1->nb [1 ] : 1 ;
5241+     const  int64_t  nb12 = src1 ? src1->nb [2 ] : 1 ;
5242+     const  int64_t  nb13 = src1 ? src1->nb [3 ] : 1 ;
5243+ 
5244+     const  int64_t  ne12 = src1 ? src1->ne [2 ] : 1 ;
5245+     const  int64_t  ne13 = src1 ? src1->ne [3 ] : 1 ;
52435246
52445247    //  TODO: is this supposed to be ceil instead of floor?
52455248    //        https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
@@ -5249,68 +5252,66 @@ static void ggml_compute_forward_soft_max_f32(
52495252    const  float  m0 = powf (2 .0f , -(max_bias       ) / n_head_log2);
52505253    const  float  m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
52515254
5252-     const  int  nc = src0->ne [0 ];
5253-     const  int  nr = ggml_nrows (src0);
5254- 
5255-     //  rows per thread
5256-     const  int  dr = (nr + nth - 1 )/nth;
5257- 
5258-     //  row range for this thread
5259-     const  int  ir0 = dr*ith;
5260-     const  int  ir1 = MIN (ir0 + dr, nr);
5261- 
5262-     float  * wp = (float  *) params->wdata  + (nc + CACHE_LINE_SIZE_F32) * ith;
5255+     float  * wp = (float  *) params->wdata  + (ne00 + CACHE_LINE_SIZE_F32) * ith;
52635256
52645257    const  bool  use_f16 = (src1 && src1->type  == GGML_TYPE_F16);
52655258
5266-     for  (int  i1 = ir0; i1 < ir1; i1++) {
5267-         //  ALiBi
5268-         const  uint32_t  h = (i1/ne01)%ne02; //  head
5269-         const  float  slope = (max_bias > 0 .0f ) ? h < n_head_log2 ? powf (m0, h + 1 ) : powf (m1, 2 *(h - n_head_log2) + 1 ) : 1 .0f ;
5270- 
5271-         float  * sp = (float  *)((char  *) src0->data  + i1*src0->nb [1 ]);
5272-         float  * dp = (float  *)((char  *)  dst->data  +  i1*dst->nb [1 ]);
5273- 
5274-         //  broadcast the mask across rows
5275-         ggml_fp16_t  * mp_f16 = src1 ? (ggml_fp16_t  *)((char  *) src1->data ) + (i1%ne01)*ne00 : NULL ;
5276-         float        * mp_f32 = src1 ? (float        *)((char  *) src1->data ) + (i1%ne01)*ne00 : NULL ;
5277- 
5278-         ggml_vec_cpy_f32   (nc, wp, sp);
5279-         ggml_vec_scale_f32 (nc, wp, scale);
5280-         if  (mp_f32) {
5281-             if  (use_f16) {
5282-                 for  (int  i = 0 ; i < nc; ++i) {
5283-                     wp[i] += slope*GGML_CPU_FP16_TO_FP32 (mp_f16[i]);
5284-                 }
5285-             } else  {
5286-                 for  (int  i = 0 ; i < nc; ++i) {
5287-                     wp[i] += slope*mp_f32[i];
5259+     for  (int64_t  i03 = 0 ; i03 < ne03; i03++) {
5260+         for  (int64_t  i02 = 0 ; i02 < ne02; i02++) {
5261+             for  (int64_t  i01 = ith; i01 < ne01; i01 += nth) {
5262+                 const  int64_t  i11 = i01;
5263+                 const  int64_t  i12 = i02%ne12;
5264+                 const  int64_t  i13 = i03%ne13;
5265+ 
5266+                 //  ALiBi
5267+                 const  uint32_t  h = i02; //  head
5268+                 const  float  slope = (max_bias > 0 .0f ) ? h < n_head_log2 ? powf (m0, h + 1 ) : powf (m1, 2 *(h - n_head_log2) + 1 ) : 1 .0f ;
5269+ 
5270+                 float  * sp = (float  *)((char  *) src0->data  + i01*nb01 + i02*nb02 + i03*nb03);
5271+                 float  * dp = (float  *)((char  *)  dst->data  + i01*nb1  + i02*nb2  + i03*nb3);
5272+ 
5273+                 //  broadcast the mask across rows
5274+                 ggml_fp16_t  * mp_f16 = src1 ? (ggml_fp16_t  *)((char  *) src1->data  + i11*nb11 + i12*nb12 + i13*nb13) : NULL ;
5275+                 float        * mp_f32 = src1 ? (float        *)((char  *) src1->data  + i11*nb11 + i12*nb12 + i13*nb13) : NULL ;
5276+ 
5277+                 ggml_vec_cpy_f32   (ne00, wp, sp);
5278+                 ggml_vec_scale_f32 (ne00, wp, scale);
5279+                 if  (mp_f32) {
5280+                     if  (use_f16) {
5281+                         for  (int  i = 0 ; i < ne00; ++i) {
5282+                             wp[i] += slope*GGML_CPU_FP16_TO_FP32 (mp_f16[i]);
5283+                         }
5284+                     } else  {
5285+                         for  (int  i = 0 ; i < ne00; ++i) {
5286+                             wp[i] += slope*mp_f32[i];
5287+                         }
5288+                     }
52885289                }
5289-             }
5290-         }
52915290
52925291#ifndef  NDEBUG
5293-         for  (int  i = 0 ; i < nc ; ++i) {
5294-             // printf("p[%d] = %f\n", i, p[i]);
5295-             assert (!isnan (wp[i]));
5296-         }
5292+                  for  (int  i = 0 ; i < ne00 ; ++i) {
5293+                      // printf("p[%d] = %f\n", i, p[i]);
5294+                      assert (!isnan (wp[i]));
5295+                  }
52975296#endif 
52985297
5299-         float  max = -INFINITY;
5300-         ggml_vec_max_f32 (nc , &max, wp);
5298+                  float  max = -INFINITY;
5299+                  ggml_vec_max_f32 (ne00 , &max, wp);
53015300
5302-         ggml_float sum = ggml_vec_soft_max_f32 (nc , dp, wp, max);
5303-         assert (sum > 0.0 );
5301+                  ggml_float sum = ggml_vec_soft_max_f32 (ne00 , dp, wp, max);
5302+                  assert (sum > 0.0 );
53045303
5305-         sum = 1.0 /sum;
5306-         ggml_vec_scale_f32 (nc , dp, sum);
5304+                  sum = 1.0 /sum;
5305+                  ggml_vec_scale_f32 (ne00 , dp, sum);
53075306
53085307#ifndef  NDEBUG
5309-         for  (int  i = 0 ; i < nc ; ++i) {
5310-             assert (!isnan (dp[i]));
5311-             assert (!isinf (dp[i]));
5312-         }
5308+                  for  (int  i = 0 ; i < ne00 ; ++i) {
5309+                      assert (!isnan (dp[i]));
5310+                      assert (!isinf (dp[i]));
5311+                  }
53135312#endif 
5313+             }
5314+         }
53145315    }
53155316}
53165317
@@ -7766,7 +7767,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
77667767    const  float  m0 = powf (2 .0f , -(max_bias       ) / n_head_log2);
77677768    const  float  m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
77687769
7769-     ggml_type    const  k_vec_dot_type        = ggml_get_type_traits_cpu (k->type )->vec_dot_type ;
7770+     ggml_type          const  k_vec_dot_type  = ggml_get_type_traits_cpu (k->type )->vec_dot_type ;
77707771    ggml_from_float_t  const  q_to_vec_dot   = ggml_get_type_traits_cpu (k_vec_dot_type)->from_float ;
77717772    ggml_vec_dot_t     const  kq_vec_dot     = ggml_get_type_traits_cpu (k->type )->vec_dot ;
77727773    ggml_to_float_t    const  v_to_float     = ggml_get_type_traits (v->type )->to_float ;
@@ -7798,7 +7799,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
77987799            memset (VKQ32, 0 , DV*sizeof (float ));
77997800        }
78007801
7801-         const  ggml_fp16_t  * mp = mask ? (ggml_fp16_t  *)((char  *) mask->data  + iq1*mask->nb [1 ]) : NULL ;
7802+         const  ggml_fp16_t  * mp = mask ? (ggml_fp16_t  *)((char  *) mask->data  + iq1*mask->nb [1 ] + (iq3%mask-> ne [ 2 ])*mask-> nb [ 2 ] ) : NULL ;
78027803
78037804        //  k indices
78047805        const  int  ik3 = iq3 / rk3;
0 commit comments