Skip to content

Commit a0bf696

Browse files
committed
templateify ggml_compute_forward_rope_f32 and _f16
1 parent 4b540ab commit a0bf696

File tree

1 file changed

+28
-205
lines changed

1 file changed

+28
-205
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 28 additions & 205 deletions
Original file line numberDiff line numberDiff line change
@@ -5531,25 +5531,28 @@ static void ggml_mrope_cache_init(
55315531
}
55325532
}
55335533

5534-
static void rotate_pairs(const int64_t n, const int64_t n_offset, const float * cache, const float * src_data, float * dst_data, const int scale = 2) {
5534+
5535+
template<typename T>
5536+
static void rotate_pairs(const int64_t n, const int64_t n_offset, const float * cache, const T * src_data, T * dst_data, const int scale = 2) {
55355537
for (int64_t i0 = 0; i0 < n; i0 += 2) {
55365538
const int64_t ic = i0/scale; //hack for GGML_ROPE_TYPE_NORMAL, where we need ic = i0; for all other cases, ic = i0/2
55375539

55385540
const float cos_theta = cache[i0 + 0];
55395541
const float sin_theta = cache[i0 + 1];
55405542

5541-
const float * const src = src_data + ic;
5542-
float * dst = dst_data + ic;
5543+
const T * const src = src_data + ic;
5544+
T * dst = dst_data + ic;
55435545

5544-
const float x0 = src[0];
5545-
const float x1 = src[n_offset];
5546+
const float x0 = type_conversion_table<T>::to_f32(src[0]);
5547+
const float x1 = type_conversion_table<T>::to_f32(src[n_offset]);
55465548

5547-
dst[0] = x0*cos_theta - x1*sin_theta;
5548-
dst[n_offset] = x0*sin_theta + x1*cos_theta;
5549+
dst[0] = type_conversion_table<T>::from_f32(x0*cos_theta - x1*sin_theta);
5550+
dst[n_offset] = type_conversion_table<T>::from_f32(x0*sin_theta + x1*cos_theta);
55495551
}
55505552
}
55515553

5552-
static void ggml_compute_forward_rope_f32(
5554+
template<typename T> //float or ggml_fp16_t
5555+
static void ggml_compute_forward_rope_flt(
55535556
const ggml_compute_params * params,
55545557
ggml_tensor * dst,
55555558
const bool forward) {
@@ -5558,6 +5561,9 @@ static void ggml_compute_forward_rope_f32(
55585561
const ggml_tensor * src1 = dst->src[1];
55595562
const ggml_tensor * src2 = dst->src[2];
55605563

5564+
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
5565+
GGML_ASSERT(src1->type == GGML_TYPE_I32);
5566+
55615567
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
55625568
int sections[4];
55635569

@@ -5580,7 +5586,8 @@ static void ggml_compute_forward_rope_f32(
55805586
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
55815587
//printf("n_past = %d, ne2 = %d\n", n_past, ne2);
55825588

5583-
GGML_ASSERT(nb00 == sizeof(float));
5589+
GGML_ASSERT(nb0 == nb00);
5590+
GGML_ASSERT(nb0 == sizeof(T));
55845591

55855592
const int ith = params->ith;
55865593
const int nth = params->nth;
@@ -5652,30 +5659,30 @@ static void ggml_compute_forward_rope_f32(
56525659
if (ir++ < ir0) continue;
56535660
if (ir > ir1) break;
56545661

5655-
float * src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
5656-
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
5662+
T * src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
5663+
T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
56575664

56585665
switch (mode) {
56595666
case GGML_ROPE_TYPE_NORMAL:
5660-
rotate_pairs(n_dims, 1, cache, src, dst_data, 1);
5667+
rotate_pairs<T>(n_dims, 1, cache, src, dst_data, 1);
56615668
break;
56625669
case GGML_ROPE_TYPE_NEOX:
56635670
case GGML_ROPE_TYPE_MROPE: //pure, not vision
5664-
rotate_pairs(n_dims, n_dims/2, cache, src, dst_data);
5671+
rotate_pairs<T>(n_dims, n_dims/2, cache, src, dst_data);
56655672
break;
56665673
case GGML_ROPE_TYPE_VISION:
5667-
rotate_pairs(ne0, n_dims, cache, src, dst_data);
5674+
rotate_pairs<T>(ne0, n_dims, cache, src, dst_data);
56685675
break;
56695676
default:
56705677
//rope type not supported, silently default to NORMAL
5671-
rotate_pairs(n_dims, 1, cache, src, dst_data, 1);
5678+
rotate_pairs<T>(n_dims, 1, cache, src, dst_data, 1);
56725679
}
56735680

56745681
if (!is_vision) {
56755682
// fill the remain channels with data from src tensor
56765683
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
5677-
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5678-
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5684+
const T * const src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5685+
T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
56795686

56805687
dst_data[0] = src[0];
56815688
dst_data[1] = src[1];
@@ -5686,190 +5693,6 @@ static void ggml_compute_forward_rope_f32(
56865693
}
56875694
}
56885695

5689-
// TODO: deduplicate f16/f32 code
5690-
static void ggml_compute_forward_rope_f16(
5691-
const ggml_compute_params * params,
5692-
ggml_tensor * dst,
5693-
const bool forward) {
5694-
5695-
const ggml_tensor * src0 = dst->src[0];
5696-
const ggml_tensor * src1 = dst->src[1];
5697-
const ggml_tensor * src2 = dst->src[2];
5698-
5699-
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
5700-
int sections[4];
5701-
5702-
//const int n_past = ((int32_t *) dst->op_params)[0];
5703-
const int n_dims = ((int32_t *) dst->op_params)[1];
5704-
const int mode = ((int32_t *) dst->op_params)[2];
5705-
//const int n_ctx = ((int32_t *) dst->op_params)[3];
5706-
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
5707-
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
5708-
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
5709-
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
5710-
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
5711-
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
5712-
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
5713-
memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
5714-
5715-
5716-
GGML_TENSOR_UNARY_OP_LOCALS
5717-
5718-
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
5719-
//printf("n_past = %d, ne2 = %d\n", n_past, ne2);
5720-
5721-
GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
5722-
5723-
const int ith = params->ith;
5724-
const int nth = params->nth;
5725-
5726-
const int nr = ggml_nrows(dst);
5727-
5728-
GGML_ASSERT(n_dims <= ne0);
5729-
GGML_ASSERT(n_dims % 2 == 0);
5730-
5731-
// rows per thread
5732-
const int dr = (nr + nth - 1)/nth;
5733-
5734-
// row range for this thread
5735-
const int ir0 = dr*ith;
5736-
const int ir1 = MIN(ir0 + dr, nr);
5737-
5738-
// row index used to determine which thread to use
5739-
int ir = 0;
5740-
5741-
const float theta_scale = powf(freq_base, -2.0f/n_dims);
5742-
5743-
float corr_dims[2];
5744-
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
5745-
5746-
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
5747-
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
5748-
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
5749-
5750-
if (is_mrope) {
5751-
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
5752-
}
5753-
5754-
if (is_vision) {
5755-
GGML_ASSERT(n_dims == ne0/2);
5756-
}
5757-
5758-
const float * freq_factors = NULL;
5759-
if (src2 != NULL) {
5760-
GGML_ASSERT(src2->type == GGML_TYPE_F32);
5761-
GGML_ASSERT(src2->ne[0] >= n_dims / 2);
5762-
freq_factors = (const float *) src2->data;
5763-
}
5764-
5765-
// backward process uses inverse rotation by cos and sin.
5766-
// cos and sin build a rotation matrix, where the inverse is the transpose.
5767-
// this essentially just switches the sign of sin.
5768-
const float sin_sign = forward ? 1.0f : -1.0f;
5769-
5770-
const int32_t * pos = (const int32_t *) src1->data;
5771-
5772-
for (int64_t i3 = 0; i3 < ne3; i3++) {
5773-
for (int64_t i2 = 0; i2 < ne2; i2++) {
5774-
5775-
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
5776-
if (!is_mrope) {
5777-
const int64_t p = pos[i2];
5778-
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
5779-
}
5780-
else {
5781-
const int64_t p_t = pos[i2];
5782-
const int64_t p_h = pos[i2 + ne2];
5783-
const int64_t p_w = pos[i2 + ne2 * 2];
5784-
const int64_t p_e = pos[i2 + ne2 * 3];
5785-
ggml_mrope_cache_init(
5786-
p_t, p_h, p_w, p_e, sections, is_vision,
5787-
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
5788-
}
5789-
5790-
for (int64_t i1 = 0; i1 < ne1; i1++) {
5791-
if (ir++ < ir0) continue;
5792-
if (ir > ir1) break;
5793-
5794-
if (is_neox || is_mrope) {
5795-
if (is_vision) {
5796-
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5797-
const int64_t ic = i0/2;
5798-
5799-
const float cos_theta = cache[i0 + 0];
5800-
const float sin_theta = cache[i0 + 1];
5801-
5802-
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5803-
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5804-
5805-
const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
5806-
const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims]);
5807-
5808-
dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5809-
dst_data[n_dims] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5810-
}
5811-
} else {
5812-
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5813-
const int64_t ic = i0/2;
5814-
5815-
const float cos_theta = cache[i0 + 0];
5816-
const float sin_theta = cache[i0 + 1];
5817-
5818-
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5819-
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5820-
5821-
const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
5822-
const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims/2]);
5823-
5824-
dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5825-
dst_data[n_dims/2] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5826-
}
5827-
}
5828-
} else {
5829-
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5830-
const float cos_theta = cache[i0 + 0];
5831-
const float sin_theta = cache[i0 + 1];
5832-
5833-
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5834-
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5835-
5836-
const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
5837-
const float x1 = GGML_CPU_FP16_TO_FP32(src[1]);
5838-
5839-
dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5840-
dst_data[1] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5841-
}
5842-
}
5843-
5844-
if (is_vision) {
5845-
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
5846-
const int64_t ic = i0/2;
5847-
5848-
const float cos_theta = cache[i0 + 0];
5849-
const float sin_theta = cache[i0 + 1];
5850-
5851-
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5852-
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5853-
5854-
const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
5855-
const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims]);
5856-
5857-
dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5858-
dst_data[n_dims] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5859-
}
5860-
} else {
5861-
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
5862-
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5863-
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5864-
5865-
dst_data[0] = src[0];
5866-
dst_data[1] = src[1];
5867-
}
5868-
}
5869-
}
5870-
}
5871-
}
5872-
}
58735696

58745697
void ggml_compute_forward_rope(
58755698
const ggml_compute_params * params,
@@ -5880,11 +5703,11 @@ void ggml_compute_forward_rope(
58805703
switch (src0->type) {
58815704
case GGML_TYPE_F16:
58825705
{
5883-
ggml_compute_forward_rope_f16(params, dst, true);
5706+
ggml_compute_forward_rope_flt<ggml_fp16_t>(params, dst, true);
58845707
} break;
58855708
case GGML_TYPE_F32:
58865709
{
5887-
ggml_compute_forward_rope_f32(params, dst, true);
5710+
ggml_compute_forward_rope_flt<float>(params, dst, true);
58885711
} break;
58895712
default:
58905713
{
@@ -5904,11 +5727,11 @@ void ggml_compute_forward_rope_back(
59045727
switch (src0->type) {
59055728
case GGML_TYPE_F16:
59065729
{
5907-
ggml_compute_forward_rope_f16(params, dst, false);
5730+
ggml_compute_forward_rope_flt<ggml_fp16_t>(params, dst, false);
59085731
} break;
59095732
case GGML_TYPE_F32:
59105733
{
5911-
ggml_compute_forward_rope_f32(params, dst, false);
5734+
ggml_compute_forward_rope_flt<float>(params, dst, false);
59125735
} break;
59135736
default:
59145737
{

0 commit comments

Comments
 (0)