Skip to content

Commit 1eca891

Browse files
authored
llama : fix rwkv inference (#11618)
Signed-off-by: Molly Sophia <[email protected]>
1 parent 74b0807 commit 1eca891

File tree

3 files changed

+428
-368
lines changed

3 files changed

+428
-368
lines changed

src/llama-context.cpp

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1970,6 +1970,228 @@ ggml_tensor * llama_context::build_mamba_layer(
19701970
}
19711971

19721972

1973+
ggml_tensor * llama_context::build_rwkv_token_shift_load(
1974+
ggml_context * ctx0,
1975+
ggml_cgraph * graph,
1976+
ggml_tensor * state_copy,
1977+
ggml_tensor * state_mask,
1978+
const llama_ubatch & ubatch,
1979+
int il,
1980+
bool worst_case) {
1981+
const auto & hparams = model.hparams;
1982+
1983+
const auto token_shift_count = hparams.token_shift_count;
1984+
1985+
const auto & n_tokens = ubatch.n_tokens;
1986+
const int64_t n_seqs = ubatch.n_seqs;
1987+
1988+
struct ggml_tensor * token_shift_all = kv_self.k_l[il];
1989+
1990+
struct ggml_tensor * token_shift = build_copy_mask_state(
1991+
ctx0, graph, token_shift_all, state_copy, state_mask,
1992+
n_tokens, hparams.n_embd_k_s(), n_seqs, worst_case);
1993+
1994+
token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
1995+
1996+
return token_shift;
1997+
}
1998+
1999+
2000+
ggml_tensor * llama_context::build_rwkv_token_shift_store(
2001+
ggml_context * ctx0,
2002+
ggml_tensor * token_shift,
2003+
const llama_ubatch & ubatch,
2004+
int il,
2005+
bool worst_case) {
2006+
const auto & hparams = model.hparams;
2007+
2008+
const auto token_shift_count = hparams.token_shift_count;
2009+
const auto n_embd = hparams.n_embd;
2010+
2011+
const auto & n_tokens = ubatch.n_tokens;
2012+
const int64_t n_seqs = ubatch.n_seqs;
2013+
2014+
const auto kv_head = worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head;
2015+
2016+
return ggml_cpy(
2017+
ctx0,
2018+
ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
2019+
ggml_view_1d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self.k_l[il]))
2020+
);
2021+
}
2022+
2023+
2024+
ggml_tensor * llama_context::build_rwkv6_time_mix(
2025+
ggml_context * ctx0,
2026+
ggml_cgraph * graph,
2027+
ggml_tensor * cur,
2028+
ggml_tensor * x_prev,
2029+
ggml_tensor * state_copy,
2030+
ggml_tensor * state_mask,
2031+
const llama_ubatch & ubatch,
2032+
int il,
2033+
bool worst_case) {
2034+
const auto & hparams = model.hparams;
2035+
2036+
const auto n_tokens = ubatch.n_tokens;
2037+
const auto n_seqs = ubatch.n_seqs;
2038+
const auto n_embd = hparams.n_embd;
2039+
const auto head_size = hparams.wkv_head_size;
2040+
const auto n_head = n_embd / head_size;
2041+
const auto n_head_kv = hparams.n_head_kv(il);
2042+
2043+
const auto kv_head = worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head;
2044+
2045+
const auto layer = &model.layers[il];
2046+
2047+
bool is_qrwkv = layer->time_mix_first == nullptr;
2048+
2049+
struct ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur);
2050+
struct ggml_tensor * xxx = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->time_mix_lerp_x), cur);
2051+
2052+
xxx = ggml_reshape_4d(
2053+
ctx0,
2054+
ggml_tanh(
2055+
ctx0,
2056+
ggml_mul_mat(ctx0, layer->time_mix_w1, xxx)
2057+
),
2058+
layer->time_mix_w1->ne[1] / 5, 1, 5, n_tokens
2059+
);
2060+
2061+
xxx = ggml_cont(ctx0, ggml_permute(ctx0, xxx, 0, 1, 3, 2));
2062+
2063+
xxx = ggml_mul_mat(
2064+
ctx0,
2065+
ggml_reshape_4d(
2066+
ctx0,
2067+
layer->time_mix_w2,
2068+
layer->time_mix_w2->ne[0], layer->time_mix_w2->ne[1], 1, 5
2069+
),
2070+
xxx
2071+
);
2072+
2073+
struct ggml_tensor *xw, *xk, *xv, *xr, *xg;
2074+
if (layer->time_mix_lerp_fused) {
2075+
// fusing these weights makes some performance improvement
2076+
sx = ggml_reshape_3d(ctx0, sx, n_embd, 1, n_tokens);
2077+
cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
2078+
xxx = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xxx, layer->time_mix_lerp_fused), sx), cur);
2079+
xw = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], 0);
2080+
xk = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
2081+
xv = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
2082+
xr = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
2083+
xg = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
2084+
} else {
2085+
// for backward compatibility
2086+
xw = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], 0);
2087+
xk = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
2088+
xv = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
2089+
xr = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
2090+
xg = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
2091+
2092+
xw = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xw, layer->time_mix_lerp_w), sx), cur);
2093+
xk = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xk, layer->time_mix_lerp_k), sx), cur);
2094+
xv = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xv, layer->time_mix_lerp_v), sx), cur);
2095+
xr = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xr, layer->time_mix_lerp_r), sx), cur);
2096+
xg = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xg, layer->time_mix_lerp_g), sx), cur);
2097+
}
2098+
2099+
struct ggml_tensor * r = build_lora_mm(ctx0, layer->time_mix_receptance, xr);
2100+
struct ggml_tensor * k = build_lora_mm(ctx0, layer->time_mix_key, xk);
2101+
struct ggml_tensor * v = build_lora_mm(ctx0, layer->time_mix_value, xv);
2102+
if (layer->time_mix_receptance_b) {
2103+
r = ggml_add(ctx0, r, layer->time_mix_receptance_b);
2104+
}
2105+
if (layer->time_mix_key_b) {
2106+
k = ggml_add(ctx0, k, layer->time_mix_key_b);
2107+
}
2108+
if (layer->time_mix_value_b) {
2109+
v = ggml_add(ctx0, v, layer->time_mix_value_b);
2110+
}
2111+
2112+
struct ggml_tensor * g = build_lora_mm(ctx0, layer->time_mix_gate, xg);
2113+
if (is_qrwkv) {
2114+
g = ggml_sigmoid(ctx0, g);
2115+
} else {
2116+
g = ggml_silu(ctx0, g);
2117+
}
2118+
2119+
if (n_head_kv != 0 && n_head_kv != n_head) {
2120+
GGML_ASSERT(n_head % n_head_kv == 0);
2121+
k = ggml_reshape_4d(ctx0, k, head_size, 1, n_head_kv, n_tokens);
2122+
v = ggml_reshape_4d(ctx0, v, head_size, 1, n_head_kv, n_tokens);
2123+
struct ggml_tensor * tmp = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, head_size, n_head / n_head_kv, n_head_kv, n_tokens);
2124+
k = ggml_repeat(ctx0, k, tmp);
2125+
v = ggml_repeat(ctx0, v, tmp);
2126+
}
2127+
2128+
k = ggml_reshape_3d(ctx0, k, head_size, n_head, n_tokens);
2129+
v = ggml_reshape_3d(ctx0, v, head_size, n_head, n_tokens);
2130+
r = ggml_reshape_3d(ctx0, r, head_size, n_head, n_tokens);
2131+
2132+
struct ggml_tensor * w = ggml_mul_mat(
2133+
ctx0,
2134+
layer->time_mix_decay_w2,
2135+
ggml_tanh(
2136+
ctx0,
2137+
ggml_mul_mat(ctx0, layer->time_mix_decay_w1, xw)
2138+
)
2139+
);
2140+
2141+
w = ggml_add(ctx0, w, layer->time_mix_decay);
2142+
w = ggml_exp(ctx0, ggml_neg(ctx0, ggml_exp(ctx0, w)));
2143+
w = ggml_reshape_3d(ctx0, w, head_size, n_head, n_tokens);
2144+
2145+
if (is_qrwkv) {
2146+
// k = k * (1 - w)
2147+
k = ggml_sub(ctx0, k, ggml_mul(ctx0, k, w));
2148+
}
2149+
2150+
struct ggml_tensor * wkv_state = build_copy_mask_state(
2151+
ctx0, graph, kv_self.v_l[il], state_copy, state_mask,
2152+
n_tokens, hparams.n_embd_v_s(), n_seqs, worst_case);
2153+
2154+
struct ggml_tensor * wkv_output;
2155+
if (is_qrwkv) {
2156+
wkv_output = ggml_gated_linear_attn(ctx0, k, v, r, w, wkv_state, pow(head_size, -0.5f));
2157+
} else {
2158+
wkv_output = ggml_rwkv_wkv6(ctx0, k, v, r, layer->time_mix_first, w, wkv_state);
2159+
}
2160+
cur = ggml_view_1d(ctx0, wkv_output, n_embd * n_tokens, 0);
2161+
wkv_state = ggml_view_1d(ctx0, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float));
2162+
2163+
ggml_build_forward_expand(
2164+
graph,
2165+
ggml_cpy(
2166+
ctx0,
2167+
wkv_state,
2168+
ggml_view_1d(
2169+
ctx0,
2170+
kv_self.v_l[il],
2171+
hparams.n_embd_v_s() * n_seqs,
2172+
hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self.v_l[il])
2173+
)
2174+
)
2175+
);
2176+
2177+
if (!is_qrwkv) {
2178+
// group norm with head_count groups
2179+
cur = ggml_reshape_3d(ctx0, cur, n_embd / n_head, n_head, n_tokens);
2180+
cur = ggml_norm(ctx0, cur, 64e-5f);
2181+
2182+
// Convert back to regular vectors.
2183+
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
2184+
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer->time_mix_ln), layer->time_mix_ln_b);
2185+
} else {
2186+
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
2187+
}
2188+
2189+
cur = ggml_mul(ctx0, cur, g);
2190+
cur = build_lora_mm(ctx0, layer->time_mix_output, cur);
2191+
2192+
return cur;
2193+
}
2194+
19732195
// llama output
19742196

19752197
size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) {

src/llama-context.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,33 @@ struct llama_context {
248248
int il,
249249
bool worst_case);
250250

251+
ggml_tensor * build_rwkv_token_shift_load(
252+
ggml_context * ctx0,
253+
ggml_cgraph * graph,
254+
ggml_tensor * state_copy,
255+
ggml_tensor * state_mask,
256+
const llama_ubatch & ubatch,
257+
int il,
258+
bool worst_case);
259+
260+
ggml_tensor * build_rwkv_token_shift_store(
261+
ggml_context * ctx0,
262+
ggml_tensor * token_shift,
263+
const llama_ubatch & ubatch,
264+
int il,
265+
bool worst_case);
266+
267+
ggml_tensor * build_rwkv6_time_mix(
268+
ggml_context * ctx0,
269+
ggml_cgraph * graph,
270+
ggml_tensor * cur,
271+
ggml_tensor * x_prev,
272+
ggml_tensor * state_copy,
273+
ggml_tensor * state_mask,
274+
const llama_ubatch & ubatch,
275+
int il,
276+
bool worst_case);
277+
251278
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
252279
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
253280

0 commit comments

Comments
 (0)