Skip to content

Commit f9c84ce

Browse files
committed
llama : fix rwkv inference
Signed-off-by: Molly Sophia <[email protected]>
1 parent 74b0807 commit f9c84ce

File tree

3 files changed

+429
-368
lines changed

3 files changed

+429
-368
lines changed

src/llama-context.cpp

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1970,6 +1970,229 @@ ggml_tensor * llama_context::build_mamba_layer(
19701970
}
19711971

19721972

1973+
ggml_tensor * llama_context::build_rwkv_token_shift_load(
1974+
ggml_context * ctx0,
1975+
ggml_cgraph * graph,
1976+
ggml_tensor * state_copy,
1977+
ggml_tensor * state_mask,
1978+
const llama_ubatch & ubatch,
1979+
int il,
1980+
bool worst_case) {
1981+
const auto & hparams = model.hparams;
1982+
1983+
const auto token_shift_count = hparams.token_shift_count;
1984+
1985+
const auto & n_tokens = ubatch.n_tokens;
1986+
const int64_t n_seqs = ubatch.n_seqs;
1987+
1988+
struct ggml_tensor * token_shift_all = kv_self.k_l[il];
1989+
1990+
struct ggml_tensor * token_shift = build_copy_mask_state(
1991+
ctx0, graph, token_shift_all, state_copy, state_mask,
1992+
n_tokens, hparams.n_embd_k_s(), n_seqs, worst_case);
1993+
1994+
token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
1995+
1996+
return token_shift;
1997+
}
1998+
1999+
2000+
ggml_tensor * llama_context::build_rwkv_token_shift_store(
2001+
ggml_context * ctx0,
2002+
ggml_tensor * token_shift,
2003+
const llama_ubatch & ubatch,
2004+
int il,
2005+
bool worst_case) {
2006+
const auto & hparams = model.hparams;
2007+
2008+
const auto token_shift_count = hparams.token_shift_count;
2009+
const auto n_embd = hparams.n_embd;
2010+
2011+
const auto & n_tokens = ubatch.n_tokens;
2012+
const int64_t n_seqs = ubatch.n_seqs;
2013+
2014+
const auto kv_head = worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head;
2015+
2016+
return ggml_cpy(
2017+
ctx0,
2018+
ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
2019+
ggml_view_1d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self.k_l[il]))
2020+
);
2021+
}
2022+
2023+
2024+
ggml_tensor * llama_context::build_rwkv6_time_mix(
2025+
ggml_context * ctx0,
2026+
ggml_cgraph * graph,
2027+
ggml_tensor * cur,
2028+
ggml_tensor * x_prev,
2029+
ggml_tensor * state_copy,
2030+
ggml_tensor * state_mask,
2031+
const llama_ubatch & ubatch,
2032+
int il,
2033+
bool worst_case) {
2034+
const auto & hparams = model.hparams;
2035+
2036+
const auto n_tokens = ubatch.n_tokens;
2037+
const auto n_seqs = ubatch.n_seqs;
2038+
const auto n_seq_tokens = ubatch.n_seq_tokens;
2039+
const auto n_embd = hparams.n_embd;
2040+
const auto head_size = hparams.wkv_head_size;
2041+
const auto n_head = n_embd / head_size;
2042+
const auto n_head_kv = hparams.n_head_kv(il);
2043+
2044+
const auto kv_head = worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head;
2045+
2046+
const auto layer = &model.layers[il];
2047+
2048+
bool is_qrwkv = layer->time_mix_first == nullptr;
2049+
2050+
struct ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur);
2051+
struct ggml_tensor * xxx = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->time_mix_lerp_x), cur);
2052+
2053+
xxx = ggml_reshape_4d(
2054+
ctx0,
2055+
ggml_tanh(
2056+
ctx0,
2057+
ggml_mul_mat(ctx0, layer->time_mix_w1, xxx)
2058+
),
2059+
layer->time_mix_w1->ne[1] / 5, 1, 5, n_tokens
2060+
);
2061+
2062+
xxx = ggml_cont(ctx0, ggml_permute(ctx0, xxx, 0, 1, 3, 2));
2063+
2064+
xxx = ggml_mul_mat(
2065+
ctx0,
2066+
ggml_reshape_4d(
2067+
ctx0,
2068+
layer->time_mix_w2,
2069+
layer->time_mix_w2->ne[0], layer->time_mix_w2->ne[1], 1, 5
2070+
),
2071+
xxx
2072+
);
2073+
2074+
struct ggml_tensor *xw, *xk, *xv, *xr, *xg;
2075+
if (layer->time_mix_lerp_fused) {
2076+
// fusing these weights makes some performance improvement
2077+
sx = ggml_reshape_3d(ctx0, sx, n_embd, 1, n_tokens);
2078+
cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
2079+
xxx = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xxx, layer->time_mix_lerp_fused), sx), cur);
2080+
xw = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], 0);
2081+
xk = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
2082+
xv = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
2083+
xr = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
2084+
xg = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
2085+
} else {
2086+
// for backward compatibility
2087+
xw = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], 0);
2088+
xk = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
2089+
xv = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
2090+
xr = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
2091+
xg = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
2092+
2093+
xw = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xw, layer->time_mix_lerp_w), sx), cur);
2094+
xk = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xk, layer->time_mix_lerp_k), sx), cur);
2095+
xv = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xv, layer->time_mix_lerp_v), sx), cur);
2096+
xr = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xr, layer->time_mix_lerp_r), sx), cur);
2097+
xg = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xg, layer->time_mix_lerp_g), sx), cur);
2098+
}
2099+
2100+
struct ggml_tensor * r = build_lora_mm(ctx0, layer->time_mix_receptance, xr);
2101+
struct ggml_tensor * k = build_lora_mm(ctx0, layer->time_mix_key, xk);
2102+
struct ggml_tensor * v = build_lora_mm(ctx0, layer->time_mix_value, xv);
2103+
if (layer->time_mix_receptance_b) {
2104+
r = ggml_add(ctx0, r, layer->time_mix_receptance_b);
2105+
}
2106+
if (layer->time_mix_key_b) {
2107+
k = ggml_add(ctx0, k, layer->time_mix_key_b);
2108+
}
2109+
if (layer->time_mix_value_b) {
2110+
v = ggml_add(ctx0, v, layer->time_mix_value_b);
2111+
}
2112+
2113+
struct ggml_tensor * g = build_lora_mm(ctx0, layer->time_mix_gate, xg);
2114+
if (is_qrwkv) {
2115+
g = ggml_sigmoid(ctx0, g);
2116+
} else {
2117+
g = ggml_silu(ctx0, g);
2118+
}
2119+
2120+
if (n_head_kv != 0 && n_head_kv != n_head) {
2121+
GGML_ASSERT(n_head % n_head_kv == 0);
2122+
k = ggml_reshape_4d(ctx0, k, head_size, 1, n_head_kv, n_tokens);
2123+
v = ggml_reshape_4d(ctx0, v, head_size, 1, n_head_kv, n_tokens);
2124+
struct ggml_tensor * tmp = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, head_size, n_head / n_head_kv, n_head_kv, n_tokens);
2125+
k = ggml_repeat(ctx0, k, tmp);
2126+
v = ggml_repeat(ctx0, v, tmp);
2127+
}
2128+
2129+
k = ggml_reshape_3d(ctx0, k, head_size, n_head, n_tokens);
2130+
v = ggml_reshape_3d(ctx0, v, head_size, n_head, n_tokens);
2131+
r = ggml_reshape_3d(ctx0, r, head_size, n_head, n_tokens);
2132+
2133+
struct ggml_tensor * w = ggml_mul_mat(
2134+
ctx0,
2135+
layer->time_mix_decay_w2,
2136+
ggml_tanh(
2137+
ctx0,
2138+
ggml_mul_mat(ctx0, layer->time_mix_decay_w1, xw)
2139+
)
2140+
);
2141+
2142+
w = ggml_add(ctx0, w, layer->time_mix_decay);
2143+
w = ggml_exp(ctx0, ggml_neg(ctx0, ggml_exp(ctx0, w)));
2144+
w = ggml_reshape_3d(ctx0, w, head_size, n_head, n_tokens);
2145+
2146+
if (is_qrwkv) {
2147+
// k = k * (1 - w)
2148+
k = ggml_sub(ctx0, k, ggml_mul(ctx0, k, w));
2149+
}
2150+
2151+
struct ggml_tensor * wkv_state = build_copy_mask_state(
2152+
ctx0, graph, kv_self.v_l[il], state_copy, state_mask,
2153+
n_tokens, hparams.n_embd_v_s(), n_seqs, worst_case);
2154+
2155+
struct ggml_tensor * wkv_output;
2156+
if (is_qrwkv) {
2157+
wkv_output = ggml_gated_linear_attn(ctx0, k, v, r, w, wkv_state, pow(head_size, -0.5f));
2158+
} else {
2159+
wkv_output = ggml_rwkv_wkv6(ctx0, k, v, r, layer->time_mix_first, w, wkv_state);
2160+
}
2161+
cur = ggml_view_1d(ctx0, wkv_output, n_embd * n_tokens, 0);
2162+
wkv_state = ggml_view_1d(ctx0, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float));
2163+
2164+
ggml_build_forward_expand(
2165+
graph,
2166+
ggml_cpy(
2167+
ctx0,
2168+
wkv_state,
2169+
ggml_view_1d(
2170+
ctx0,
2171+
kv_self.v_l[il],
2172+
hparams.n_embd_v_s() * n_seqs,
2173+
hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self.v_l[il])
2174+
)
2175+
)
2176+
);
2177+
2178+
if (!is_qrwkv) {
2179+
// group norm with head_count groups
2180+
cur = ggml_reshape_3d(ctx0, cur, n_embd / n_head, n_head, n_tokens);
2181+
cur = ggml_norm(ctx0, cur, 64e-5f);
2182+
2183+
// Convert back to regular vectors.
2184+
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
2185+
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer->time_mix_ln), layer->time_mix_ln_b);
2186+
} else {
2187+
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
2188+
}
2189+
2190+
cur = ggml_mul(ctx0, cur, g);
2191+
cur = build_lora_mm(ctx0, layer->time_mix_output, cur);
2192+
2193+
return cur;
2194+
}
2195+
19732196
// llama output
19742197

19752198
size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) {

src/llama-context.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,33 @@ struct llama_context {
248248
int il,
249249
bool worst_case);
250250

251+
ggml_tensor * build_rwkv_token_shift_load(
252+
ggml_context * ctx0,
253+
ggml_cgraph * graph,
254+
ggml_tensor * state_copy,
255+
ggml_tensor * state_mask,
256+
const llama_ubatch & ubatch,
257+
int il,
258+
bool worst_case);
259+
260+
ggml_tensor * build_rwkv_token_shift_store(
261+
ggml_context * ctx0,
262+
ggml_tensor * token_shift,
263+
const llama_ubatch & ubatch,
264+
int il,
265+
bool worst_case);
266+
267+
ggml_tensor * build_rwkv6_time_mix(
268+
ggml_context * ctx0,
269+
ggml_cgraph * graph,
270+
ggml_tensor * cur,
271+
ggml_tensor * x_prev,
272+
ggml_tensor * state_copy,
273+
ggml_tensor * state_mask,
274+
const llama_ubatch & ubatch,
275+
int il,
276+
bool worst_case);
277+
251278
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
252279
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
253280

0 commit comments

Comments
 (0)