@@ -1970,6 +1970,228 @@ ggml_tensor * llama_context::build_mamba_layer(
19701970}
19711971
19721972
1973+ ggml_tensor * llama_context::build_rwkv_token_shift_load (
1974+ ggml_context * ctx0,
1975+ ggml_cgraph * graph,
1976+ ggml_tensor * state_copy,
1977+ ggml_tensor * state_mask,
1978+ const llama_ubatch & ubatch,
1979+ int il,
1980+ bool worst_case) {
1981+ const auto & hparams = model.hparams ;
1982+
1983+ const auto token_shift_count = hparams.token_shift_count ;
1984+
1985+ const auto & n_tokens = ubatch.n_tokens ;
1986+ const int64_t n_seqs = ubatch.n_seqs ;
1987+
1988+ struct ggml_tensor * token_shift_all = kv_self.k_l [il];
1989+
1990+ struct ggml_tensor * token_shift = build_copy_mask_state (
1991+ ctx0, graph, token_shift_all, state_copy, state_mask,
1992+ n_tokens, hparams.n_embd_k_s (), n_seqs, worst_case);
1993+
1994+ token_shift = ggml_reshape_3d (ctx0, token_shift, hparams.n_embd , token_shift_count, n_seqs);
1995+
1996+ return token_shift;
1997+ }
1998+
1999+
2000+ ggml_tensor * llama_context::build_rwkv_token_shift_store (
2001+ ggml_context * ctx0,
2002+ ggml_tensor * token_shift,
2003+ const llama_ubatch & ubatch,
2004+ int il,
2005+ bool worst_case) {
2006+ const auto & hparams = model.hparams ;
2007+
2008+ const auto token_shift_count = hparams.token_shift_count ;
2009+ const auto n_embd = hparams.n_embd ;
2010+
2011+ const auto & n_tokens = ubatch.n_tokens ;
2012+ const int64_t n_seqs = ubatch.n_seqs ;
2013+
2014+ const auto kv_head = worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head ;
2015+
2016+ return ggml_cpy (
2017+ ctx0,
2018+ ggml_view_1d (ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0 ),
2019+ ggml_view_1d (ctx0, kv_self.k_l [il], hparams.n_embd_k_s () * n_seqs, hparams.n_embd_k_s () * kv_head * ggml_element_size (kv_self.k_l [il]))
2020+ );
2021+ }
2022+
2023+
2024+ ggml_tensor * llama_context::build_rwkv6_time_mix (
2025+ ggml_context * ctx0,
2026+ ggml_cgraph * graph,
2027+ ggml_tensor * cur,
2028+ ggml_tensor * x_prev,
2029+ ggml_tensor * state_copy,
2030+ ggml_tensor * state_mask,
2031+ const llama_ubatch & ubatch,
2032+ int il,
2033+ bool worst_case) {
2034+ const auto & hparams = model.hparams ;
2035+
2036+ const auto n_tokens = ubatch.n_tokens ;
2037+ const auto n_seqs = ubatch.n_seqs ;
2038+ const auto n_embd = hparams.n_embd ;
2039+ const auto head_size = hparams.wkv_head_size ;
2040+ const auto n_head = n_embd / head_size;
2041+ const auto n_head_kv = hparams.n_head_kv (il);
2042+
2043+ const auto kv_head = worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head ;
2044+
2045+ const auto layer = &model.layers [il];
2046+
2047+ bool is_qrwkv = layer->time_mix_first == nullptr ;
2048+
2049+ struct ggml_tensor * sx = ggml_sub (ctx0, x_prev, cur);
2050+ struct ggml_tensor * xxx = ggml_add (ctx0, ggml_mul (ctx0, sx, layer->time_mix_lerp_x ), cur);
2051+
2052+ xxx = ggml_reshape_4d (
2053+ ctx0,
2054+ ggml_tanh (
2055+ ctx0,
2056+ ggml_mul_mat (ctx0, layer->time_mix_w1 , xxx)
2057+ ),
2058+ layer->time_mix_w1 ->ne [1 ] / 5 , 1 , 5 , n_tokens
2059+ );
2060+
2061+ xxx = ggml_cont (ctx0, ggml_permute (ctx0, xxx, 0 , 1 , 3 , 2 ));
2062+
2063+ xxx = ggml_mul_mat (
2064+ ctx0,
2065+ ggml_reshape_4d (
2066+ ctx0,
2067+ layer->time_mix_w2 ,
2068+ layer->time_mix_w2 ->ne [0 ], layer->time_mix_w2 ->ne [1 ], 1 , 5
2069+ ),
2070+ xxx
2071+ );
2072+
2073+ struct ggml_tensor *xw, *xk, *xv, *xr, *xg;
2074+ if (layer->time_mix_lerp_fused ) {
2075+ // fusing these weights makes some performance improvement
2076+ sx = ggml_reshape_3d (ctx0, sx, n_embd, 1 , n_tokens);
2077+ cur = ggml_reshape_3d (ctx0, cur, n_embd, 1 , n_tokens);
2078+ xxx = ggml_add (ctx0, ggml_mul (ctx0, ggml_add (ctx0, xxx, layer->time_mix_lerp_fused ), sx), cur);
2079+ xw = ggml_view_2d (ctx0, xxx, n_embd, n_tokens, xxx->nb [1 ], 0 );
2080+ xk = ggml_view_2d (ctx0, xxx, n_embd, n_tokens, xxx->nb [1 ], n_embd * n_tokens * sizeof (float ));
2081+ xv = ggml_view_2d (ctx0, xxx, n_embd, n_tokens, xxx->nb [1 ], n_embd * n_tokens * 2 * sizeof (float ));
2082+ xr = ggml_view_2d (ctx0, xxx, n_embd, n_tokens, xxx->nb [1 ], n_embd * n_tokens * 3 * sizeof (float ));
2083+ xg = ggml_view_2d (ctx0, xxx, n_embd, n_tokens, xxx->nb [1 ], n_embd * n_tokens * 4 * sizeof (float ));
2084+ } else {
2085+ // for backward compatibility
2086+ xw = ggml_view_2d (ctx0, xxx, n_embd, n_tokens, xxx->nb [1 ], 0 );
2087+ xk = ggml_view_2d (ctx0, xxx, n_embd, n_tokens, xxx->nb [1 ], n_embd * n_tokens * sizeof (float ));
2088+ xv = ggml_view_2d (ctx0, xxx, n_embd, n_tokens, xxx->nb [1 ], n_embd * n_tokens * 2 * sizeof (float ));
2089+ xr = ggml_view_2d (ctx0, xxx, n_embd, n_tokens, xxx->nb [1 ], n_embd * n_tokens * 3 * sizeof (float ));
2090+ xg = ggml_view_2d (ctx0, xxx, n_embd, n_tokens, xxx->nb [1 ], n_embd * n_tokens * 4 * sizeof (float ));
2091+
2092+ xw = ggml_add (ctx0, ggml_mul (ctx0, ggml_add (ctx0, xw, layer->time_mix_lerp_w ), sx), cur);
2093+ xk = ggml_add (ctx0, ggml_mul (ctx0, ggml_add (ctx0, xk, layer->time_mix_lerp_k ), sx), cur);
2094+ xv = ggml_add (ctx0, ggml_mul (ctx0, ggml_add (ctx0, xv, layer->time_mix_lerp_v ), sx), cur);
2095+ xr = ggml_add (ctx0, ggml_mul (ctx0, ggml_add (ctx0, xr, layer->time_mix_lerp_r ), sx), cur);
2096+ xg = ggml_add (ctx0, ggml_mul (ctx0, ggml_add (ctx0, xg, layer->time_mix_lerp_g ), sx), cur);
2097+ }
2098+
2099+ struct ggml_tensor * r = build_lora_mm (ctx0, layer->time_mix_receptance , xr);
2100+ struct ggml_tensor * k = build_lora_mm (ctx0, layer->time_mix_key , xk);
2101+ struct ggml_tensor * v = build_lora_mm (ctx0, layer->time_mix_value , xv);
2102+ if (layer->time_mix_receptance_b ) {
2103+ r = ggml_add (ctx0, r, layer->time_mix_receptance_b );
2104+ }
2105+ if (layer->time_mix_key_b ) {
2106+ k = ggml_add (ctx0, k, layer->time_mix_key_b );
2107+ }
2108+ if (layer->time_mix_value_b ) {
2109+ v = ggml_add (ctx0, v, layer->time_mix_value_b );
2110+ }
2111+
2112+ struct ggml_tensor * g = build_lora_mm (ctx0, layer->time_mix_gate , xg);
2113+ if (is_qrwkv) {
2114+ g = ggml_sigmoid (ctx0, g);
2115+ } else {
2116+ g = ggml_silu (ctx0, g);
2117+ }
2118+
2119+ if (n_head_kv != 0 && n_head_kv != n_head) {
2120+ GGML_ASSERT (n_head % n_head_kv == 0 );
2121+ k = ggml_reshape_4d (ctx0, k, head_size, 1 , n_head_kv, n_tokens);
2122+ v = ggml_reshape_4d (ctx0, v, head_size, 1 , n_head_kv, n_tokens);
2123+ struct ggml_tensor * tmp = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, head_size, n_head / n_head_kv, n_head_kv, n_tokens);
2124+ k = ggml_repeat (ctx0, k, tmp);
2125+ v = ggml_repeat (ctx0, v, tmp);
2126+ }
2127+
2128+ k = ggml_reshape_3d (ctx0, k, head_size, n_head, n_tokens);
2129+ v = ggml_reshape_3d (ctx0, v, head_size, n_head, n_tokens);
2130+ r = ggml_reshape_3d (ctx0, r, head_size, n_head, n_tokens);
2131+
2132+ struct ggml_tensor * w = ggml_mul_mat (
2133+ ctx0,
2134+ layer->time_mix_decay_w2 ,
2135+ ggml_tanh (
2136+ ctx0,
2137+ ggml_mul_mat (ctx0, layer->time_mix_decay_w1 , xw)
2138+ )
2139+ );
2140+
2141+ w = ggml_add (ctx0, w, layer->time_mix_decay );
2142+ w = ggml_exp (ctx0, ggml_neg (ctx0, ggml_exp (ctx0, w)));
2143+ w = ggml_reshape_3d (ctx0, w, head_size, n_head, n_tokens);
2144+
2145+ if (is_qrwkv) {
2146+ // k = k * (1 - w)
2147+ k = ggml_sub (ctx0, k, ggml_mul (ctx0, k, w));
2148+ }
2149+
2150+ struct ggml_tensor * wkv_state = build_copy_mask_state (
2151+ ctx0, graph, kv_self.v_l [il], state_copy, state_mask,
2152+ n_tokens, hparams.n_embd_v_s (), n_seqs, worst_case);
2153+
2154+ struct ggml_tensor * wkv_output;
2155+ if (is_qrwkv) {
2156+ wkv_output = ggml_gated_linear_attn (ctx0, k, v, r, w, wkv_state, pow (head_size, -0 .5f ));
2157+ } else {
2158+ wkv_output = ggml_rwkv_wkv6 (ctx0, k, v, r, layer->time_mix_first , w, wkv_state);
2159+ }
2160+ cur = ggml_view_1d (ctx0, wkv_output, n_embd * n_tokens, 0 );
2161+ wkv_state = ggml_view_1d (ctx0, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof (float ));
2162+
2163+ ggml_build_forward_expand (
2164+ graph,
2165+ ggml_cpy (
2166+ ctx0,
2167+ wkv_state,
2168+ ggml_view_1d (
2169+ ctx0,
2170+ kv_self.v_l [il],
2171+ hparams.n_embd_v_s () * n_seqs,
2172+ hparams.n_embd_v_s () * kv_head * ggml_element_size (kv_self.v_l [il])
2173+ )
2174+ )
2175+ );
2176+
2177+ if (!is_qrwkv) {
2178+ // group norm with head_count groups
2179+ cur = ggml_reshape_3d (ctx0, cur, n_embd / n_head, n_head, n_tokens);
2180+ cur = ggml_norm (ctx0, cur, 64e-5f );
2181+
2182+ // Convert back to regular vectors.
2183+ cur = ggml_reshape_2d (ctx0, cur, n_embd, n_tokens);
2184+ cur = ggml_add (ctx0, ggml_mul (ctx0, cur, layer->time_mix_ln ), layer->time_mix_ln_b );
2185+ } else {
2186+ cur = ggml_reshape_2d (ctx0, cur, n_embd, n_tokens);
2187+ }
2188+
2189+ cur = ggml_mul (ctx0, cur, g);
2190+ cur = build_lora_mm (ctx0, layer->time_mix_output , cur);
2191+
2192+ return cur;
2193+ }
2194+
19732195// llama output
19742196
19752197size_t llama_output_reserve (struct llama_context & lctx, size_t n_outputs) {
0 commit comments