@@ -5543,25 +5543,28 @@ static void ggml_mrope_cache_init(
55435543 }
55445544}
55455545
5546- static void rotate_pairs (const int64_t n, const int64_t n_offset, const float * cache, const float * src_data, float * dst_data, const int scale = 2 ) {
5546+
5547+ template <typename T>
5548+ static void rotate_pairs (const int64_t n, const int64_t n_offset, const float * cache, const T * src_data, T * dst_data, const int scale = 2 ) {
55475549 for (int64_t i0 = 0 ; i0 < n; i0 += 2 ) {
55485550 const int64_t ic = i0/scale; // hack for GGML_ROPE_TYPE_NORMAL, where we need ic = i0; for all other cases, ic = i0/2
55495551
55505552 const float cos_theta = cache[i0 + 0 ];
55515553 const float sin_theta = cache[i0 + 1 ];
55525554
5553- const float * const src = src_data + ic;
5554- float * dst = dst_data + ic;
5555+ const T * const src = src_data + ic;
5556+ T * dst = dst_data + ic;
55555557
5556- const float x0 = src[0 ];
5557- const float x1 = src[n_offset];
5558+ const float x0 = type_conversion_table<T>:: to_f32 ( src[0 ]) ;
5559+ const float x1 = type_conversion_table<T>:: to_f32 ( src[n_offset]) ;
55585560
5559- dst[0 ] = x0*cos_theta - x1*sin_theta;
5560- dst[n_offset] = x0*sin_theta + x1*cos_theta;
5561+ dst[0 ] = type_conversion_table<T>:: from_f32 ( x0*cos_theta - x1*sin_theta) ;
5562+ dst[n_offset] = type_conversion_table<T>:: from_f32 ( x0*sin_theta + x1*cos_theta) ;
55615563 }
55625564}
55635565
5564- static void ggml_compute_forward_rope_f32 (
5566+ template <typename T> // float or ggml_fp16_t
5567+ static void ggml_compute_forward_rope_flt (
55655568 const ggml_compute_params * params,
55665569 ggml_tensor * dst,
55675570 const bool forward) {
@@ -5570,6 +5573,9 @@ static void ggml_compute_forward_rope_f32(
55705573 const ggml_tensor * src1 = dst->src [1 ];
55715574 const ggml_tensor * src2 = dst->src [2 ];
55725575
5576+ GGML_ASSERT (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
5577+ GGML_ASSERT (src1->type == GGML_TYPE_I32);
5578+
55735579 float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
55745580 int sections[4 ];
55755581
@@ -5592,7 +5598,8 @@ static void ggml_compute_forward_rope_f32(
55925598 // printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
55935599 // printf("n_past = %d, ne2 = %d\n", n_past, ne2);
55945600
5595- GGML_ASSERT (nb00 == sizeof (float ));
5601+ GGML_ASSERT (nb0 == nb00);
5602+ GGML_ASSERT (nb0 == sizeof (T));
55965603
55975604 const int ith = params->ith ;
55985605 const int nth = params->nth ;
@@ -5665,30 +5672,30 @@ static void ggml_compute_forward_rope_f32(
56655672 if (ir++ < ir0) continue ;
56665673 if (ir > ir1) break ;
56675674
5668- float * src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
5669- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
5675+ T * src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
5676+ T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
56705677
56715678 switch (mode) {
56725679 case GGML_ROPE_TYPE_NORMAL:
5673- rotate_pairs (n_dims, 1 , cache, src, dst_data, 1 );
5680+ rotate_pairs<T> (n_dims, 1 , cache, src, dst_data, 1 );
56745681 break ;
56755682 case GGML_ROPE_TYPE_NEOX:
56765683 case GGML_ROPE_TYPE_MROPE: // pure, not vision
5677- rotate_pairs (n_dims, n_dims/2 , cache, src, dst_data);
5684+ rotate_pairs<T> (n_dims, n_dims/2 , cache, src, dst_data);
56785685 break ;
56795686 case GGML_ROPE_TYPE_VISION:
5680- rotate_pairs (ne0, n_dims, cache, src, dst_data);
5687+ rotate_pairs<T> (ne0, n_dims, cache, src, dst_data);
56815688 break ;
56825689 default :
56835690 // rope type not supported, silently default to NORMAL
5684- rotate_pairs (n_dims, 1 , cache, src, dst_data, 1 );
5691+ rotate_pairs<T> (n_dims, 1 , cache, src, dst_data, 1 );
56855692 }
56865693
56875694 if (!is_vision) {
56885695 // fill the remain channels with data from src tensor
56895696 for (int64_t i0 = n_dims; i0 < ne0; i0 += 2 ) {
5690- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5691- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5697+ const T * const src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5698+ T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
56925699
56935700 dst_data[0 ] = src[0 ];
56945701 dst_data[1 ] = src[1 ];
@@ -5699,192 +5706,6 @@ static void ggml_compute_forward_rope_f32(
56995706 }
57005707}
57015708
5702- // TODO: deduplicate f16/f32 code
5703- static void ggml_compute_forward_rope_f16 (
5704- const ggml_compute_params * params,
5705- ggml_tensor * dst,
5706- const bool forward) {
5707-
5708- const ggml_tensor * src0 = dst->src [0 ];
5709- const ggml_tensor * src1 = dst->src [1 ];
5710- const ggml_tensor * src2 = dst->src [2 ];
5711-
5712- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
5713- int sections[4 ];
5714-
5715- // const int n_past = ((int32_t *) dst->op_params)[0];
5716- const int n_dims = ((int32_t *) dst->op_params )[1 ];
5717- const int mode = ((int32_t *) dst->op_params )[2 ];
5718- // const int n_ctx = ((int32_t *) dst->op_params)[3];
5719- const int n_ctx_orig = ((int32_t *) dst->op_params )[4 ];
5720- memcpy (&freq_base, (int32_t *) dst->op_params + 5 , sizeof (float ));
5721- memcpy (&freq_scale, (int32_t *) dst->op_params + 6 , sizeof (float ));
5722- memcpy (&ext_factor, (int32_t *) dst->op_params + 7 , sizeof (float ));
5723- memcpy (&attn_factor, (int32_t *) dst->op_params + 8 , sizeof (float ));
5724- memcpy (&beta_fast, (int32_t *) dst->op_params + 9 , sizeof (float ));
5725- memcpy (&beta_slow, (int32_t *) dst->op_params + 10 , sizeof (float ));
5726- memcpy (§ions, (int32_t *) dst->op_params + 11 , sizeof (int )*4 );
5727-
5728-
5729- GGML_TENSOR_UNARY_OP_LOCALS
5730-
5731- // printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
5732- // printf("n_past = %d, ne2 = %d\n", n_past, ne2);
5733-
5734- GGML_ASSERT (nb0 == sizeof (ggml_fp16_t ));
5735-
5736- const int ith = params->ith ;
5737- const int nth = params->nth ;
5738-
5739- const int nr = ggml_nrows (dst);
5740-
5741- GGML_ASSERT (n_dims <= ne0);
5742- GGML_ASSERT (n_dims % 2 == 0 );
5743-
5744- // rows per thread
5745- const int dr = (nr + nth - 1 )/nth;
5746-
5747- // row range for this thread
5748- const int ir0 = dr*ith;
5749- const int ir1 = MIN (ir0 + dr, nr);
5750-
5751- // row index used to determine which thread to use
5752- int ir = 0 ;
5753-
5754- const float theta_scale = powf (freq_base, -2 .0f /n_dims);
5755-
5756- float corr_dims[2 ];
5757- ggml_rope_yarn_corr_dims (n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
5758-
5759- const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
5760- const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
5761- const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
5762- const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
5763-
5764- if (is_mrope) {
5765- GGML_ASSERT (sections[0 ] > 0 || sections[1 ] > 0 || sections[2 ] > 0 );
5766- }
5767-
5768- if (is_vision) {
5769- GGML_ASSERT (n_dims == ne0/2 );
5770- }
5771-
5772- const float * freq_factors = NULL ;
5773- if (src2 != NULL ) {
5774- GGML_ASSERT (src2->type == GGML_TYPE_F32);
5775- GGML_ASSERT (src2->ne [0 ] >= n_dims / 2 );
5776- freq_factors = (const float *) src2->data ;
5777- }
5778-
5779- // backward process uses inverse rotation by cos and sin.
5780- // cos and sin build a rotation matrix, where the inverse is the transpose.
5781- // this essentially just switches the sign of sin.
5782- const float sin_sign = forward ? 1 .0f : -1 .0f ;
5783-
5784- const int32_t * pos = (const int32_t *) src1->data ;
5785-
5786- for (int64_t i3 = 0 ; i3 < ne3; i3++) {
5787- for (int64_t i2 = 0 ; i2 < ne2; i2++) {
5788-
5789- float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
5790- if (!is_mrope) {
5791- const int64_t p = pos[i2];
5792- ggml_rope_cache_init (p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
5793- }
5794- else {
5795- const int64_t p_t = pos[i2];
5796- const int64_t p_h = pos[i2 + ne2];
5797- const int64_t p_w = pos[i2 + ne2 * 2 ];
5798- const int64_t p_e = pos[i2 + ne2 * 3 ];
5799- ggml_mrope_cache_init (
5800- p_t , p_h, p_w, p_e, sections, is_imrope, is_vision,
5801- freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
5802- }
5803-
5804- for (int64_t i1 = 0 ; i1 < ne1; i1++) {
5805- if (ir++ < ir0) continue ;
5806- if (ir > ir1) break ;
5807-
5808- if (is_neox || is_mrope) {
5809- if (is_vision) {
5810- for (int64_t i0 = 0 ; i0 < n_dims; i0 += 2 ) {
5811- const int64_t ic = i0/2 ;
5812-
5813- const float cos_theta = cache[i0 + 0 ];
5814- const float sin_theta = cache[i0 + 1 ];
5815-
5816- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5817- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5818-
5819- const float x0 = GGML_CPU_FP16_TO_FP32 (src[0 ]);
5820- const float x1 = GGML_CPU_FP16_TO_FP32 (src[n_dims]);
5821-
5822- dst_data[0 ] = GGML_CPU_FP32_TO_FP16 (x0*cos_theta - x1*sin_theta);
5823- dst_data[n_dims] = GGML_CPU_FP32_TO_FP16 (x0*sin_theta + x1*cos_theta);
5824- }
5825- } else {
5826- for (int64_t i0 = 0 ; i0 < n_dims; i0 += 2 ) {
5827- const int64_t ic = i0/2 ;
5828-
5829- const float cos_theta = cache[i0 + 0 ];
5830- const float sin_theta = cache[i0 + 1 ];
5831-
5832- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5833- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5834-
5835- const float x0 = GGML_CPU_FP16_TO_FP32 (src[0 ]);
5836- const float x1 = GGML_CPU_FP16_TO_FP32 (src[n_dims/2 ]);
5837-
5838- dst_data[0 ] = GGML_CPU_FP32_TO_FP16 (x0*cos_theta - x1*sin_theta);
5839- dst_data[n_dims/2 ] = GGML_CPU_FP32_TO_FP16 (x0*sin_theta + x1*cos_theta);
5840- }
5841- }
5842- } else {
5843- for (int64_t i0 = 0 ; i0 < n_dims; i0 += 2 ) {
5844- const float cos_theta = cache[i0 + 0 ];
5845- const float sin_theta = cache[i0 + 1 ];
5846-
5847- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5848- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5849-
5850- const float x0 = GGML_CPU_FP16_TO_FP32 (src[0 ]);
5851- const float x1 = GGML_CPU_FP16_TO_FP32 (src[1 ]);
5852-
5853- dst_data[0 ] = GGML_CPU_FP32_TO_FP16 (x0*cos_theta - x1*sin_theta);
5854- dst_data[1 ] = GGML_CPU_FP32_TO_FP16 (x0*sin_theta + x1*cos_theta);
5855- }
5856- }
5857-
5858- if (is_vision) {
5859- for (int64_t i0 = n_dims; i0 < ne0; i0 += 2 ) {
5860- const int64_t ic = i0/2 ;
5861-
5862- const float cos_theta = cache[i0 + 0 ];
5863- const float sin_theta = cache[i0 + 1 ];
5864-
5865- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5866- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5867-
5868- const float x0 = GGML_CPU_FP16_TO_FP32 (src[0 ]);
5869- const float x1 = GGML_CPU_FP16_TO_FP32 (src[n_dims]);
5870-
5871- dst_data[0 ] = GGML_CPU_FP32_TO_FP16 (x0*cos_theta - x1*sin_theta);
5872- dst_data[n_dims] = GGML_CPU_FP32_TO_FP16 (x0*sin_theta + x1*cos_theta);
5873- }
5874- } else {
5875- for (int64_t i0 = n_dims; i0 < ne0; i0 += 2 ) {
5876- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5877- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5878-
5879- dst_data[0 ] = src[0 ];
5880- dst_data[1 ] = src[1 ];
5881- }
5882- }
5883- }
5884- }
5885- }
5886- }
5887-
58885709void ggml_compute_forward_rope (
58895710 const ggml_compute_params * params,
58905711 ggml_tensor * dst) {
@@ -5894,11 +5715,11 @@ void ggml_compute_forward_rope(
58945715 switch (src0->type ) {
58955716 case GGML_TYPE_F16:
58965717 {
5897- ggml_compute_forward_rope_f16 (params, dst, true );
5718+ ggml_compute_forward_rope_flt< ggml_fp16_t > (params, dst, true );
58985719 } break ;
58995720 case GGML_TYPE_F32:
59005721 {
5901- ggml_compute_forward_rope_f32 (params, dst, true );
5722+ ggml_compute_forward_rope_flt< float > (params, dst, true );
59025723 } break ;
59035724 default :
59045725 {
@@ -5918,11 +5739,11 @@ void ggml_compute_forward_rope_back(
59185739 switch (src0->type ) {
59195740 case GGML_TYPE_F16:
59205741 {
5921- ggml_compute_forward_rope_f16 (params, dst, false );
5742+ ggml_compute_forward_rope_flt< ggml_fp16_t > (params, dst, false );
59225743 } break ;
59235744 case GGML_TYPE_F32:
59245745 {
5925- ggml_compute_forward_rope_f32 (params, dst, false );
5746+ ggml_compute_forward_rope_flt< float > (params, dst, false );
59265747 } break ;
59275748 default :
59285749 {
0 commit comments