@@ -2567,63 +2567,56 @@ void llama_context_kv_self::build_attn_inp(
2567
2567
}
2568
2568
}
2569
2569
2570
- void llama_context_kv_self::build_attn_kv_store (
2570
+ ggml_tensor * llama_context_kv_self::build_attn (
2571
2571
ggml_context * ctx0,
2572
2572
ggml_cgraph * gf,
2573
+ ggml_tensor * wo,
2574
+ ggml_tensor * wo_b,
2573
2575
ggml_tensor * k_cur,
2574
2576
ggml_tensor * v_cur,
2577
+ ggml_tensor * q_cur,
2575
2578
int32_t n_tokens,
2576
- int64_t il,
2579
+ float kq_scale,
2580
+ int il,
2577
2581
bool worst_case) {
2578
2582
const auto & hparams = model.hparams ;
2579
2583
2580
2584
const auto & n_ctx = cparams.n_ctx ;
2581
2585
2582
- const auto kv_head = worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head ;
2583
-
2584
2586
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa (il);
2585
2587
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa (il);
2586
2588
2587
- GGML_ASSERT (kv_self.size == n_ctx);
2589
+ // store to KV cache
2590
+ {
2591
+ const auto kv_head = worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head ;
2588
2592
2589
- struct ggml_tensor * k_cache_view = ggml_view_1d (ctx0, kv_self.k_l [il], n_tokens*n_embd_k_gqa, ggml_row_size (kv_self.k_l [il]->type , n_embd_k_gqa)*kv_head);
2590
- // cb(k_cache_view, "k_cache_view", il);
2593
+ GGML_ASSERT (kv_self.size == n_ctx);
2591
2594
2592
- // note: storing RoPE-ed version of K in the KV cache
2593
- ggml_build_forward_expand (gf, ggml_cpy (ctx0, k_cur, k_cache_view) );
2595
+ struct ggml_tensor * k_cache_view = ggml_view_1d (ctx0, kv_self. k_l [il], n_tokens*n_embd_k_gqa, ggml_row_size (kv_self. k_l [il]-> type , n_embd_k_gqa)*kv_head);
2596
+ // cb(k_cache_view, "k_cache_view", il );
2594
2597
2595
- assert (v_cur->ne [0 ] == n_embd_v_gqa && v_cur->ne [1 ] == n_tokens);
2598
+ // note: storing RoPE-ed version of K in the KV cache
2599
+ ggml_build_forward_expand (gf, ggml_cpy (ctx0, k_cur, k_cache_view));
2596
2600
2597
- struct ggml_tensor * v_cache_view = nullptr ;
2601
+ assert (v_cur-> ne [ 0 ] == n_embd_v_gqa && v_cur-> ne [ 1 ] == n_tokens) ;
2598
2602
2599
- if (cparams.flash_attn ) {
2600
- v_cache_view = ggml_view_1d (ctx0, kv_self.v_l [il], n_tokens*n_embd_v_gqa, ggml_row_size (kv_self.v_l [il]->type , n_embd_v_gqa)*kv_head);
2601
- } else {
2602
- // note: the V cache is transposed when not using flash attention
2603
- v_cache_view = ggml_view_2d (ctx0, kv_self.v_l [il], n_tokens, n_embd_v_gqa,
2604
- ( n_ctx)*ggml_element_size (kv_self.v_l [il]),
2605
- (kv_head)*ggml_element_size (kv_self.v_l [il]));
2603
+ struct ggml_tensor * v_cache_view = nullptr ;
2606
2604
2607
- v_cur = ggml_transpose (ctx0, v_cur);
2608
- }
2609
- // cb(v_cache_view, "v_cache_view", il);
2605
+ if (cparams.flash_attn ) {
2606
+ v_cache_view = ggml_view_1d (ctx0, kv_self.v_l [il], n_tokens*n_embd_v_gqa, ggml_row_size (kv_self.v_l [il]->type , n_embd_v_gqa)*kv_head);
2607
+ } else {
2608
+ // note: the V cache is transposed when not using flash attention
2609
+ v_cache_view = ggml_view_2d (ctx0, kv_self.v_l [il], n_tokens, n_embd_v_gqa,
2610
+ ( n_ctx)*ggml_element_size (kv_self.v_l [il]),
2611
+ (kv_head)*ggml_element_size (kv_self.v_l [il]));
2610
2612
2611
- ggml_build_forward_expand (gf, ggml_cpy (ctx0, v_cur, v_cache_view));
2612
- }
2613
+ v_cur = ggml_transpose (ctx0, v_cur);
2614
+ }
2615
+ // cb(v_cache_view, "v_cache_view", il);
2613
2616
2614
- ggml_tensor * llama_context_kv_self::build_attn_qkv (
2615
- ggml_context * ctx0,
2616
- ggml_cgraph * gf,
2617
- ggml_tensor * wo,
2618
- ggml_tensor * wo_b,
2619
- ggml_tensor * q_cur,
2620
- int32_t n_tokens,
2621
- float kq_scale,
2622
- int il,
2623
- bool worst_case) {
2624
- const auto & hparams = model.hparams ;
2617
+ ggml_build_forward_expand (gf, ggml_cpy (ctx0, v_cur, v_cache_view));
2618
+ }
2625
2619
2626
- const auto & n_ctx = cparams.n_ctx ;
2627
2620
const auto & n_embd_head_k = hparams.n_embd_head_k ;
2628
2621
const auto & n_embd_head_v = hparams.n_embd_head_v ;
2629
2622
@@ -2657,8 +2650,6 @@ ggml_tensor * llama_context_kv_self::build_attn_qkv(
2657
2650
2658
2651
const int64_t n_head = hparams.n_head (il);
2659
2652
const int64_t n_head_kv = hparams.n_head_kv (il);
2660
- const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa (il);
2661
- const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa (il);
2662
2653
2663
2654
struct ggml_tensor * q = ggml_permute (ctx0, q_cur, 0 , 2 , 1 , 3 );
2664
2655
// cb(q, "q", il);
0 commit comments