Skip to content

Commit c032b8c

Browse files
MollySophiaLaylBongerscompiladeggerganov
committed
llama : support RWKV v6 models (llama/8980)
* convert_hf_to_gguf: Add support for RWKV v6 Signed-off-by: Molly Sophia <[email protected]> * Add RWKV tokenization * Fix build Signed-off-by: Molly Sophia <[email protected]> * Do not use special tokens when matching in RWKV tokenizer * Fix model loading * Add (broken) placeholder graph builder for RWKV * Add workaround for kv cache * Add logits conversion to rwkv5 * Add rwkv5 layer norms * Add time mix KVRG & correct merge mistake * Add remaining time mix parameters * Add time mix output loading * Add placeholder llm_build_time_mix * Fix build Signed-off-by: Molly Sophia <[email protected]> * Load more tensors for rwkv v6 Signed-off-by: Molly Sophia <[email protected]> * Fix rwkv tokenizer Signed-off-by: Molly Sophia <[email protected]> * ggml: Add unary operator Exp Signed-off-by: Molly Sophia <[email protected]> * RWKV v6 graph building Signed-off-by: Molly Sophia <[email protected]> * Add ``rescale_every_n_layers`` parameter Signed-off-by: Molly Sophia <[email protected]> * Add ``wkv.head_size`` key for RWKV so it doesn't reuse Mamba ssm parameters Signed-off-by: Molly Sophia <[email protected]> * Fix offloading layers to CUDA Signed-off-by: Molly Sophia <[email protected]> * Fix parallel inferencing for RWKV Signed-off-by: Molly Sophia <[email protected]> * Remove trailing whitespaces Signed-off-by: Molly Sophia <[email protected]> * build_rwkv: Avoid using inplace operations Signed-off-by: Molly Sophia <[email protected]> * convert_hf_to_gguf: rwkv: Avoid using ``eval`` Signed-off-by: Molly Sophia <[email protected]> * convert_hf_to_gguf: rwkv tokenizer: Don't escape sequences manually Signed-off-by: Molly Sophia <[email protected]> * Update convert_hf_to_gguf.py Co-authored-by: compilade <[email protected]> * ggml: Add backward computation for unary op ``exp`` Signed-off-by: Molly Sophia <[email protected]> * Update convert_hf_to_gguf.py Co-authored-by: compilade <[email protected]> * Update convert_hf_to_gguf.py Co-authored-by: compilade <[email protected]> * Use MODEL_ARCH.RWKV6 instead of MODEL_ARCH.RWKV Signed-off-by: Molly Sophia <[email protected]> * build_rwkv6: Simplify graph Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Detect model.type Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Fix tensor loading for 7B/14B models Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Fix group_norm assertion failure with Metal Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Clean up Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Add quantization tensor exclusion Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Use the new advanced batch splits Signed-off-by: Molly Sophia <[email protected]> * Update src/llama.cpp Co-authored-by: compilade <[email protected]> * llama: rwkv6: Use ``ggml_norm`` instead of ``ggml_group_norm`` Co-authored-by: compilade <[email protected]> * llama: rwkv6: Apply code style and misc changes Signed-off-by: Molly Sophia <[email protected]> * converter: Use class name ``Rwkv6Model`` Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Make use of key ``feed_forward_length`` Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Add kv ``time_mix_extra_dim`` and ``time_decay_extra_dim`` Signed-off-by: Molly Sophia <[email protected]> * converter: Match ``new_name`` instead of ``name`` for float32 explicit tensors Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Keep ``time_mix_w1/w2`` as F32 Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Remove unused nodes Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Apply code format changes Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Add lora for some supported tensors Currently att.key/receptance/value/gate/output, ffn.receptance/key/value, as well as head.weight Signed-off-by: Molly Sophia <[email protected]> * rwkv : speed-up tokenization using trie * minor : style + indentation * llama: rwkv6: Avoid division by zero Co-authored-by: compilade <[email protected]> * ggml: rwkv_wkv: Avoid copying the state Signed-off-by: Molly Sophia <[email protected]> --------- Signed-off-by: Molly Sophia <[email protected]> Co-authored-by: Layl Bongers <[email protected]> Co-authored-by: compilade <[email protected]> Co-authored-by: Georgi Gerganov <[email protected]>
1 parent c584042 commit c032b8c

File tree

2 files changed

+244
-3
lines changed

2 files changed

+244
-3
lines changed

include/ggml.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,7 @@ extern "C" {
514514
GGML_OP_WIN_UNPART,
515515
GGML_OP_GET_REL_POS,
516516
GGML_OP_ADD_REL_POS,
517+
GGML_OP_RWKV_WKV,
517518

518519
GGML_OP_UNARY,
519520

@@ -548,6 +549,7 @@ extern "C" {
548549
GGML_UNARY_OP_SILU,
549550
GGML_UNARY_OP_HARDSWISH,
550551
GGML_UNARY_OP_HARDSIGMOID,
552+
GGML_UNARY_OP_EXP,
551553

552554
GGML_UNARY_OP_COUNT,
553555
};
@@ -1165,6 +1167,14 @@ extern "C" {
11651167
struct ggml_context * ctx,
11661168
struct ggml_tensor * a);
11671169

1170+
GGML_API struct ggml_tensor * ggml_exp(
1171+
struct ggml_context * ctx,
1172+
struct ggml_tensor * a);
1173+
1174+
GGML_API struct ggml_tensor * ggml_exp_inplace(
1175+
struct ggml_context * ctx,
1176+
struct ggml_tensor * a);
1177+
11681178
// normalize along rows
11691179
GGML_API struct ggml_tensor * ggml_norm(
11701180
struct ggml_context * ctx,
@@ -1913,6 +1923,15 @@ extern "C" {
19131923
struct ggml_tensor * pw,
19141924
struct ggml_tensor * ph);
19151925

1926+
GGML_API struct ggml_tensor * ggml_rwkv_wkv(
1927+
struct ggml_context * ctx,
1928+
struct ggml_tensor * k,
1929+
struct ggml_tensor * v,
1930+
struct ggml_tensor * r,
1931+
struct ggml_tensor * tf,
1932+
struct ggml_tensor * td,
1933+
struct ggml_tensor * state);
1934+
19161935
// custom operators
19171936

19181937
typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);

src/ggml.c

Lines changed: 225 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2422,6 +2422,7 @@ inline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x
24222422
// TODO: optimize performance
24232423
inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
24242424
inline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
2425+
inline static void ggml_vec_exp_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = expf(x[i]); }
24252426

24262427
static const float GELU_COEF_A = 0.044715f;
24272428
static const float GELU_QUICK_COEF = -1.702f;
@@ -2932,6 +2933,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
29322933
"WIN_UNPART",
29332934
"GET_REL_POS",
29342935
"ADD_REL_POS",
2936+
"RWKV_WKV",
29352937

29362938
"UNARY",
29372939

@@ -2950,7 +2952,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
29502952
"CROSS_ENTROPY_LOSS_BACK",
29512953
};
29522954

2953-
static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78");
2955+
static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");
29542956

29552957
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
29562958
"none",
@@ -3024,6 +3026,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
30243026
"win_unpart(x)",
30253027
"get_rel_pos(x)",
30263028
"add_rel_pos(x)",
3029+
"rwkv_wkv(k, v, r, tf, td, s)",
30273030

30283031
"unary(x)",
30293032

@@ -3042,7 +3045,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
30423045
"cross_entropy_loss_back(x,y)",
30433046
};
30443047

3045-
static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78");
3048+
static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");
30463049

30473050
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
30483051

@@ -3061,9 +3064,10 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
30613064
"SILU",
30623065
"HARDSWISH",
30633066
"HARDSIGMOID",
3067+
"EXP",
30643068
};
30653069

3066-
static_assert(GGML_UNARY_OP_COUNT == 13, "GGML_UNARY_OP_COUNT != 13");
3070+
static_assert(GGML_UNARY_OP_COUNT == 14, "GGML_UNARY_OP_COUNT != 14");
30673071

30683072

30693073
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
@@ -5466,6 +5470,19 @@ struct ggml_tensor * ggml_hardsigmoid(
54665470
return ggml_unary(ctx, a, GGML_UNARY_OP_HARDSIGMOID);
54675471
}
54685472

5473+
// ggml exp
5474+
struct ggml_tensor * ggml_exp(
5475+
struct ggml_context * ctx,
5476+
struct ggml_tensor * a) {
5477+
return ggml_unary(ctx, a, GGML_UNARY_OP_EXP);
5478+
}
5479+
5480+
struct ggml_tensor * ggml_exp_inplace(
5481+
struct ggml_context * ctx,
5482+
struct ggml_tensor * a) {
5483+
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXP);
5484+
}
5485+
54695486
// ggml_norm
54705487

54715488
static struct ggml_tensor * ggml_norm_impl(
@@ -7734,6 +7751,59 @@ struct ggml_tensor * ggml_add_rel_pos_inplace(
77347751
return ggml_add_rel_pos_impl(ctx, a, pw, ph, true);
77357752
}
77367753

7754+
// ggml_rwkv_wkv
7755+
7756+
struct ggml_tensor * ggml_rwkv_wkv(
7757+
struct ggml_context * ctx,
7758+
struct ggml_tensor * k,
7759+
struct ggml_tensor * v,
7760+
struct ggml_tensor * r,
7761+
struct ggml_tensor * tf,
7762+
struct ggml_tensor * td,
7763+
struct ggml_tensor * state) {
7764+
GGML_ASSERT(ggml_is_contiguous(k));
7765+
GGML_ASSERT(ggml_is_contiguous(v));
7766+
GGML_ASSERT(ggml_is_contiguous(r));
7767+
GGML_ASSERT(ggml_is_contiguous(tf));
7768+
GGML_ASSERT(ggml_is_contiguous(td));
7769+
GGML_ASSERT(ggml_is_contiguous(state));
7770+
7771+
const int64_t S = k->ne[0];
7772+
const int64_t H = k->ne[2];
7773+
const int64_t n_tokens = k->ne[3];
7774+
const int64_t n_seqs = state->ne[1];
7775+
{
7776+
GGML_ASSERT(k->ne[1] == 1);
7777+
GGML_ASSERT(v->ne[0] == 1 && v->ne[1] == S && v->ne[2] == H && v->ne[3] == n_tokens);
7778+
GGML_ASSERT(r->ne[0] == 1 && r->ne[1] == S && r->ne[2] == H && r->ne[3] == n_tokens);
7779+
// TODO: RWKV v4 and v5
7780+
GGML_ASSERT(td->ne[0] == 1 && td->ne[1] == S && td->ne[2] == H && td->ne[3] == n_tokens);
7781+
GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
7782+
}
7783+
7784+
bool is_node = false;
7785+
7786+
if (k->grad || v->grad || r->grad || tf->grad || td->grad || state->grad) {
7787+
GGML_ABORT("fatal error"); // TODO: implement backward
7788+
is_node = true;
7789+
}
7790+
7791+
// concat output and new_state
7792+
const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
7793+
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
7794+
7795+
result->op = GGML_OP_RWKV_WKV;
7796+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7797+
result->src[0] = k;
7798+
result->src[1] = v;
7799+
result->src[2] = r;
7800+
result->src[3] = tf;
7801+
result->src[4] = td;
7802+
result->src[5] = state;
7803+
7804+
return result;
7805+
}
7806+
77377807
// ggml_unary
77387808

77397809
static struct ggml_tensor * ggml_unary_impl(
@@ -12114,6 +12184,48 @@ static void ggml_compute_forward_hardsigmoid(
1211412184
}
1211512185
}
1211612186

12187+
static void ggml_compute_forward_exp_f32(
12188+
const struct ggml_compute_params * params,
12189+
struct ggml_tensor * dst) {
12190+
12191+
const struct ggml_tensor * src0 = dst->src[0];
12192+
12193+
if (params->ith != 0) {
12194+
return;
12195+
}
12196+
12197+
assert(ggml_is_contiguous_1(src0));
12198+
assert(ggml_is_contiguous_1(dst));
12199+
assert(ggml_are_same_shape(src0, dst));
12200+
12201+
const int n = ggml_nrows(src0);
12202+
const int nc = src0->ne[0];
12203+
12204+
for (int i = 0; i < n; i++) {
12205+
ggml_vec_exp_f32(nc,
12206+
(float *) ((char *) dst->data + i*( dst->nb[1])),
12207+
(float *) ((char *) src0->data + i*(src0->nb[1])));
12208+
}
12209+
}
12210+
12211+
static void ggml_compute_forward_exp(
12212+
const struct ggml_compute_params * params,
12213+
struct ggml_tensor * dst) {
12214+
12215+
const struct ggml_tensor * src0 = dst->src[0];
12216+
12217+
switch (src0->type) {
12218+
case GGML_TYPE_F32:
12219+
{
12220+
ggml_compute_forward_exp_f32(params, dst);
12221+
} break;
12222+
default:
12223+
{
12224+
GGML_ABORT("fatal error");
12225+
}
12226+
}
12227+
}
12228+
1211712229

1211812230
// ggml_compute_forward_norm
1211912231

@@ -16692,6 +16804,10 @@ static void ggml_compute_forward_unary(
1669216804
{
1669316805
ggml_compute_forward_hardsigmoid(params, dst);
1669416806
} break;
16807+
case GGML_UNARY_OP_EXP:
16808+
{
16809+
ggml_compute_forward_exp(params, dst);
16810+
} break;
1669516811
default:
1669616812
{
1669716813
GGML_ABORT("fatal error");
@@ -16827,6 +16943,96 @@ static void ggml_compute_forward_add_rel_pos(
1682716943
}
1682816944
}
1682916945

16946+
// ggml_compute_forward_rwkv_wkv
16947+
16948+
static void ggml_compute_forward_rwkv_wkv_f32(
16949+
const struct ggml_compute_params * params,
16950+
struct ggml_tensor * dst) {
16951+
const size_t T = dst->src[1]->ne[3];
16952+
const size_t C = dst->ne[0];
16953+
const size_t H = dst->src[1]->ne[2];
16954+
const size_t n_seqs = dst->src[5]->ne[1];
16955+
16956+
float * dst_data = (float *) dst->data;
16957+
float * state = ((float *) dst->data) + C * T;
16958+
16959+
if (params->ith != 0) {
16960+
return;
16961+
}
16962+
16963+
memset(dst_data, 0, T * C * sizeof(float));
16964+
16965+
float * k = (float *) dst->src[0]->data;
16966+
float * v = (float *) dst->src[1]->data;
16967+
float * r = (float *) dst->src[2]->data;
16968+
float * time_faaaa = (float *) dst->src[3]->data;
16969+
float * time_decay = (float *) dst->src[4]->data;
16970+
16971+
size_t t_stride = H * (C / H);
16972+
16973+
size_t h_stride = C / H;
16974+
size_t h_stride_2d = (C / H) * (C / H);
16975+
16976+
// basically fused operations:
16977+
// dst = r @ (time_faaaa * (k @ v) + state),
16978+
// state = time_decay * state + (k @ v),
16979+
// recursive through each token
16980+
for (size_t t = 0; t < T; t++) {
16981+
size_t t_offset = t * t_stride;
16982+
size_t state_offset = (C / H) * C * (t / (T / n_seqs));
16983+
float * state_cur = state + state_offset;
16984+
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
16985+
16986+
for (size_t h = 0; h < H; h++) {
16987+
size_t h_offset = h * h_stride;
16988+
size_t t_h_offset = t_offset + h_offset;
16989+
size_t h_2d_offset = h * h_stride_2d;
16990+
16991+
for (size_t i = 0; i < C / H; i++) {
16992+
size_t t_h_i_offset = t_h_offset + i;
16993+
size_t h_i_offset = h_offset + i;
16994+
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
16995+
16996+
float k_val = k[t_h_i_offset];
16997+
float r_val = r[t_h_i_offset];
16998+
float time_faaaa_val = time_faaaa[h_i_offset];
16999+
// RWKV v6: different time_decay for each token.
17000+
float time_decay_val = time_decay[t_h_i_offset];
17001+
17002+
for (size_t j = 0; j < C / H; j ++) {
17003+
size_t t_h_j_offset = t_h_offset + j;
17004+
size_t h_2d_i_j_offset = h_2d_i_offset + j;
17005+
17006+
float v_val = v[t_h_j_offset];
17007+
float kv_val = v_val * k_val;
17008+
float prev_state_val = state_prev[h_2d_i_j_offset];
17009+
float temp_val = kv_val * time_faaaa_val + prev_state_val;
17010+
dst_data[t_h_j_offset] += temp_val * r_val;
17011+
state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
17012+
}
17013+
}
17014+
}
17015+
}
17016+
}
17017+
17018+
static void ggml_compute_forward_rwkv_wkv(
17019+
const struct ggml_compute_params * params,
17020+
struct ggml_tensor * dst) {
17021+
17022+
const struct ggml_tensor * src0 = dst->src[0];
17023+
17024+
switch (src0->type) {
17025+
case GGML_TYPE_F32:
17026+
{
17027+
ggml_compute_forward_rwkv_wkv_f32(params, dst);
17028+
} break;
17029+
default:
17030+
{
17031+
GGML_ABORT("fatal error");
17032+
}
17033+
}
17034+
}
17035+
1683017036
// ggml_compute_forward_map_unary
1683117037

1683217038
static void ggml_compute_forward_map_unary_f32(
@@ -17478,6 +17684,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
1747817684
{
1747917685
ggml_compute_forward_add_rel_pos(params, tensor);
1748017686
} break;
17687+
case GGML_OP_RWKV_WKV:
17688+
{
17689+
ggml_compute_forward_rwkv_wkv(params, tensor);
17690+
} break;
1748117691
case GGML_OP_MAP_UNARY:
1748217692
{
1748317693
ggml_unary_op_f32_t fun;
@@ -18591,12 +18801,22 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
1859118801
zero_table);
1859218802
}
1859318803
} break;
18804+
case GGML_UNARY_OP_EXP:
18805+
{
18806+
if (src0->grad) {
18807+
src0->grad = ggml_add_or_set(ctx,
18808+
src0->grad,
18809+
ggml_mul(ctx, tensor, tensor->grad),
18810+
zero_table);
18811+
}
18812+
} break;
1859418813
default:
1859518814
GGML_ABORT("fatal error");
1859618815
}
1859718816
} break;
1859818817
case GGML_OP_GET_REL_POS:
1859918818
case GGML_OP_ADD_REL_POS:
18819+
case GGML_OP_RWKV_WKV:
1860018820
case GGML_OP_MAP_UNARY:
1860118821
case GGML_OP_MAP_BINARY:
1860218822
case GGML_OP_MAP_CUSTOM1_F32:
@@ -19021,6 +19241,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
1902119241
case GGML_UNARY_OP_SIGMOID:
1902219242
case GGML_UNARY_OP_HARDSWISH:
1902319243
case GGML_UNARY_OP_HARDSIGMOID:
19244+
case GGML_UNARY_OP_EXP:
1902419245
{
1902519246
n_tasks = 1;
1902619247
} break;
@@ -19112,6 +19333,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
1911219333
case GGML_OP_WIN_PART:
1911319334
case GGML_OP_WIN_UNPART:
1911419335
case GGML_OP_GET_REL_POS:
19336+
case GGML_OP_RWKV_WKV:
1911519337
case GGML_OP_MAP_UNARY:
1911619338
case GGML_OP_MAP_BINARY:
1911719339
case GGML_OP_MAP_CUSTOM1_F32:

0 commit comments

Comments
 (0)