@@ -1970,6 +1970,229 @@ ggml_tensor * llama_context::build_mamba_layer(
19701970}
19711971
19721972
1973+ ggml_tensor * llama_context::build_rwkv_token_shift_load (
1974+ ggml_context * ctx0,
1975+ ggml_cgraph * graph,
1976+ ggml_tensor * state_copy,
1977+ ggml_tensor * state_mask,
1978+ const llama_ubatch & ubatch,
1979+ int il,
1980+ bool worst_case) {
1981+ const auto & hparams = model.hparams ;
1982+
1983+ const auto token_shift_count = hparams.token_shift_count ;
1984+
1985+ const auto & n_tokens = ubatch.n_tokens ;
1986+ const int64_t n_seqs = ubatch.n_seqs ;
1987+
1988+ struct ggml_tensor * token_shift_all = kv_self.k_l [il];
1989+
1990+ struct ggml_tensor * token_shift = build_copy_mask_state (
1991+ ctx0, graph, token_shift_all, state_copy, state_mask,
1992+ n_tokens, hparams.n_embd_k_s (), n_seqs, worst_case);
1993+
1994+ token_shift = ggml_reshape_3d (ctx0, token_shift, hparams.n_embd , token_shift_count, n_seqs);
1995+
1996+ return token_shift;
1997+ }
1998+
1999+
2000+ ggml_tensor * llama_context::build_rwkv_token_shift_store (
2001+ ggml_context * ctx0,
2002+ ggml_tensor * token_shift,
2003+ const llama_ubatch & ubatch,
2004+ int il,
2005+ bool worst_case) {
2006+ const auto & hparams = model.hparams ;
2007+
2008+ const auto token_shift_count = hparams.token_shift_count ;
2009+ const auto n_embd = hparams.n_embd ;
2010+
2011+ const auto & n_tokens = ubatch.n_tokens ;
2012+ const int64_t n_seqs = ubatch.n_seqs ;
2013+
2014+ const auto kv_head = worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head ;
2015+
2016+ return ggml_cpy (
2017+ ctx0,
2018+ ggml_view_1d (ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0 ),
2019+ ggml_view_1d (ctx0, kv_self.k_l [il], hparams.n_embd_k_s () * n_seqs, hparams.n_embd_k_s () * kv_head * ggml_element_size (kv_self.k_l [il]))
2020+ );
2021+ }
2022+
2023+
2024+ ggml_tensor * llama_context::build_rwkv6_time_mix (
2025+ ggml_context * ctx0,
2026+ ggml_cgraph * graph,
2027+ ggml_tensor * cur,
2028+ ggml_tensor * x_prev,
2029+ ggml_tensor * state_copy,
2030+ ggml_tensor * state_mask,
2031+ const llama_ubatch & ubatch,
2032+ int il,
2033+ bool worst_case) {
2034+ const auto & hparams = model.hparams ;
2035+
2036+ const auto n_tokens = ubatch.n_tokens ;
2037+ const auto n_seqs = ubatch.n_seqs ;
2038+ const auto n_seq_tokens = ubatch.n_seq_tokens ;
2039+ const auto n_embd = hparams.n_embd ;
2040+ const auto head_size = hparams.wkv_head_size ;
2041+ const auto n_head = n_embd / head_size;
2042+ const auto n_head_kv = hparams.n_head_kv (il);
2043+
2044+ const auto kv_head = worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head ;
2045+
2046+ const auto layer = &model.layers [il];
2047+
2048+ bool is_qrwkv = layer->time_mix_first == nullptr ;
2049+
2050+ struct ggml_tensor * sx = ggml_sub (ctx0, x_prev, cur);
2051+ struct ggml_tensor * xxx = ggml_add (ctx0, ggml_mul (ctx0, sx, layer->time_mix_lerp_x ), cur);
2052+
2053+ xxx = ggml_reshape_4d (
2054+ ctx0,
2055+ ggml_tanh (
2056+ ctx0,
2057+ ggml_mul_mat (ctx0, layer->time_mix_w1 , xxx)
2058+ ),
2059+ layer->time_mix_w1 ->ne [1 ] / 5 , 1 , 5 , n_tokens
2060+ );
2061+
2062+ xxx = ggml_cont (ctx0, ggml_permute (ctx0, xxx, 0 , 1 , 3 , 2 ));
2063+
2064+ xxx = ggml_mul_mat (
2065+ ctx0,
2066+ ggml_reshape_4d (
2067+ ctx0,
2068+ layer->time_mix_w2 ,
2069+ layer->time_mix_w2 ->ne [0 ], layer->time_mix_w2 ->ne [1 ], 1 , 5
2070+ ),
2071+ xxx
2072+ );
2073+
2074+ struct ggml_tensor *xw, *xk, *xv, *xr, *xg;
2075+ if (layer->time_mix_lerp_fused ) {
2076+ // fusing these weights makes some performance improvement
2077+ sx = ggml_reshape_3d (ctx0, sx, n_embd, 1 , n_tokens);
2078+ cur = ggml_reshape_3d (ctx0, cur, n_embd, 1 , n_tokens);
2079+ xxx = ggml_add (ctx0, ggml_mul (ctx0, ggml_add (ctx0, xxx, layer->time_mix_lerp_fused ), sx), cur);
2080+ xw = ggml_view_2d (ctx0, xxx, n_embd, n_tokens, xxx->nb [1 ], 0 );
2081+ xk = ggml_view_2d (ctx0, xxx, n_embd, n_tokens, xxx->nb [1 ], n_embd * n_tokens * sizeof (float ));
2082+ xv = ggml_view_2d (ctx0, xxx, n_embd, n_tokens, xxx->nb [1 ], n_embd * n_tokens * 2 * sizeof (float ));
2083+ xr = ggml_view_2d (ctx0, xxx, n_embd, n_tokens, xxx->nb [1 ], n_embd * n_tokens * 3 * sizeof (float ));
2084+ xg = ggml_view_2d (ctx0, xxx, n_embd, n_tokens, xxx->nb [1 ], n_embd * n_tokens * 4 * sizeof (float ));
2085+ } else {
2086+ // for backward compatibility
2087+ xw = ggml_view_2d (ctx0, xxx, n_embd, n_tokens, xxx->nb [1 ], 0 );
2088+ xk = ggml_view_2d (ctx0, xxx, n_embd, n_tokens, xxx->nb [1 ], n_embd * n_tokens * sizeof (float ));
2089+ xv = ggml_view_2d (ctx0, xxx, n_embd, n_tokens, xxx->nb [1 ], n_embd * n_tokens * 2 * sizeof (float ));
2090+ xr = ggml_view_2d (ctx0, xxx, n_embd, n_tokens, xxx->nb [1 ], n_embd * n_tokens * 3 * sizeof (float ));
2091+ xg = ggml_view_2d (ctx0, xxx, n_embd, n_tokens, xxx->nb [1 ], n_embd * n_tokens * 4 * sizeof (float ));
2092+
2093+ xw = ggml_add (ctx0, ggml_mul (ctx0, ggml_add (ctx0, xw, layer->time_mix_lerp_w ), sx), cur);
2094+ xk = ggml_add (ctx0, ggml_mul (ctx0, ggml_add (ctx0, xk, layer->time_mix_lerp_k ), sx), cur);
2095+ xv = ggml_add (ctx0, ggml_mul (ctx0, ggml_add (ctx0, xv, layer->time_mix_lerp_v ), sx), cur);
2096+ xr = ggml_add (ctx0, ggml_mul (ctx0, ggml_add (ctx0, xr, layer->time_mix_lerp_r ), sx), cur);
2097+ xg = ggml_add (ctx0, ggml_mul (ctx0, ggml_add (ctx0, xg, layer->time_mix_lerp_g ), sx), cur);
2098+ }
2099+
2100+ struct ggml_tensor * r = build_lora_mm (ctx0, layer->time_mix_receptance , xr);
2101+ struct ggml_tensor * k = build_lora_mm (ctx0, layer->time_mix_key , xk);
2102+ struct ggml_tensor * v = build_lora_mm (ctx0, layer->time_mix_value , xv);
2103+ if (layer->time_mix_receptance_b ) {
2104+ r = ggml_add (ctx0, r, layer->time_mix_receptance_b );
2105+ }
2106+ if (layer->time_mix_key_b ) {
2107+ k = ggml_add (ctx0, k, layer->time_mix_key_b );
2108+ }
2109+ if (layer->time_mix_value_b ) {
2110+ v = ggml_add (ctx0, v, layer->time_mix_value_b );
2111+ }
2112+
2113+ struct ggml_tensor * g = build_lora_mm (ctx0, layer->time_mix_gate , xg);
2114+ if (is_qrwkv) {
2115+ g = ggml_sigmoid (ctx0, g);
2116+ } else {
2117+ g = ggml_silu (ctx0, g);
2118+ }
2119+
2120+ if (n_head_kv != 0 && n_head_kv != n_head) {
2121+ GGML_ASSERT (n_head % n_head_kv == 0 );
2122+ k = ggml_reshape_4d (ctx0, k, head_size, 1 , n_head_kv, n_tokens);
2123+ v = ggml_reshape_4d (ctx0, v, head_size, 1 , n_head_kv, n_tokens);
2124+ struct ggml_tensor * tmp = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, head_size, n_head / n_head_kv, n_head_kv, n_tokens);
2125+ k = ggml_repeat (ctx0, k, tmp);
2126+ v = ggml_repeat (ctx0, v, tmp);
2127+ }
2128+
2129+ k = ggml_reshape_3d (ctx0, k, head_size, n_head, n_tokens);
2130+ v = ggml_reshape_3d (ctx0, v, head_size, n_head, n_tokens);
2131+ r = ggml_reshape_3d (ctx0, r, head_size, n_head, n_tokens);
2132+
2133+ struct ggml_tensor * w = ggml_mul_mat (
2134+ ctx0,
2135+ layer->time_mix_decay_w2 ,
2136+ ggml_tanh (
2137+ ctx0,
2138+ ggml_mul_mat (ctx0, layer->time_mix_decay_w1 , xw)
2139+ )
2140+ );
2141+
2142+ w = ggml_add (ctx0, w, layer->time_mix_decay );
2143+ w = ggml_exp (ctx0, ggml_neg (ctx0, ggml_exp (ctx0, w)));
2144+ w = ggml_reshape_3d (ctx0, w, head_size, n_head, n_tokens);
2145+
2146+ if (is_qrwkv) {
2147+ // k = k * (1 - w)
2148+ k = ggml_sub (ctx0, k, ggml_mul (ctx0, k, w));
2149+ }
2150+
2151+ struct ggml_tensor * wkv_state = build_copy_mask_state (
2152+ ctx0, graph, kv_self.v_l [il], state_copy, state_mask,
2153+ n_tokens, hparams.n_embd_v_s (), n_seqs, worst_case);
2154+
2155+ struct ggml_tensor * wkv_output;
2156+ if (is_qrwkv) {
2157+ wkv_output = ggml_gated_linear_attn (ctx0, k, v, r, w, wkv_state, pow (head_size, -0 .5f ));
2158+ } else {
2159+ wkv_output = ggml_rwkv_wkv6 (ctx0, k, v, r, layer->time_mix_first , w, wkv_state);
2160+ }
2161+ cur = ggml_view_1d (ctx0, wkv_output, n_embd * n_tokens, 0 );
2162+ wkv_state = ggml_view_1d (ctx0, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof (float ));
2163+
2164+ ggml_build_forward_expand (
2165+ graph,
2166+ ggml_cpy (
2167+ ctx0,
2168+ wkv_state,
2169+ ggml_view_1d (
2170+ ctx0,
2171+ kv_self.v_l [il],
2172+ hparams.n_embd_v_s () * n_seqs,
2173+ hparams.n_embd_v_s () * kv_head * ggml_element_size (kv_self.v_l [il])
2174+ )
2175+ )
2176+ );
2177+
2178+ if (!is_qrwkv) {
2179+ // group norm with head_count groups
2180+ cur = ggml_reshape_3d (ctx0, cur, n_embd / n_head, n_head, n_tokens);
2181+ cur = ggml_norm (ctx0, cur, 64e-5f );
2182+
2183+ // Convert back to regular vectors.
2184+ cur = ggml_reshape_2d (ctx0, cur, n_embd, n_tokens);
2185+ cur = ggml_add (ctx0, ggml_mul (ctx0, cur, layer->time_mix_ln ), layer->time_mix_ln_b );
2186+ } else {
2187+ cur = ggml_reshape_2d (ctx0, cur, n_embd, n_tokens);
2188+ }
2189+
2190+ cur = ggml_mul (ctx0, cur, g);
2191+ cur = build_lora_mm (ctx0, layer->time_mix_output , cur);
2192+
2193+ return cur;
2194+ }
2195+
19732196// llama output
19742197
19752198size_t llama_output_reserve (struct llama_context & lctx, size_t n_outputs) {
0 commit comments