@@ -5531,25 +5531,28 @@ static void ggml_mrope_cache_init(
55315531    }
55325532}
55335533
5534- static  void  rotate_pairs (const  int64_t  n, const  int64_t  n_offset, const  float  * cache, const  float  * src_data, float  * dst_data, const  int  scale = 2 ) {
5534+ 
5535+ template <typename  T>
5536+ static  void  rotate_pairs (const  int64_t  n, const  int64_t  n_offset, const  float  * cache, const  T * src_data, T * dst_data, const  int  scale = 2 ) {
55355537  for  (int64_t  i0 = 0 ; i0 < n; i0 += 2 ) {
55365538    const  int64_t  ic = i0/scale; // hack for GGML_ROPE_TYPE_NORMAL, where we need ic = i0; for all other cases, ic = i0/2
55375539
55385540    const  float  cos_theta = cache[i0 + 0 ];
55395541    const  float  sin_theta = cache[i0 + 1 ];
55405542
5541-     const  float  * const  src = src_data + ic;
5542-     float  * dst             = dst_data + ic;
5543+     const  T  * const  src = src_data + ic;
5544+     T  * dst             = dst_data + ic;
55435545
5544-     const  float  x0 = src[0 ];
5545-     const  float  x1 = src[n_offset];
5546+     const  float  x0 = type_conversion_table<T>:: to_f32 ( src[0 ]) ;
5547+     const  float  x1 = type_conversion_table<T>:: to_f32 ( src[n_offset]) ;
55465548
5547-     dst[0 ]        = x0*cos_theta - x1*sin_theta;
5548-     dst[n_offset] = x0*sin_theta + x1*cos_theta;
5549+     dst[0 ]        = type_conversion_table<T>:: from_f32 ( x0*cos_theta - x1*sin_theta) ;
5550+     dst[n_offset] = type_conversion_table<T>:: from_f32 ( x0*sin_theta + x1*cos_theta) ;
55495551  }
55505552}
55515553
5552- static  void  ggml_compute_forward_rope_f32 (
5554+ template <typename  T> // float or ggml_fp16_t
5555+ static  void  ggml_compute_forward_rope_flt (
55535556        const  ggml_compute_params * params,
55545557        ggml_tensor * dst,
55555558        const  bool  forward) {
@@ -5558,6 +5561,9 @@ static void ggml_compute_forward_rope_f32(
55585561    const  ggml_tensor * src1 = dst->src [1 ];
55595562    const  ggml_tensor * src2 = dst->src [2 ];
55605563
5564+     GGML_ASSERT (src0->type  == GGML_TYPE_F32 || src0->type  == GGML_TYPE_F16);
5565+     GGML_ASSERT (src1->type  == GGML_TYPE_I32);
5566+ 
55615567    float  freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
55625568    int  sections[4 ];
55635569
@@ -5580,7 +5586,8 @@ static void ggml_compute_forward_rope_f32(
55805586    // printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
55815587    // printf("n_past = %d, ne2 = %d\n", n_past, ne2);
55825588
5583-     GGML_ASSERT (nb00 == sizeof (float ));
5589+     GGML_ASSERT (nb0 == nb00);
5590+     GGML_ASSERT (nb0 == sizeof (T));
55845591
55855592    const  int  ith = params->ith ;
55865593    const  int  nth = params->nth ;
@@ -5652,30 +5659,30 @@ static void ggml_compute_forward_rope_f32(
56525659                if  (ir++ < ir0) continue ;
56535660                if  (ir   > ir1) break ;
56545661
5655-                 float  * src = (float  *)((char  *) src0->data  + i3*nb03 + i2*nb02 + i1*nb01);
5656-                 float  * dst_data  = (float  *)((char  *)  dst->data  + i3*nb3  + i2*nb2  + i1*nb1);
5662+                 T  * src = (T  *)((char  *) src0->data  + i3*nb03 + i2*nb02 + i1*nb01);
5663+                 T  * dst_data  = (T  *)((char  *)  dst->data  + i3*nb3  + i2*nb2  + i1*nb1);
56575664
56585665                switch  (mode) {
56595666                  case  GGML_ROPE_TYPE_NORMAL:
5660-                     rotate_pairs (n_dims, 1 , cache, src, dst_data, 1 );
5667+                     rotate_pairs<T> (n_dims, 1 , cache, src, dst_data, 1 );
56615668                    break ;
56625669                  case  GGML_ROPE_TYPE_NEOX:
56635670                  case  GGML_ROPE_TYPE_MROPE: // pure, not vision
5664-                     rotate_pairs (n_dims, n_dims/2 , cache, src, dst_data);
5671+                     rotate_pairs<T> (n_dims, n_dims/2 , cache, src, dst_data);
56655672                    break ;
56665673                  case  GGML_ROPE_TYPE_VISION:
5667-                     rotate_pairs (ne0, n_dims, cache, src, dst_data);
5674+                     rotate_pairs<T> (ne0, n_dims, cache, src, dst_data);
56685675                    break ;
56695676                  default :
56705677                    // rope type not supported, silently default to NORMAL
5671-                     rotate_pairs (n_dims, 1 , cache, src, dst_data, 1 );
5678+                     rotate_pairs<T> (n_dims, 1 , cache, src, dst_data, 1 );
56725679                }
56735680
56745681                if  (!is_vision) {
56755682                    //  fill the remain channels with data from src tensor
56765683                    for  (int64_t  i0 = n_dims; i0 < ne0; i0 += 2 ) {
5677-                         const  float  * const  src = (float  *)((char  *) src0->data  + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5678-                         float  * dst_data  = (float  *)((char  *)  dst->data  + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
5684+                         const  T  * const  src = (T  *)((char  *) src0->data  + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5685+                         T  * dst_data  = (T  *)((char  *)  dst->data  + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
56795686
56805687                        dst_data[0 ] = src[0 ];
56815688                        dst_data[1 ] = src[1 ];
@@ -5686,190 +5693,6 @@ static void ggml_compute_forward_rope_f32(
56865693    }
56875694}
56885695
5689- //  TODO: deduplicate f16/f32 code
5690- static  void  ggml_compute_forward_rope_f16 (
5691-         const  ggml_compute_params * params,
5692-         ggml_tensor * dst,
5693-         const  bool  forward) {
5694- 
5695-     const  ggml_tensor * src0 = dst->src [0 ];
5696-     const  ggml_tensor * src1 = dst->src [1 ];
5697-     const  ggml_tensor * src2 = dst->src [2 ];
5698- 
5699-     float  freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
5700-     int  sections[4 ];
5701- 
5702-     // const int n_past     = ((int32_t *) dst->op_params)[0];
5703-     const  int  n_dims     = ((int32_t  *) dst->op_params )[1 ];
5704-     const  int  mode       = ((int32_t  *) dst->op_params )[2 ];
5705-     // const int n_ctx      = ((int32_t *) dst->op_params)[3];
5706-     const  int  n_ctx_orig = ((int32_t  *) dst->op_params )[4 ];
5707-     memcpy (&freq_base,   (int32_t  *) dst->op_params  +  5 , sizeof (float ));
5708-     memcpy (&freq_scale,  (int32_t  *) dst->op_params  +  6 , sizeof (float ));
5709-     memcpy (&ext_factor,  (int32_t  *) dst->op_params  +  7 , sizeof (float ));
5710-     memcpy (&attn_factor, (int32_t  *) dst->op_params  +  8 , sizeof (float ));
5711-     memcpy (&beta_fast,   (int32_t  *) dst->op_params  +  9 , sizeof (float ));
5712-     memcpy (&beta_slow,   (int32_t  *) dst->op_params  + 10 , sizeof (float ));
5713-     memcpy (§ions,    (int32_t  *) dst->op_params  + 11 , sizeof (int )*4 );
5714- 
5715- 
5716-     GGML_TENSOR_UNARY_OP_LOCALS
5717- 
5718-     // printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
5719-     // printf("n_past = %d, ne2 = %d\n", n_past, ne2);
5720- 
5721-     GGML_ASSERT (nb0 == sizeof (ggml_fp16_t ));
5722- 
5723-     const  int  ith = params->ith ;
5724-     const  int  nth = params->nth ;
5725- 
5726-     const  int  nr = ggml_nrows (dst);
5727- 
5728-     GGML_ASSERT (n_dims <= ne0);
5729-     GGML_ASSERT (n_dims % 2  == 0 );
5730- 
5731-     //  rows per thread
5732-     const  int  dr = (nr + nth - 1 )/nth;
5733- 
5734-     //  row range for this thread
5735-     const  int  ir0 = dr*ith;
5736-     const  int  ir1 = MIN (ir0 + dr, nr);
5737- 
5738-     //  row index used to determine which thread to use
5739-     int  ir = 0 ;
5740- 
5741-     const  float  theta_scale = powf (freq_base, -2 .0f /n_dims);
5742- 
5743-     float  corr_dims[2 ];
5744-     ggml_rope_yarn_corr_dims (n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
5745- 
5746-     const  bool  is_neox = mode & GGML_ROPE_TYPE_NEOX;
5747-     const  bool  is_mrope = mode & GGML_ROPE_TYPE_MROPE;
5748-     const  bool  is_vision = mode == GGML_ROPE_TYPE_VISION;
5749- 
5750-     if  (is_mrope) {
5751-         GGML_ASSERT (sections[0 ] > 0  || sections[1 ] > 0  || sections[2 ] > 0 );
5752-     }
5753- 
5754-     if  (is_vision) {
5755-         GGML_ASSERT (n_dims == ne0/2 );
5756-     }
5757- 
5758-     const  float  * freq_factors = NULL ;
5759-     if  (src2 != NULL ) {
5760-         GGML_ASSERT (src2->type  == GGML_TYPE_F32);
5761-         GGML_ASSERT (src2->ne [0 ] >= n_dims / 2 );
5762-         freq_factors = (const  float  *) src2->data ;
5763-     }
5764- 
5765-     //  backward process uses inverse rotation by cos and sin.
5766-     //  cos and sin build a rotation matrix, where the inverse is the transpose.
5767-     //  this essentially just switches the sign of sin.
5768-     const  float  sin_sign = forward ? 1 .0f  : -1 .0f ;
5769- 
5770-     const  int32_t  * pos = (const  int32_t  *) src1->data ;
5771- 
5772-     for  (int64_t  i3 = 0 ; i3 < ne3; i3++) {
5773-         for  (int64_t  i2 = 0 ; i2 < ne2; i2++) {
5774- 
5775-             float  * cache = (float  *) params->wdata  + (ne0 + CACHE_LINE_SIZE_F32)*ith;
5776-             if  (!is_mrope) {
5777-                 const  int64_t  p = pos[i2];
5778-                 ggml_rope_cache_init (p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
5779-             }
5780-             else  {
5781-                 const  int64_t  p_t  = pos[i2];
5782-                 const  int64_t  p_h = pos[i2 + ne2];
5783-                 const  int64_t  p_w = pos[i2 + ne2 * 2 ];
5784-                 const  int64_t  p_e = pos[i2 + ne2 * 3 ];
5785-                 ggml_mrope_cache_init (
5786-                     p_t , p_h, p_w, p_e, sections, is_vision,
5787-                     freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
5788-             }
5789- 
5790-             for  (int64_t  i1 = 0 ; i1 < ne1; i1++) {
5791-                 if  (ir++ < ir0) continue ;
5792-                 if  (ir   > ir1) break ;
5793- 
5794-                 if  (is_neox || is_mrope) {
5795-                     if  (is_vision) {
5796-                         for  (int64_t  i0 = 0 ; i0 < n_dims; i0 += 2 ) {
5797-                             const  int64_t  ic = i0/2 ;
5798- 
5799-                             const  float  cos_theta = cache[i0 + 0 ];
5800-                             const  float  sin_theta = cache[i0 + 1 ];
5801- 
5802-                             const  ggml_fp16_t  * const  src = (ggml_fp16_t  *)((char  *) src0->data  + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5803-                             ggml_fp16_t  * dst_data  = (ggml_fp16_t  *)((char  *)  dst->data  + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
5804- 
5805-                             const  float  x0 = GGML_CPU_FP16_TO_FP32 (src[0 ]);
5806-                             const  float  x1 = GGML_CPU_FP16_TO_FP32 (src[n_dims]);
5807- 
5808-                             dst_data[0 ]      = GGML_CPU_FP32_TO_FP16 (x0*cos_theta - x1*sin_theta);
5809-                             dst_data[n_dims] = GGML_CPU_FP32_TO_FP16 (x0*sin_theta + x1*cos_theta);
5810-                         }
5811-                     } else  {
5812-                         for  (int64_t  i0 = 0 ; i0 < n_dims; i0 += 2 ) {
5813-                             const  int64_t  ic = i0/2 ;
5814- 
5815-                             const  float  cos_theta = cache[i0 + 0 ];
5816-                             const  float  sin_theta = cache[i0 + 1 ];
5817- 
5818-                             const  ggml_fp16_t  * const  src = (ggml_fp16_t  *)((char  *) src0->data  + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5819-                             ggml_fp16_t  * dst_data  = (ggml_fp16_t  *)((char  *)  dst->data  + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
5820- 
5821-                             const  float  x0 = GGML_CPU_FP16_TO_FP32 (src[0 ]);
5822-                             const  float  x1 = GGML_CPU_FP16_TO_FP32 (src[n_dims/2 ]);
5823- 
5824-                             dst_data[0 ]        = GGML_CPU_FP32_TO_FP16 (x0*cos_theta - x1*sin_theta);
5825-                             dst_data[n_dims/2 ] = GGML_CPU_FP32_TO_FP16 (x0*sin_theta + x1*cos_theta);
5826-                         }
5827-                     }
5828-                 } else  {
5829-                     for  (int64_t  i0 = 0 ; i0 < n_dims; i0 += 2 ) {
5830-                         const  float  cos_theta = cache[i0 + 0 ];
5831-                         const  float  sin_theta = cache[i0 + 1 ];
5832- 
5833-                         const  ggml_fp16_t  * const  src = (ggml_fp16_t  *)((char  *) src0->data  + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5834-                               ggml_fp16_t  * dst_data  = (ggml_fp16_t  *)((char  *)  dst->data  + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
5835- 
5836-                         const  float  x0 = GGML_CPU_FP16_TO_FP32 (src[0 ]);
5837-                         const  float  x1 = GGML_CPU_FP16_TO_FP32 (src[1 ]);
5838- 
5839-                         dst_data[0 ] = GGML_CPU_FP32_TO_FP16 (x0*cos_theta - x1*sin_theta);
5840-                         dst_data[1 ] = GGML_CPU_FP32_TO_FP16 (x0*sin_theta + x1*cos_theta);
5841-                     }
5842-                 }
5843- 
5844-                 if  (is_vision) {
5845-                     for  (int64_t  i0 = n_dims; i0 < ne0; i0 += 2 ) {
5846-                         const  int64_t  ic = i0/2 ;
5847- 
5848-                         const  float  cos_theta = cache[i0 + 0 ];
5849-                         const  float  sin_theta = cache[i0 + 1 ];
5850- 
5851-                         const  ggml_fp16_t  * const  src = (ggml_fp16_t  *)((char  *) src0->data  + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5852-                         ggml_fp16_t  * dst_data  = (ggml_fp16_t  *)((char  *)  dst->data  + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
5853- 
5854-                         const  float  x0 = GGML_CPU_FP16_TO_FP32 (src[0 ]);
5855-                         const  float  x1 = GGML_CPU_FP16_TO_FP32 (src[n_dims]);
5856- 
5857-                         dst_data[0 ]      = GGML_CPU_FP32_TO_FP16 (x0*cos_theta - x1*sin_theta);
5858-                         dst_data[n_dims] = GGML_CPU_FP32_TO_FP16 (x0*sin_theta + x1*cos_theta);
5859-                     }
5860-                 } else  {
5861-                     for  (int64_t  i0 = n_dims; i0 < ne0; i0 += 2 ) {
5862-                         const  ggml_fp16_t  * const  src = (ggml_fp16_t  *)((char  *) src0->data  + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5863-                         ggml_fp16_t  * dst_data  = (ggml_fp16_t  *)((char  *)  dst->data  + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
5864- 
5865-                         dst_data[0 ] = src[0 ];
5866-                         dst_data[1 ] = src[1 ];
5867-                     }
5868-                 }
5869-             }
5870-         }
5871-     }
5872- }
58735696
58745697void  ggml_compute_forward_rope (
58755698        const  ggml_compute_params * params,
@@ -5880,11 +5703,11 @@ void ggml_compute_forward_rope(
58805703    switch  (src0->type ) {
58815704        case  GGML_TYPE_F16:
58825705            {
5883-                 ggml_compute_forward_rope_f16 (params, dst, true );
5706+                 ggml_compute_forward_rope_flt< ggml_fp16_t > (params, dst, true );
58845707            } break ;
58855708        case  GGML_TYPE_F32:
58865709            {
5887-                 ggml_compute_forward_rope_f32 (params, dst, true );
5710+                 ggml_compute_forward_rope_flt< float > (params, dst, true );
58885711            } break ;
58895712        default :
58905713            {
@@ -5904,11 +5727,11 @@ void ggml_compute_forward_rope_back(
59045727    switch  (src0->type ) {
59055728        case  GGML_TYPE_F16:
59065729            {
5907-                 ggml_compute_forward_rope_f16 (params, dst, false );
5730+                 ggml_compute_forward_rope_flt< ggml_fp16_t > (params, dst, false );
59085731            } break ;
59095732        case  GGML_TYPE_F32:
59105733            {
5911-                 ggml_compute_forward_rope_f32 (params, dst, false );
5734+                 ggml_compute_forward_rope_flt< float > (params, dst, false );
59125735            } break ;
59135736        default :
59145737            {
0 commit comments