@@ -6403,6 +6403,10 @@ struct llm_build_context {
64036403 // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
64046404 struct ggml_tensor * KQ_mask = build_inp_KQ_mask ();
64056405
6406+ // whether to use n_tokens as the matrix dimension during multiplication or n_head
6407+ // n_tokens is higher during prompt processing, this allows to optimize for this case
6408+ bool pp_opt = n_tokens > n_head;
6409+
64066410 for (int il = 0 ; il < n_layer; ++il) {
64076411 struct ggml_tensor * inpSA = inpL;
64086412
@@ -6535,35 +6539,45 @@ struct llm_build_context {
65356539 struct ggml_tensor * q_nope2 = ggml_mul_mat (ctx0, wk_b, q_nope_perm);
65366540 cb (q_nope2, " q_nope2" , il);
65376541
6538- struct ggml_tensor * q_nope2_perm = ggml_permute (ctx0, q_nope2, 0 , 2 , 1 , 3 );
6539- cb (q_nope2_perm, " q_nope2_perm" , il);
6542+ if (!pp_opt) {
6543+ q_nope2 = ggml_permute (ctx0, q_nope2, 0 , 2 , 1 , 3 );
6544+ cb (q_nope2, " q_nope2_perm" , il);
6545+ }
65406546
6541- struct ggml_tensor * kq_nope = ggml_mul_mat (ctx0, kv_cache, q_nope2_perm );
6547+ struct ggml_tensor * kq_nope = ggml_mul_mat (ctx0, kv_cache, q_nope2 );
65426548 cb (kq_nope, " kq_nope" , il);
65436549
6544- struct ggml_tensor * q_pe_perm = ggml_permute (ctx0, q_pe, 0 , 3 , 2 , 1 );
6545- cb (q_pe_perm, " q_pe_perm" , il);
6550+ if (pp_opt) {
6551+ q_pe = ggml_permute (ctx0, q_pe, 0 , 2 , 1 , 3 );
6552+ cb (q_pe, " q_pe_perm" , il);
6553+ }
65466554
65476555 struct ggml_tensor * kq_pe = ggml_mul_mat (ctx0, kr_cache, q_pe);
65486556 cb (kq_pe, " kq_pe" , il);
65496557
65506558 struct ggml_tensor * kq = ggml_add (ctx0, kq_nope, kq_pe);
65516559 cb (kq, " kq" , il);
65526560
6553- kq = ggml_cont (ctx0, ggml_permute (ctx0, kq, 0 , 2 , 1 , 3 ));
6554- cb (kq, " kq_perm" , il);
6561+ if (!pp_opt) {
6562+ kq = ggml_cont (ctx0, ggml_permute (ctx0, kq, 0 , 2 , 1 , 3 ));
6563+ cb (kq, " kq_perm" , il);
6564+ }
65556565
65566566 kq = ggml_soft_max_ext (ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias );
65576567 cb (kq, " kq_soft_max_ext" , il);
65586568
6559- struct ggml_tensor * kq_perm = ggml_permute (ctx0, kq, 0 , 2 , 1 , 3 );
6560- cb (kq_perm, " kq_soft_max_ext_perm" , il);
6569+ if (!pp_opt) {
6570+ kq = ggml_permute (ctx0, kq, 0 , 2 , 1 , 3 );
6571+ cb (kq, " kq_soft_max_ext_perm" , il);
6572+ }
65616573
6562- struct ggml_tensor * kqv_compressed = ggml_mul_mat (ctx0, kv_cache_trans, kq_perm );
6574+ struct ggml_tensor * kqv_compressed = ggml_mul_mat (ctx0, kv_cache_trans, kq );
65636575 cb (kqv_compressed, " kqv_compressed" , il);
65646576
6565- kqv_compressed = ggml_permute (ctx0, kqv_compressed, 0 , 2 , 1 , 3 );
6566- cb (kqv_compressed, " kqv_compressed_perm" , il);
6577+ if (!pp_opt) {
6578+ kqv_compressed = ggml_permute (ctx0, kqv_compressed, 0 , 2 , 1 , 3 );
6579+ cb (kqv_compressed, " kqv_compressed_perm" , il);
6580+ }
65676581
65686582 struct ggml_tensor * wv_b = ggml_view_3d (ctx0, model.layers [il].wv_b , kv_lora_rank, n_embd_head_v, n_head, ggml_row_size (model.layers [il].wv_b ->type , kv_lora_rank), ggml_row_size (model.layers [il].wv_b ->type , kv_lora_rank * n_embd_head_v), 0 );
65696583 cb (wv_b, " wv_b" , il);
0 commit comments