@@ -2422,6 +2422,7 @@ inline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x
24222422// TODO: optimize performance
24232423inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
24242424inline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
2425+ inline static void ggml_vec_exp_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = expf(x[i]); }
24252426
24262427static const float GELU_COEF_A = 0.044715f;
24272428static const float GELU_QUICK_COEF = -1.702f;
@@ -2932,6 +2933,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
29322933 "WIN_UNPART",
29332934 "GET_REL_POS",
29342935 "ADD_REL_POS",
2936+ "RWKV_WKV",
29352937
29362938 "UNARY",
29372939
@@ -2950,7 +2952,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
29502952 "CROSS_ENTROPY_LOSS_BACK",
29512953};
29522954
2953- static_assert(GGML_OP_COUNT == 78 , "GGML_OP_COUNT != 78 ");
2955+ static_assert(GGML_OP_COUNT == 79 , "GGML_OP_COUNT != 79 ");
29542956
29552957static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
29562958 "none",
@@ -3024,6 +3026,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
30243026 "win_unpart(x)",
30253027 "get_rel_pos(x)",
30263028 "add_rel_pos(x)",
3029+ "rwkv_wkv(k, v, r, tf, td, s)",
30273030
30283031 "unary(x)",
30293032
@@ -3042,7 +3045,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
30423045 "cross_entropy_loss_back(x,y)",
30433046};
30443047
3045- static_assert(GGML_OP_COUNT == 78 , "GGML_OP_COUNT != 78 ");
3048+ static_assert(GGML_OP_COUNT == 79 , "GGML_OP_COUNT != 79 ");
30463049
30473050static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
30483051
@@ -3061,9 +3064,10 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
30613064 "SILU",
30623065 "HARDSWISH",
30633066 "HARDSIGMOID",
3067+ "EXP",
30643068};
30653069
3066- static_assert(GGML_UNARY_OP_COUNT == 13 , "GGML_UNARY_OP_COUNT != 13 ");
3070+ static_assert(GGML_UNARY_OP_COUNT == 14 , "GGML_UNARY_OP_COUNT != 14 ");
30673071
30683072
30693073static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
@@ -5466,6 +5470,19 @@ struct ggml_tensor * ggml_hardsigmoid(
54665470 return ggml_unary(ctx, a, GGML_UNARY_OP_HARDSIGMOID);
54675471}
54685472
5473+ // ggml exp
5474+ struct ggml_tensor * ggml_exp(
5475+ struct ggml_context * ctx,
5476+ struct ggml_tensor * a) {
5477+ return ggml_unary(ctx, a, GGML_UNARY_OP_EXP);
5478+ }
5479+
5480+ struct ggml_tensor * ggml_exp_inplace(
5481+ struct ggml_context * ctx,
5482+ struct ggml_tensor * a) {
5483+ return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXP);
5484+ }
5485+
54695486// ggml_norm
54705487
54715488static struct ggml_tensor * ggml_norm_impl(
@@ -7734,6 +7751,59 @@ struct ggml_tensor * ggml_add_rel_pos_inplace(
77347751 return ggml_add_rel_pos_impl(ctx, a, pw, ph, true);
77357752}
77367753
7754+ // ggml_rwkv_wkv
7755+
7756+ struct ggml_tensor * ggml_rwkv_wkv(
7757+ struct ggml_context * ctx,
7758+ struct ggml_tensor * k,
7759+ struct ggml_tensor * v,
7760+ struct ggml_tensor * r,
7761+ struct ggml_tensor * tf,
7762+ struct ggml_tensor * td,
7763+ struct ggml_tensor * state) {
7764+ GGML_ASSERT(ggml_is_contiguous(k));
7765+ GGML_ASSERT(ggml_is_contiguous(v));
7766+ GGML_ASSERT(ggml_is_contiguous(r));
7767+ GGML_ASSERT(ggml_is_contiguous(tf));
7768+ GGML_ASSERT(ggml_is_contiguous(td));
7769+ GGML_ASSERT(ggml_is_contiguous(state));
7770+
7771+ const int64_t S = k->ne[0];
7772+ const int64_t H = k->ne[2];
7773+ const int64_t n_tokens = k->ne[3];
7774+ const int64_t n_seqs = state->ne[1];
7775+ {
7776+ GGML_ASSERT(k->ne[1] == 1);
7777+ GGML_ASSERT(v->ne[0] == 1 && v->ne[1] == S && v->ne[2] == H && v->ne[3] == n_tokens);
7778+ GGML_ASSERT(r->ne[0] == 1 && r->ne[1] == S && r->ne[2] == H && r->ne[3] == n_tokens);
7779+ // TODO: RWKV v4 and v5
7780+ GGML_ASSERT(td->ne[0] == 1 && td->ne[1] == S && td->ne[2] == H && td->ne[3] == n_tokens);
7781+ GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
7782+ }
7783+
7784+ bool is_node = false;
7785+
7786+ if (k->grad || v->grad || r->grad || tf->grad || td->grad || state->grad) {
7787+ GGML_ABORT("fatal error"); // TODO: implement backward
7788+ is_node = true;
7789+ }
7790+
7791+ // concat output and new_state
7792+ const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
7793+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
7794+
7795+ result->op = GGML_OP_RWKV_WKV;
7796+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7797+ result->src[0] = k;
7798+ result->src[1] = v;
7799+ result->src[2] = r;
7800+ result->src[3] = tf;
7801+ result->src[4] = td;
7802+ result->src[5] = state;
7803+
7804+ return result;
7805+ }
7806+
77377807// ggml_unary
77387808
77397809static struct ggml_tensor * ggml_unary_impl(
@@ -12114,6 +12184,48 @@ static void ggml_compute_forward_hardsigmoid(
1211412184 }
1211512185}
1211612186
12187+ static void ggml_compute_forward_exp_f32(
12188+ const struct ggml_compute_params * params,
12189+ struct ggml_tensor * dst) {
12190+
12191+ const struct ggml_tensor * src0 = dst->src[0];
12192+
12193+ if (params->ith != 0) {
12194+ return;
12195+ }
12196+
12197+ assert(ggml_is_contiguous_1(src0));
12198+ assert(ggml_is_contiguous_1(dst));
12199+ assert(ggml_are_same_shape(src0, dst));
12200+
12201+ const int n = ggml_nrows(src0);
12202+ const int nc = src0->ne[0];
12203+
12204+ for (int i = 0; i < n; i++) {
12205+ ggml_vec_exp_f32(nc,
12206+ (float *) ((char *) dst->data + i*( dst->nb[1])),
12207+ (float *) ((char *) src0->data + i*(src0->nb[1])));
12208+ }
12209+ }
12210+
12211+ static void ggml_compute_forward_exp(
12212+ const struct ggml_compute_params * params,
12213+ struct ggml_tensor * dst) {
12214+
12215+ const struct ggml_tensor * src0 = dst->src[0];
12216+
12217+ switch (src0->type) {
12218+ case GGML_TYPE_F32:
12219+ {
12220+ ggml_compute_forward_exp_f32(params, dst);
12221+ } break;
12222+ default:
12223+ {
12224+ GGML_ABORT("fatal error");
12225+ }
12226+ }
12227+ }
12228+
1211712229
1211812230// ggml_compute_forward_norm
1211912231
@@ -16692,6 +16804,10 @@ static void ggml_compute_forward_unary(
1669216804 {
1669316805 ggml_compute_forward_hardsigmoid(params, dst);
1669416806 } break;
16807+ case GGML_UNARY_OP_EXP:
16808+ {
16809+ ggml_compute_forward_exp(params, dst);
16810+ } break;
1669516811 default:
1669616812 {
1669716813 GGML_ABORT("fatal error");
@@ -16827,6 +16943,96 @@ static void ggml_compute_forward_add_rel_pos(
1682716943 }
1682816944}
1682916945
16946+ // ggml_compute_forward_rwkv_wkv
16947+
16948+ static void ggml_compute_forward_rwkv_wkv_f32(
16949+ const struct ggml_compute_params * params,
16950+ struct ggml_tensor * dst) {
16951+ const size_t T = dst->src[1]->ne[3];
16952+ const size_t C = dst->ne[0];
16953+ const size_t H = dst->src[1]->ne[2];
16954+ const size_t n_seqs = dst->src[5]->ne[1];
16955+
16956+ float * dst_data = (float *) dst->data;
16957+ float * state = ((float *) dst->data) + C * T;
16958+
16959+ if (params->ith != 0) {
16960+ return;
16961+ }
16962+
16963+ memset(dst_data, 0, T * C * sizeof(float));
16964+
16965+ float * k = (float *) dst->src[0]->data;
16966+ float * v = (float *) dst->src[1]->data;
16967+ float * r = (float *) dst->src[2]->data;
16968+ float * time_faaaa = (float *) dst->src[3]->data;
16969+ float * time_decay = (float *) dst->src[4]->data;
16970+
16971+ size_t t_stride = H * (C / H);
16972+
16973+ size_t h_stride = C / H;
16974+ size_t h_stride_2d = (C / H) * (C / H);
16975+
16976+ // basically fused operations:
16977+ // dst = r @ (time_faaaa * (k @ v) + state),
16978+ // state = time_decay * state + (k @ v),
16979+ // recursive through each token
16980+ for (size_t t = 0; t < T; t++) {
16981+ size_t t_offset = t * t_stride;
16982+ size_t state_offset = (C / H) * C * (t / (T / n_seqs));
16983+ float * state_cur = state + state_offset;
16984+ float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
16985+
16986+ for (size_t h = 0; h < H; h++) {
16987+ size_t h_offset = h * h_stride;
16988+ size_t t_h_offset = t_offset + h_offset;
16989+ size_t h_2d_offset = h * h_stride_2d;
16990+
16991+ for (size_t i = 0; i < C / H; i++) {
16992+ size_t t_h_i_offset = t_h_offset + i;
16993+ size_t h_i_offset = h_offset + i;
16994+ size_t h_2d_i_offset = h_2d_offset + i * h_stride;
16995+
16996+ float k_val = k[t_h_i_offset];
16997+ float r_val = r[t_h_i_offset];
16998+ float time_faaaa_val = time_faaaa[h_i_offset];
16999+ // RWKV v6: different time_decay for each token.
17000+ float time_decay_val = time_decay[t_h_i_offset];
17001+
17002+ for (size_t j = 0; j < C / H; j ++) {
17003+ size_t t_h_j_offset = t_h_offset + j;
17004+ size_t h_2d_i_j_offset = h_2d_i_offset + j;
17005+
17006+ float v_val = v[t_h_j_offset];
17007+ float kv_val = v_val * k_val;
17008+ float prev_state_val = state_prev[h_2d_i_j_offset];
17009+ float temp_val = kv_val * time_faaaa_val + prev_state_val;
17010+ dst_data[t_h_j_offset] += temp_val * r_val;
17011+ state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
17012+ }
17013+ }
17014+ }
17015+ }
17016+ }
17017+
17018+ static void ggml_compute_forward_rwkv_wkv(
17019+ const struct ggml_compute_params * params,
17020+ struct ggml_tensor * dst) {
17021+
17022+ const struct ggml_tensor * src0 = dst->src[0];
17023+
17024+ switch (src0->type) {
17025+ case GGML_TYPE_F32:
17026+ {
17027+ ggml_compute_forward_rwkv_wkv_f32(params, dst);
17028+ } break;
17029+ default:
17030+ {
17031+ GGML_ABORT("fatal error");
17032+ }
17033+ }
17034+ }
17035+
1683017036// ggml_compute_forward_map_unary
1683117037
1683217038static void ggml_compute_forward_map_unary_f32(
@@ -17478,6 +17684,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
1747817684 {
1747917685 ggml_compute_forward_add_rel_pos(params, tensor);
1748017686 } break;
17687+ case GGML_OP_RWKV_WKV:
17688+ {
17689+ ggml_compute_forward_rwkv_wkv(params, tensor);
17690+ } break;
1748117691 case GGML_OP_MAP_UNARY:
1748217692 {
1748317693 ggml_unary_op_f32_t fun;
@@ -18591,12 +18801,22 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
1859118801 zero_table);
1859218802 }
1859318803 } break;
18804+ case GGML_UNARY_OP_EXP:
18805+ {
18806+ if (src0->grad) {
18807+ src0->grad = ggml_add_or_set(ctx,
18808+ src0->grad,
18809+ ggml_mul(ctx, tensor, tensor->grad),
18810+ zero_table);
18811+ }
18812+ } break;
1859418813 default:
1859518814 GGML_ABORT("fatal error");
1859618815 }
1859718816 } break;
1859818817 case GGML_OP_GET_REL_POS:
1859918818 case GGML_OP_ADD_REL_POS:
18819+ case GGML_OP_RWKV_WKV:
1860018820 case GGML_OP_MAP_UNARY:
1860118821 case GGML_OP_MAP_BINARY:
1860218822 case GGML_OP_MAP_CUSTOM1_F32:
@@ -19021,6 +19241,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
1902119241 case GGML_UNARY_OP_SIGMOID:
1902219242 case GGML_UNARY_OP_HARDSWISH:
1902319243 case GGML_UNARY_OP_HARDSIGMOID:
19244+ case GGML_UNARY_OP_EXP:
1902419245 {
1902519246 n_tasks = 1;
1902619247 } break;
@@ -19112,6 +19333,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
1911219333 case GGML_OP_WIN_PART:
1911319334 case GGML_OP_WIN_UNPART:
1911419335 case GGML_OP_GET_REL_POS:
19336+ case GGML_OP_RWKV_WKV:
1911519337 case GGML_OP_MAP_UNARY:
1911619338 case GGML_OP_MAP_BINARY:
1911719339 case GGML_OP_MAP_CUSTOM1_F32:
0 commit comments