Skip to content

Commit 29fbf41

Browse files
committed
templateify ggml_compute_forward_rope_f32 and _f16
1 parent 97b0baf commit 29fbf41

File tree

1 file changed

+28
-207
lines changed

1 file changed

+28
-207
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 28 additions & 207 deletions
Original file line numberDiff line numberDiff line change
@@ -5543,25 +5543,28 @@ static void ggml_mrope_cache_init(
55435543
}
55445544
}
55455545

5546-
static void rotate_pairs(const int64_t n, const int64_t n_offset, const float * cache, const float * src_data, float * dst_data, const int scale = 2) {
5546+
5547+
template<typename T>
5548+
static void rotate_pairs(const int64_t n, const int64_t n_offset, const float * cache, const T * src_data, T * dst_data, const int scale = 2) {
55475549
for (int64_t i0 = 0; i0 < n; i0 += 2) {
55485550
const int64_t ic = i0/scale; //hack for GGML_ROPE_TYPE_NORMAL, where we need ic = i0; for all other cases, ic = i0/2
55495551

55505552
const float cos_theta = cache[i0 + 0];
55515553
const float sin_theta = cache[i0 + 1];
55525554

5553-
const float * const src = src_data + ic;
5554-
float * dst = dst_data + ic;
5555+
const T * const src = src_data + ic;
5556+
T * dst = dst_data + ic;
55555557

5556-
const float x0 = src[0];
5557-
const float x1 = src[n_offset];
5558+
const float x0 = type_conversion_table<T>::to_f32(src[0]);
5559+
const float x1 = type_conversion_table<T>::to_f32(src[n_offset]);
55585560

5559-
dst[0] = x0*cos_theta - x1*sin_theta;
5560-
dst[n_offset] = x0*sin_theta + x1*cos_theta;
5561+
dst[0] = type_conversion_table<T>::from_f32(x0*cos_theta - x1*sin_theta);
5562+
dst[n_offset] = type_conversion_table<T>::from_f32(x0*sin_theta + x1*cos_theta);
55615563
}
55625564
}
55635565

5564-
static void ggml_compute_forward_rope_f32(
5566+
template<typename T> //float or ggml_fp16_t
5567+
static void ggml_compute_forward_rope_flt(
55655568
const ggml_compute_params * params,
55665569
ggml_tensor * dst,
55675570
const bool forward) {
@@ -5570,6 +5573,9 @@ static void ggml_compute_forward_rope_f32(
55705573
const ggml_tensor * src1 = dst->src[1];
55715574
const ggml_tensor * src2 = dst->src[2];
55725575

5576+
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
5577+
GGML_ASSERT(src1->type == GGML_TYPE_I32);
5578+
55735579
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
55745580
int sections[4];
55755581

@@ -5592,7 +5598,8 @@ static void ggml_compute_forward_rope_f32(
55925598
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
55935599
//printf("n_past = %d, ne2 = %d\n", n_past, ne2);
55945600

5595-
GGML_ASSERT(nb00 == sizeof(float));
5601+
GGML_ASSERT(nb0 == nb00);
5602+
GGML_ASSERT(nb0 == sizeof(T));
55965603

55975604
const int ith = params->ith;
55985605
const int nth = params->nth;
@@ -5665,30 +5672,30 @@ static void ggml_compute_forward_rope_f32(
56655672
if (ir++ < ir0) continue;
56665673
if (ir > ir1) break;
56675674

5668-
float * src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
5669-
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
5675+
T * src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
5676+
T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
56705677

56715678
switch (mode) {
56725679
case GGML_ROPE_TYPE_NORMAL:
5673-
rotate_pairs(n_dims, 1, cache, src, dst_data, 1);
5680+
rotate_pairs<T>(n_dims, 1, cache, src, dst_data, 1);
56745681
break;
56755682
case GGML_ROPE_TYPE_NEOX:
56765683
case GGML_ROPE_TYPE_MROPE: //pure, not vision
5677-
rotate_pairs(n_dims, n_dims/2, cache, src, dst_data);
5684+
rotate_pairs<T>(n_dims, n_dims/2, cache, src, dst_data);
56785685
break;
56795686
case GGML_ROPE_TYPE_VISION:
5680-
rotate_pairs(ne0, n_dims, cache, src, dst_data);
5687+
rotate_pairs<T>(ne0, n_dims, cache, src, dst_data);
56815688
break;
56825689
default:
56835690
//rope type not supported, silently default to NORMAL
5684-
rotate_pairs(n_dims, 1, cache, src, dst_data, 1);
5691+
rotate_pairs<T>(n_dims, 1, cache, src, dst_data, 1);
56855692
}
56865693

56875694
if (!is_vision) {
56885695
// fill the remain channels with data from src tensor
56895696
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
5690-
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5691-
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5697+
const T * const src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5698+
T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
56925699

56935700
dst_data[0] = src[0];
56945701
dst_data[1] = src[1];
@@ -5699,192 +5706,6 @@ static void ggml_compute_forward_rope_f32(
56995706
}
57005707
}
57015708

5702-
// TODO: deduplicate f16/f32 code
5703-
static void ggml_compute_forward_rope_f16(
5704-
const ggml_compute_params * params,
5705-
ggml_tensor * dst,
5706-
const bool forward) {
5707-
5708-
const ggml_tensor * src0 = dst->src[0];
5709-
const ggml_tensor * src1 = dst->src[1];
5710-
const ggml_tensor * src2 = dst->src[2];
5711-
5712-
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
5713-
int sections[4];
5714-
5715-
//const int n_past = ((int32_t *) dst->op_params)[0];
5716-
const int n_dims = ((int32_t *) dst->op_params)[1];
5717-
const int mode = ((int32_t *) dst->op_params)[2];
5718-
//const int n_ctx = ((int32_t *) dst->op_params)[3];
5719-
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
5720-
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
5721-
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
5722-
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
5723-
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
5724-
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
5725-
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
5726-
memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
5727-
5728-
5729-
GGML_TENSOR_UNARY_OP_LOCALS
5730-
5731-
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
5732-
//printf("n_past = %d, ne2 = %d\n", n_past, ne2);
5733-
5734-
GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
5735-
5736-
const int ith = params->ith;
5737-
const int nth = params->nth;
5738-
5739-
const int nr = ggml_nrows(dst);
5740-
5741-
GGML_ASSERT(n_dims <= ne0);
5742-
GGML_ASSERT(n_dims % 2 == 0);
5743-
5744-
// rows per thread
5745-
const int dr = (nr + nth - 1)/nth;
5746-
5747-
// row range for this thread
5748-
const int ir0 = dr*ith;
5749-
const int ir1 = MIN(ir0 + dr, nr);
5750-
5751-
// row index used to determine which thread to use
5752-
int ir = 0;
5753-
5754-
const float theta_scale = powf(freq_base, -2.0f/n_dims);
5755-
5756-
float corr_dims[2];
5757-
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
5758-
5759-
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
5760-
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
5761-
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
5762-
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
5763-
5764-
if (is_mrope) {
5765-
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
5766-
}
5767-
5768-
if (is_vision) {
5769-
GGML_ASSERT(n_dims == ne0/2);
5770-
}
5771-
5772-
const float * freq_factors = NULL;
5773-
if (src2 != NULL) {
5774-
GGML_ASSERT(src2->type == GGML_TYPE_F32);
5775-
GGML_ASSERT(src2->ne[0] >= n_dims / 2);
5776-
freq_factors = (const float *) src2->data;
5777-
}
5778-
5779-
// backward process uses inverse rotation by cos and sin.
5780-
// cos and sin build a rotation matrix, where the inverse is the transpose.
5781-
// this essentially just switches the sign of sin.
5782-
const float sin_sign = forward ? 1.0f : -1.0f;
5783-
5784-
const int32_t * pos = (const int32_t *) src1->data;
5785-
5786-
for (int64_t i3 = 0; i3 < ne3; i3++) {
5787-
for (int64_t i2 = 0; i2 < ne2; i2++) {
5788-
5789-
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
5790-
if (!is_mrope) {
5791-
const int64_t p = pos[i2];
5792-
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
5793-
}
5794-
else {
5795-
const int64_t p_t = pos[i2];
5796-
const int64_t p_h = pos[i2 + ne2];
5797-
const int64_t p_w = pos[i2 + ne2 * 2];
5798-
const int64_t p_e = pos[i2 + ne2 * 3];
5799-
ggml_mrope_cache_init(
5800-
p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
5801-
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
5802-
}
5803-
5804-
for (int64_t i1 = 0; i1 < ne1; i1++) {
5805-
if (ir++ < ir0) continue;
5806-
if (ir > ir1) break;
5807-
5808-
if (is_neox || is_mrope) {
5809-
if (is_vision) {
5810-
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5811-
const int64_t ic = i0/2;
5812-
5813-
const float cos_theta = cache[i0 + 0];
5814-
const float sin_theta = cache[i0 + 1];
5815-
5816-
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5817-
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5818-
5819-
const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
5820-
const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims]);
5821-
5822-
dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5823-
dst_data[n_dims] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5824-
}
5825-
} else {
5826-
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5827-
const int64_t ic = i0/2;
5828-
5829-
const float cos_theta = cache[i0 + 0];
5830-
const float sin_theta = cache[i0 + 1];
5831-
5832-
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5833-
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5834-
5835-
const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
5836-
const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims/2]);
5837-
5838-
dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5839-
dst_data[n_dims/2] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5840-
}
5841-
}
5842-
} else {
5843-
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5844-
const float cos_theta = cache[i0 + 0];
5845-
const float sin_theta = cache[i0 + 1];
5846-
5847-
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5848-
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5849-
5850-
const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
5851-
const float x1 = GGML_CPU_FP16_TO_FP32(src[1]);
5852-
5853-
dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5854-
dst_data[1] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5855-
}
5856-
}
5857-
5858-
if (is_vision) {
5859-
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
5860-
const int64_t ic = i0/2;
5861-
5862-
const float cos_theta = cache[i0 + 0];
5863-
const float sin_theta = cache[i0 + 1];
5864-
5865-
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5866-
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5867-
5868-
const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
5869-
const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims]);
5870-
5871-
dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5872-
dst_data[n_dims] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5873-
}
5874-
} else {
5875-
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
5876-
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5877-
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5878-
5879-
dst_data[0] = src[0];
5880-
dst_data[1] = src[1];
5881-
}
5882-
}
5883-
}
5884-
}
5885-
}
5886-
}
5887-
58885709
void ggml_compute_forward_rope(
58895710
const ggml_compute_params * params,
58905711
ggml_tensor * dst) {
@@ -5894,11 +5715,11 @@ void ggml_compute_forward_rope(
58945715
switch (src0->type) {
58955716
case GGML_TYPE_F16:
58965717
{
5897-
ggml_compute_forward_rope_f16(params, dst, true);
5718+
ggml_compute_forward_rope_flt<ggml_fp16_t>(params, dst, true);
58985719
} break;
58995720
case GGML_TYPE_F32:
59005721
{
5901-
ggml_compute_forward_rope_f32(params, dst, true);
5722+
ggml_compute_forward_rope_flt<float>(params, dst, true);
59025723
} break;
59035724
default:
59045725
{
@@ -5918,11 +5739,11 @@ void ggml_compute_forward_rope_back(
59185739
switch (src0->type) {
59195740
case GGML_TYPE_F16:
59205741
{
5921-
ggml_compute_forward_rope_f16(params, dst, false);
5742+
ggml_compute_forward_rope_flt<ggml_fp16_t>(params, dst, false);
59225743
} break;
59235744
case GGML_TYPE_F32:
59245745
{
5925-
ggml_compute_forward_rope_f32(params, dst, false);
5746+
ggml_compute_forward_rope_flt<float>(params, dst, false);
59265747
} break;
59275748
default:
59285749
{

0 commit comments

Comments
 (0)