Skip to content

Commit d977902

Browse files
committed
ggml : add support for ChatGLM RoPE
1 parent d38e451 commit d977902

File tree

2 files changed

+76
-13
lines changed

2 files changed

+76
-13
lines changed

ggml.c

Lines changed: 71 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6778,6 +6778,7 @@ struct ggml_tensor * ggml_rope_impl(
67786778
int n_past,
67796779
int n_dims,
67806780
int mode,
6781+
int n_ctx,
67816782
bool inplace) {
67826783
GGML_ASSERT(n_past >= 0);
67836784
bool is_node = false;
@@ -6790,11 +6791,12 @@ struct ggml_tensor * ggml_rope_impl(
67906791

67916792
ggml_scratch_save(ctx);
67926793

6793-
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
6794+
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4);
67946795

67956796
((int32_t *) b->data)[0] = n_past;
67966797
((int32_t *) b->data)[1] = n_dims;
67976798
((int32_t *) b->data)[2] = mode;
6799+
((int32_t *) b->data)[3] = n_ctx;
67986800

67996801
ggml_scratch_load(ctx);
68006802

@@ -6811,17 +6813,19 @@ struct ggml_tensor * ggml_rope(
68116813
struct ggml_tensor * a,
68126814
int n_past,
68136815
int n_dims,
6814-
int mode) {
6815-
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, false);
6816+
int mode,
6817+
int n_ctx) {
6818+
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, false);
68166819
}
68176820

68186821
struct ggml_tensor * ggml_rope_inplace(
68196822
struct ggml_context * ctx,
68206823
struct ggml_tensor * a,
68216824
int n_past,
68226825
int n_dims,
6823-
int mode) {
6824-
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, true);
6826+
int mode,
6827+
int n_ctx) {
6828+
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, true);
68256829
}
68266830

68276831
// ggml_rope_back
@@ -12440,7 +12444,7 @@ static void ggml_compute_forward_rope_f32(
1244012444
const struct ggml_tensor * src1,
1244112445
struct ggml_tensor * dst) {
1244212446
GGML_ASSERT(src1->type == GGML_TYPE_I32);
12443-
GGML_ASSERT(ggml_nelements(src1) == 3);
12447+
GGML_ASSERT(ggml_nelements(src1) == 4);
1244412448

1244512449
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
1244612450
return;
@@ -12449,6 +12453,7 @@ static void ggml_compute_forward_rope_f32(
1244912453
const int n_past = ((int32_t *) src1->data)[0];
1245012454
const int n_dims = ((int32_t *) src1->data)[1];
1245112455
const int mode = ((int32_t *) src1->data)[2];
12456+
const int n_ctx = ((int32_t *) src1->data)[3];
1245212457

1245312458
assert(n_past >= 0);
1245412459

@@ -12493,6 +12498,7 @@ static void ggml_compute_forward_rope_f32(
1249312498
const float theta_scale = powf(10000.0, -2.0f/n_dims);
1249412499

1249512500
const bool is_neox = mode & 2;
12501+
const bool is_glm = mode & 4;
1249612502

1249712503
for (int64_t i3 = 0; i3 < ne3; i3++) {
1249812504
for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
@@ -12503,7 +12509,32 @@ static void ggml_compute_forward_rope_f32(
1250312509

1250412510
float theta = (float)p;
1250512511

12506-
if (!is_neox) {
12512+
if (is_glm) {
12513+
theta = MIN(p, n_ctx - 2);
12514+
float block_theta = MAX(p - (n_ctx - 2), 0);
12515+
for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
12516+
const float cos_theta = cosf(theta);
12517+
const float sin_theta = sinf(theta);
12518+
const float cos_block_theta = cosf(block_theta);
12519+
const float sin_block_theta = sinf(block_theta);
12520+
12521+
theta *= theta_scale;
12522+
block_theta *= theta_scale;
12523+
12524+
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
12525+
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
12526+
12527+
const float x0 = src[0];
12528+
const float x1 = src[n_dims/2];
12529+
const float x2 = src[n_dims];
12530+
const float x3 = src[n_dims/2*3];
12531+
12532+
dst_data[0] = x0*cos_theta - x1*sin_theta;
12533+
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
12534+
dst_data[n_dims] = x2*cos_block_theta - x3*sin_block_theta;
12535+
dst_data[n_dims/2*3] = x2*sin_block_theta + x3*cos_block_theta;
12536+
}
12537+
} else if (!is_neox) {
1250712538
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
1250812539
const float cos_theta = cosf(theta);
1250912540
const float sin_theta = sinf(theta);
@@ -12553,7 +12584,7 @@ static void ggml_compute_forward_rope_f16(
1255312584
const struct ggml_tensor * src1,
1255412585
struct ggml_tensor * dst) {
1255512586
GGML_ASSERT(src1->type == GGML_TYPE_I32);
12556-
GGML_ASSERT(ggml_nelements(src1) == 3);
12587+
GGML_ASSERT(ggml_nelements(src1) == 4);
1255712588

1255812589
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
1255912590
return;
@@ -12562,6 +12593,7 @@ static void ggml_compute_forward_rope_f16(
1256212593
const int n_past = ((int32_t *) src1->data)[0];
1256312594
const int n_dims = ((int32_t *) src1->data)[1];
1256412595
const int mode = ((int32_t *) src1->data)[2];
12596+
const int n_ctx = ((int32_t *) src1->data)[3];
1256512597

1256612598
assert(n_past >= 0);
1256712599

@@ -12606,6 +12638,7 @@ static void ggml_compute_forward_rope_f16(
1260612638
const float theta_scale = powf(10000.0, -2.0f/n_dims);
1260712639

1260812640
const bool is_neox = mode & 2;
12641+
const bool is_glm = mode & 4;
1260912642

1261012643
for (int64_t i3 = 0; i3 < ne3; i3++) {
1261112644
for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
@@ -12616,7 +12649,32 @@ static void ggml_compute_forward_rope_f16(
1261612649

1261712650
float theta = (float)p;
1261812651

12619-
if (!is_neox) {
12652+
if (is_glm) {
12653+
theta = MIN(p, n_ctx - 2);
12654+
float block_theta = MAX(p - (n_ctx - 2), 0);
12655+
for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
12656+
const float cos_theta = cosf(theta);
12657+
const float sin_theta = sinf(theta);
12658+
const float cos_block_theta = cosf(block_theta);
12659+
const float sin_block_theta = sinf(block_theta);
12660+
12661+
theta *= theta_scale;
12662+
block_theta *= theta_scale;
12663+
12664+
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
12665+
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
12666+
12667+
const float x0 = GGML_FP16_TO_FP32(src[0]);
12668+
const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
12669+
const float x2 = GGML_FP16_TO_FP32(src[n_dims]);
12670+
const float x3 = GGML_FP16_TO_FP32(src[n_dims/2*3]);
12671+
12672+
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
12673+
dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
12674+
dst_data[n_dims] = GGML_FP32_TO_FP16(x2*cos_block_theta - x3*sin_block_theta);
12675+
dst_data[n_dims/2*3] = GGML_FP32_TO_FP16(x2*sin_block_theta + x3*cos_block_theta);
12676+
}
12677+
} if (!is_neox) {
1262012678
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
1262112679
const float cos_theta = cosf(theta);
1262212680
const float sin_theta = sinf(theta);
@@ -16189,17 +16247,19 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
1618916247
{
1619016248
if (src0->grad) {
1619116249
assert(src1->type == GGML_TYPE_I32);
16192-
assert(ggml_nelements(src1) == 3);
16250+
assert(ggml_nelements(src1) == 4);
1619316251
const int n_past = ((int32_t *) src1->data)[0];
1619416252
const int n_dims = ((int32_t *) src1->data)[1];
1619516253
const int mode = ((int32_t *) src1->data)[2];
16254+
const int n_ctx = ((int32_t *) src1->data)[3];
1619616255
src0->grad = ggml_add_impl(ctx,
1619716256
src0->grad,
1619816257
ggml_rope(ctx,
1619916258
tensor->grad,
1620016259
n_past,
1620116260
n_dims,
16202-
mode),
16261+
mode,
16262+
n_ctx),
1620316263
inplace);
1620416264
}
1620516265
if (src1->grad) {

ggml.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,21 +1036,24 @@ extern "C" {
10361036
// rotary position embedding
10371037
// if mode & 1 == 1, skip n_past elements
10381038
// if mode & 2 == 1, GPT-NeoX style
1039+
// if mode & 4 == 1, ChatGLM style
10391040
// TODO: avoid creating a new tensor every time
10401041
GGML_API struct ggml_tensor * ggml_rope(
10411042
struct ggml_context * ctx,
10421043
struct ggml_tensor * a,
10431044
int n_past,
10441045
int n_dims,
1045-
int mode);
1046+
int mode,
1047+
int n_ctx);
10461048

10471049
// in-place, returns view(a)
10481050
GGML_API struct ggml_tensor * ggml_rope_inplace(
10491051
struct ggml_context * ctx,
10501052
struct ggml_tensor * a,
10511053
int n_past,
10521054
int n_dims,
1053-
int mode);
1055+
int mode,
1056+
int n_ctx);
10541057

10551058
// rotary position embedding backward, i.e compute dx from dy
10561059
// a - dy

0 commit comments

Comments
 (0)