@@ -6778,6 +6778,7 @@ struct ggml_tensor * ggml_rope_impl(
67786778 int n_past,
67796779 int n_dims,
67806780 int mode,
6781+ int n_ctx,
67816782 bool inplace) {
67826783 GGML_ASSERT(n_past >= 0);
67836784 bool is_node = false;
@@ -6790,11 +6791,12 @@ struct ggml_tensor * ggml_rope_impl(
67906791
67916792 ggml_scratch_save(ctx);
67926793
6793- struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3 );
6794+ struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4 );
67946795
67956796 ((int32_t *) b->data)[0] = n_past;
67966797 ((int32_t *) b->data)[1] = n_dims;
67976798 ((int32_t *) b->data)[2] = mode;
6799+ ((int32_t *) b->data)[3] = n_ctx;
67986800
67996801 ggml_scratch_load(ctx);
68006802
@@ -6811,17 +6813,19 @@ struct ggml_tensor * ggml_rope(
68116813 struct ggml_tensor * a,
68126814 int n_past,
68136815 int n_dims,
6814- int mode) {
6815- return ggml_rope_impl(ctx, a, n_past, n_dims, mode, false);
6816+ int mode,
6817+ int n_ctx) {
6818+ return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, false);
68166819}
68176820
68186821struct ggml_tensor * ggml_rope_inplace(
68196822 struct ggml_context * ctx,
68206823 struct ggml_tensor * a,
68216824 int n_past,
68226825 int n_dims,
6823- int mode) {
6824- return ggml_rope_impl(ctx, a, n_past, n_dims, mode, true);
6826+ int mode,
6827+ int n_ctx) {
6828+ return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, true);
68256829}
68266830
68276831// ggml_rope_back
@@ -12440,7 +12444,7 @@ static void ggml_compute_forward_rope_f32(
1244012444 const struct ggml_tensor * src1,
1244112445 struct ggml_tensor * dst) {
1244212446 GGML_ASSERT(src1->type == GGML_TYPE_I32);
12443- GGML_ASSERT(ggml_nelements(src1) == 3 );
12447+ GGML_ASSERT(ggml_nelements(src1) == 4 );
1244412448
1244512449 if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
1244612450 return;
@@ -12449,6 +12453,7 @@ static void ggml_compute_forward_rope_f32(
1244912453 const int n_past = ((int32_t *) src1->data)[0];
1245012454 const int n_dims = ((int32_t *) src1->data)[1];
1245112455 const int mode = ((int32_t *) src1->data)[2];
12456+ const int n_ctx = ((int32_t *) src1->data)[3];
1245212457
1245312458 assert(n_past >= 0);
1245412459
@@ -12493,6 +12498,7 @@ static void ggml_compute_forward_rope_f32(
1249312498 const float theta_scale = powf(10000.0, -2.0f/n_dims);
1249412499
1249512500 const bool is_neox = mode & 2;
12501+ const bool is_glm = mode & 4;
1249612502
1249712503 for (int64_t i3 = 0; i3 < ne3; i3++) {
1249812504 for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
@@ -12503,7 +12509,32 @@ static void ggml_compute_forward_rope_f32(
1250312509
1250412510 float theta = (float)p;
1250512511
12506- if (!is_neox) {
12512+ if (is_glm) {
12513+ theta = MIN(p, n_ctx - 2);
12514+ float block_theta = MAX(p - (n_ctx - 2), 0);
12515+ for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
12516+ const float cos_theta = cosf(theta);
12517+ const float sin_theta = sinf(theta);
12518+ const float cos_block_theta = cosf(block_theta);
12519+ const float sin_block_theta = sinf(block_theta);
12520+
12521+ theta *= theta_scale;
12522+ block_theta *= theta_scale;
12523+
12524+ const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
12525+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
12526+
12527+ const float x0 = src[0];
12528+ const float x1 = src[n_dims/2];
12529+ const float x2 = src[n_dims];
12530+ const float x3 = src[n_dims/2*3];
12531+
12532+ dst_data[0] = x0*cos_theta - x1*sin_theta;
12533+ dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
12534+ dst_data[n_dims] = x2*cos_block_theta - x3*sin_block_theta;
12535+ dst_data[n_dims/2*3] = x2*sin_block_theta + x3*cos_block_theta;
12536+ }
12537+ } else if (!is_neox) {
1250712538 for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
1250812539 const float cos_theta = cosf(theta);
1250912540 const float sin_theta = sinf(theta);
@@ -12553,7 +12584,7 @@ static void ggml_compute_forward_rope_f16(
1255312584 const struct ggml_tensor * src1,
1255412585 struct ggml_tensor * dst) {
1255512586 GGML_ASSERT(src1->type == GGML_TYPE_I32);
12556- GGML_ASSERT(ggml_nelements(src1) == 3 );
12587+ GGML_ASSERT(ggml_nelements(src1) == 4 );
1255712588
1255812589 if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
1255912590 return;
@@ -12562,6 +12593,7 @@ static void ggml_compute_forward_rope_f16(
1256212593 const int n_past = ((int32_t *) src1->data)[0];
1256312594 const int n_dims = ((int32_t *) src1->data)[1];
1256412595 const int mode = ((int32_t *) src1->data)[2];
12596+ const int n_ctx = ((int32_t *) src1->data)[3];
1256512597
1256612598 assert(n_past >= 0);
1256712599
@@ -12606,6 +12638,7 @@ static void ggml_compute_forward_rope_f16(
1260612638 const float theta_scale = powf(10000.0, -2.0f/n_dims);
1260712639
1260812640 const bool is_neox = mode & 2;
12641+ const bool is_glm = mode & 4;
1260912642
1261012643 for (int64_t i3 = 0; i3 < ne3; i3++) {
1261112644 for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
@@ -12616,7 +12649,32 @@ static void ggml_compute_forward_rope_f16(
1261612649
1261712650 float theta = (float)p;
1261812651
12619- if (!is_neox) {
12652+ if (is_glm) {
12653+ theta = MIN(p, n_ctx - 2);
12654+ float block_theta = MAX(p - (n_ctx - 2), 0);
12655+ for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
12656+ const float cos_theta = cosf(theta);
12657+ const float sin_theta = sinf(theta);
12658+ const float cos_block_theta = cosf(block_theta);
12659+ const float sin_block_theta = sinf(block_theta);
12660+
12661+ theta *= theta_scale;
12662+ block_theta *= theta_scale;
12663+
12664+ const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
12665+ ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
12666+
12667+ const float x0 = GGML_FP16_TO_FP32(src[0]);
12668+ const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
12669+ const float x2 = GGML_FP16_TO_FP32(src[n_dims]);
12670+ const float x3 = GGML_FP16_TO_FP32(src[n_dims/2*3]);
12671+
12672+ dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
12673+ dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
12674+ dst_data[n_dims] = GGML_FP32_TO_FP16(x2*cos_block_theta - x3*sin_block_theta);
12675+ dst_data[n_dims/2*3] = GGML_FP32_TO_FP16(x2*sin_block_theta + x3*cos_block_theta);
12676+ }
12677+ } if (!is_neox) {
1262012678 for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
1262112679 const float cos_theta = cosf(theta);
1262212680 const float sin_theta = sinf(theta);
@@ -16189,17 +16247,19 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
1618916247 {
1619016248 if (src0->grad) {
1619116249 assert(src1->type == GGML_TYPE_I32);
16192- assert(ggml_nelements(src1) == 3 );
16250+ assert(ggml_nelements(src1) == 4 );
1619316251 const int n_past = ((int32_t *) src1->data)[0];
1619416252 const int n_dims = ((int32_t *) src1->data)[1];
1619516253 const int mode = ((int32_t *) src1->data)[2];
16254+ const int n_ctx = ((int32_t *) src1->data)[3];
1619616255 src0->grad = ggml_add_impl(ctx,
1619716256 src0->grad,
1619816257 ggml_rope(ctx,
1619916258 tensor->grad,
1620016259 n_past,
1620116260 n_dims,
16202- mode),
16261+ mode,
16262+ n_ctx),
1620316263 inplace);
1620416264 }
1620516265 if (src1->grad) {
0 commit comments