@@ -6404,6 +6404,10 @@ struct llm_build_context {
64046404 // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
64056405 struct ggml_tensor * KQ_mask = build_inp_KQ_mask ();
64066406
6407+ // whether to use n_tokens as the matrix dimension during multiplication or n_head
6408+ // n_tokens is higher during prompt processing, this allows to optimize for this case
6409+ bool pp_opt = n_tokens > n_head;
6410+
64076411 for (int il = 0 ; il < n_layer; ++il) {
64086412 struct ggml_tensor * inpSA = inpL;
64096413
@@ -6472,33 +6476,33 @@ struct llm_build_context {
64726476 LLM_NORM_RMS, cb, il);
64736477 cb (kv_compressed, " kv_compressed" , il);
64746478
6475- // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens}
6476- struct ggml_tensor * kv = ggml_mul_mat (ctx0, model.layers [il].wkv_b , kv_compressed);
6477- cb (kv, " kv" , il);
6479+ struct ggml_tensor * kv_cache_view = ggml_view_1d (ctx0, kv_self.kv_l [il], n_tokens*kv_lora_rank, ggml_row_size (kv_self.kv_l [il]->type , kv_lora_rank)*kv_head);
6480+ cb (kv_cache_view, " kv_cache_view" , il);
64786481
6479- // split into {n_head * n_embd_head_qk_nope, n_tokens}
6480- struct ggml_tensor * k_nope = ggml_view_3d (ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
6481- ggml_row_size (kv->type , n_embd_head_qk_nope + hparams.n_embd_head_v ),
6482- ggml_row_size (kv->type , n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v )),
6483- 0 );
6484- cb (k_nope, " k_nope" , il);
6482+ // note: storing c^KV in the KV cache
6483+ ggml_build_forward_expand (gf, ggml_cpy (ctx0, kv_compressed, kv_cache_view));
64856484
6486- // and {n_head * n_embd_head_v, n_tokens}
6487- struct ggml_tensor * v_states = ggml_view_3d (ctx0, kv, hparams.n_embd_head_v , n_head, n_tokens,
6488- ggml_row_size (kv->type , (n_embd_head_qk_nope + hparams.n_embd_head_v )),
6489- ggml_row_size (kv->type , (n_embd_head_qk_nope + hparams.n_embd_head_v )*n_head),
6490- ggml_row_size (kv->type , (n_embd_head_qk_nope)));
6491- cb (v_states, " v_states" , il);
6485+ struct ggml_tensor * kv_cache_trans_view = ggml_view_2d (ctx0, kv_self.kvt_l [il], n_tokens, kv_lora_rank, ggml_row_size (kv_self.kv_l [il]->type , kv_self.size ), ggml_row_size (kv_self.kv_l [il]->type , kv_head));
6486+ cb (kv_cache_trans_view, " kv_cache_trans_view" , il);
64926487
6493- v_states = ggml_cont (ctx0, v_states);
6494- cb (v_states, " v_states " , il );
6488+ // note: storing transposed c^KV in the transposed KV cache
6489+ ggml_build_forward_expand (gf, ggml_cpy (ctx0, ggml_transpose (ctx0, kv_compressed), kv_cache_trans_view) );
64956490
6496- v_states = ggml_view_2d (ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens,
6497- ggml_row_size (kv->type , hparams.n_embd_head_v * n_head),
6498- 0 );
6499- cb (v_states, " v_states" , il);
6491+ struct ggml_tensor * kv_cache =
6492+ ggml_view_2d (ctx0, kv_self.kv_l [il],
6493+ kv_lora_rank, n_kv,
6494+ ggml_row_size (kv_self.kv_l [il]->type , kv_lora_rank),
6495+ 0 );
6496+ cb (kv_cache, " kv_cache" , il);
65006497
6501- q_pe = ggml_cont (ctx0, q_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
6498+ struct ggml_tensor * kv_cache_trans =
6499+ ggml_view_2d (ctx0, kv_self.kvt_l [il],
6500+ n_kv, kv_lora_rank,
6501+ ggml_row_size (kv_self.kv_l [il]->type , kv_self.size ),
6502+ 0 );
6503+ cb (kv_cache_trans, " kv_cache_trans" , il);
6504+
6505+ q_pe = ggml_cont (ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
65026506 q_pe = ggml_rope_ext (
65036507 ctx0, q_pe, inp_pos, nullptr ,
65046508 n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@@ -6515,15 +6519,91 @@ struct llm_build_context {
65156519 );
65166520 cb (k_pe, " k_pe" , il);
65176521
6518- struct ggml_tensor * q_states = ggml_concat (ctx0, q_nope, q_pe, 0 );
6519- cb (q_states , " q_states " , il);
6522+ struct ggml_tensor * kr_cache_view = ggml_view_1d (ctx0, kv_self. kr_l [il], n_tokens*n_embd_head_qk_rope, ggml_row_size (kv_self. kr_l [il]-> type , n_embd_head_qk_rope)*kv_head );
6523+ cb (kr_cache_view , " kr_cache_view " , il);
65206524
6521- struct ggml_tensor * k_states = ggml_concat (ctx0, k_nope, ggml_repeat (ctx0, k_pe, q_pe), 0 );
6522- cb (k_states, " k_states " , il );
6525+ // note: storing RoPE-ed version of K^R in the KV cache
6526+ ggml_build_forward_expand (gf, ggml_cpy (ctx0, k_pe, kr_cache_view) );
65236527
6524- cur = llm_build_kv (ctx0, lctx, kv_self, gf,
6525- model.layers [il].wo , NULL ,
6526- k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
6528+ struct ggml_tensor * kr_cache =
6529+ ggml_view_2d (ctx0, kv_self.kr_l [il],
6530+ n_embd_head_qk_rope, n_kv,
6531+ ggml_row_size (kv_self.kr_l [il]->type , n_embd_head_qk_rope),
6532+ 0 );
6533+ cb (kr_cache, " kr_cache" , il);
6534+
6535+ struct ggml_tensor * wk_b = ggml_view_3d (ctx0, model.layers [il].wk_b , n_embd_head_qk_nope, kv_lora_rank, n_head, ggml_row_size (model.layers [il].wk_b ->type , n_embd_head_qk_nope), ggml_row_size (model.layers [il].wk_b ->type , kv_lora_rank * n_embd_head_qk_nope), 0 );
6536+ cb (wk_b, " wk_b" , il);
6537+
6538+ q_nope = ggml_permute (ctx0, q_nope, 0 , 2 , 1 , 3 );
6539+ cb (q_nope, " q_nope_perm" , il);
6540+
6541+ struct ggml_tensor * q_nope2 = ggml_mul_mat (ctx0, wk_b, q_nope);
6542+ cb (q_nope2, " q_nope2" , il);
6543+
6544+ if (!pp_opt) {
6545+ q_nope2 = ggml_permute (ctx0, q_nope2, 0 , 2 , 1 , 3 );
6546+ cb (q_nope2, " q_nope2_perm" , il);
6547+ }
6548+
6549+ struct ggml_tensor * kq_nope = ggml_mul_mat (ctx0, kv_cache, q_nope2);
6550+ cb (kq_nope, " kq_nope" , il);
6551+
6552+ if (!pp_opt) {
6553+ kq_nope = ggml_permute (ctx0, kq_nope, 0 , 2 , 1 , 3 );
6554+ cb (kq_nope, " kq_nope_perm" , il);
6555+ }
6556+
6557+ if (pp_opt) {
6558+ q_pe = ggml_permute (ctx0, q_pe, 0 , 2 , 1 , 3 );
6559+ cb (q_pe, " q_pe_perm" , il);
6560+ }
6561+
6562+ struct ggml_tensor * kq_pe = ggml_mul_mat (ctx0, kr_cache, q_pe);
6563+ cb (kq_pe, " kq_pe" , il);
6564+
6565+ if (!pp_opt) {
6566+ kq_pe = ggml_permute (ctx0, kq_pe, 0 , 2 , 1 , 3 );
6567+ cb (kq_pe, " kq_pe_perm" , il);
6568+ }
6569+
6570+ struct ggml_tensor * kq = ggml_add (ctx0, kq_nope, kq_pe);
6571+ cb (kq, " kq" , il);
6572+
6573+ kq = ggml_soft_max_ext (ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias );
6574+ cb (kq, " kq_soft_max_ext" , il);
6575+
6576+ if (!pp_opt) {
6577+ kq = ggml_permute (ctx0, kq, 0 , 2 , 1 , 3 );
6578+ cb (kq, " kq_soft_max_ext_perm" , il);
6579+ }
6580+
6581+ struct ggml_tensor * kqv_compressed = ggml_mul_mat (ctx0, kv_cache_trans, kq);
6582+ cb (kqv_compressed, " kqv_compressed" , il);
6583+
6584+ if (!pp_opt) {
6585+ kqv_compressed = ggml_permute (ctx0, kqv_compressed, 0 , 2 , 3 , 1 );
6586+ cb (kqv_compressed, " kqv_compressed_perm" , il);
6587+ }
6588+
6589+ struct ggml_tensor * wv_b = ggml_view_3d (ctx0, model.layers [il].wv_b , kv_lora_rank, n_embd_head_v, n_head, ggml_row_size (model.layers [il].wv_b ->type , kv_lora_rank), ggml_row_size (model.layers [il].wv_b ->type , kv_lora_rank * n_embd_head_v), 0 );
6590+ cb (wv_b, " wv_b" , il);
6591+
6592+ struct ggml_tensor * kqv = ggml_mul_mat (ctx0, wv_b, kqv_compressed);
6593+ cb (kqv, " kqv" , il);
6594+
6595+ if (pp_opt) {
6596+ kqv = ggml_cont (ctx0, ggml_permute (ctx0, kqv, 0 , 2 , 1 , 3 ));
6597+ cb (kqv, " kqv_perm" , il);
6598+ }
6599+
6600+ cur = ggml_view_2d (ctx0, kqv, n_embd_head_v*n_head, n_tokens, ggml_row_size (kqv->type , n_embd_head_v*n_head), 0 );
6601+ cb (cur, " kqv_2d" , il);
6602+
6603+ ggml_build_forward_expand (gf, cur);
6604+
6605+ cur = llm_build_lora_mm (lctx, ctx0, model.layers [il].wo , cur);
6606+ cb (cur, " kqv_out" , il);
65276607 }
65286608
65296609 if (il == n_layer - 1 ) {
@@ -9768,6 +9848,24 @@ struct llama_context * llama_init_from_model(
97689848 ggml_type_name (type_v), (float )memory_size_v / (1024 .0f * 1024 .0f ));
97699849 }
97709850
9851+ {
9852+ size_t memory_size_kr = 0 ;
9853+ size_t memory_size_kv = 0 ;
9854+
9855+ for (auto & kr : ctx->kv_self .kr_l ) {
9856+ memory_size_kr += ggml_nbytes (kr);
9857+ }
9858+
9859+ for (auto & kv : ctx->kv_self .kv_l ) {
9860+ memory_size_kv += ggml_nbytes (kv);
9861+ }
9862+
9863+ LLAMA_LOG_INFO (" %s: KV self size = %7.2f MiB, K^R (%s): %7.2f MiB, c^KV (%s): %7.2f MiB\n " , __func__,
9864+ (float )(memory_size_kr + memory_size_kv) / (1024 .0f * 1024 .0f ),
9865+ ggml_type_name (type_k), (float )memory_size_kr / (1024 .0f * 1024 .0f ),
9866+ ggml_type_name (type_k), (float )memory_size_kv / (1024 .0f * 1024 .0f ));
9867+ }
9868+
97719869 // graph outputs buffer
97729870 {
97739871 // resized during inference when a batch uses more outputs
0 commit comments