@@ -2567,63 +2567,56 @@ void llama_context_kv_self::build_attn_inp(
25672567 }
25682568}
25692569
2570- void llama_context_kv_self::build_attn_kv_store (
2570+ ggml_tensor * llama_context_kv_self::build_attn (
25712571 ggml_context * ctx0,
25722572 ggml_cgraph * gf,
2573+ ggml_tensor * wo,
2574+ ggml_tensor * wo_b,
25732575 ggml_tensor * k_cur,
25742576 ggml_tensor * v_cur,
2577+ ggml_tensor * q_cur,
25752578 int32_t n_tokens,
2576- int64_t il,
2579+ float kq_scale,
2580+ int il,
25772581 bool worst_case) {
25782582 const auto & hparams = model.hparams ;
25792583
25802584 const auto & n_ctx = cparams.n_ctx ;
25812585
2582- const auto kv_head = worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head ;
2583-
25842586 const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa (il);
25852587 const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa (il);
25862588
2587- GGML_ASSERT (kv_self.size == n_ctx);
2589+ // store to KV cache
2590+ {
2591+ const auto kv_head = worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head ;
25882592
2589- struct ggml_tensor * k_cache_view = ggml_view_1d (ctx0, kv_self.k_l [il], n_tokens*n_embd_k_gqa, ggml_row_size (kv_self.k_l [il]->type , n_embd_k_gqa)*kv_head);
2590- // cb(k_cache_view, "k_cache_view", il);
2593+ GGML_ASSERT (kv_self.size == n_ctx);
25912594
2592- // note: storing RoPE-ed version of K in the KV cache
2593- ggml_build_forward_expand (gf, ggml_cpy (ctx0, k_cur, k_cache_view) );
2595+ struct ggml_tensor * k_cache_view = ggml_view_1d (ctx0, kv_self. k_l [il], n_tokens*n_embd_k_gqa, ggml_row_size (kv_self. k_l [il]-> type , n_embd_k_gqa)*kv_head);
2596+ // cb(k_cache_view, "k_cache_view", il );
25942597
2595- assert (v_cur->ne [0 ] == n_embd_v_gqa && v_cur->ne [1 ] == n_tokens);
2598+ // note: storing RoPE-ed version of K in the KV cache
2599+ ggml_build_forward_expand (gf, ggml_cpy (ctx0, k_cur, k_cache_view));
25962600
2597- struct ggml_tensor * v_cache_view = nullptr ;
2601+ assert (v_cur-> ne [ 0 ] == n_embd_v_gqa && v_cur-> ne [ 1 ] == n_tokens) ;
25982602
2599- if (cparams.flash_attn ) {
2600- v_cache_view = ggml_view_1d (ctx0, kv_self.v_l [il], n_tokens*n_embd_v_gqa, ggml_row_size (kv_self.v_l [il]->type , n_embd_v_gqa)*kv_head);
2601- } else {
2602- // note: the V cache is transposed when not using flash attention
2603- v_cache_view = ggml_view_2d (ctx0, kv_self.v_l [il], n_tokens, n_embd_v_gqa,
2604- ( n_ctx)*ggml_element_size (kv_self.v_l [il]),
2605- (kv_head)*ggml_element_size (kv_self.v_l [il]));
2603+ struct ggml_tensor * v_cache_view = nullptr ;
26062604
2607- v_cur = ggml_transpose (ctx0, v_cur);
2608- }
2609- // cb(v_cache_view, "v_cache_view", il);
2605+ if (cparams.flash_attn ) {
2606+ v_cache_view = ggml_view_1d (ctx0, kv_self.v_l [il], n_tokens*n_embd_v_gqa, ggml_row_size (kv_self.v_l [il]->type , n_embd_v_gqa)*kv_head);
2607+ } else {
2608+ // note: the V cache is transposed when not using flash attention
2609+ v_cache_view = ggml_view_2d (ctx0, kv_self.v_l [il], n_tokens, n_embd_v_gqa,
2610+ ( n_ctx)*ggml_element_size (kv_self.v_l [il]),
2611+ (kv_head)*ggml_element_size (kv_self.v_l [il]));
26102612
2611- ggml_build_forward_expand (gf, ggml_cpy (ctx0, v_cur, v_cache_view));
2612- }
2613+ v_cur = ggml_transpose (ctx0, v_cur);
2614+ }
2615+ // cb(v_cache_view, "v_cache_view", il);
26132616
2614- ggml_tensor * llama_context_kv_self::build_attn_qkv (
2615- ggml_context * ctx0,
2616- ggml_cgraph * gf,
2617- ggml_tensor * wo,
2618- ggml_tensor * wo_b,
2619- ggml_tensor * q_cur,
2620- int32_t n_tokens,
2621- float kq_scale,
2622- int il,
2623- bool worst_case) {
2624- const auto & hparams = model.hparams ;
2617+ ggml_build_forward_expand (gf, ggml_cpy (ctx0, v_cur, v_cache_view));
2618+ }
26252619
2626- const auto & n_ctx = cparams.n_ctx ;
26272620 const auto & n_embd_head_k = hparams.n_embd_head_k ;
26282621 const auto & n_embd_head_v = hparams.n_embd_head_v ;
26292622
@@ -2657,8 +2650,6 @@ ggml_tensor * llama_context_kv_self::build_attn_qkv(
26572650
26582651 const int64_t n_head = hparams.n_head (il);
26592652 const int64_t n_head_kv = hparams.n_head_kv (il);
2660- const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa (il);
2661- const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa (il);
26622653
26632654 struct ggml_tensor * q = ggml_permute (ctx0, q_cur, 0 , 2 , 1 , 3 );
26642655 // cb(q, "q", il);
0 commit comments