Skip to content

Commit 488b649

Browse files
committed
Fused QKV multiplication
This PR adds fused QKV multiplication
1 parent 75cbdd3 commit 488b649

File tree

6 files changed

+256
-26
lines changed

6 files changed

+256
-26
lines changed

src/llama-arch.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -757,6 +757,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
757757
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
758758
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
759759
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
760+
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
760761
},
761762
},
762763
{
@@ -777,6 +778,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
777778
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
778779
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
779780
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
781+
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
780782
},
781783
},
782784
{

src/llama-graph.cpp

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "llama-impl.h"
44
#include "llama-batch.h"
55
#include "llama-cparams.h"
6+
#include "llama-model.h"
67

78
#include "llama-kv-cache.h"
89
#include "llama-kv-cache-iswa.h"
@@ -632,6 +633,74 @@ ggml_tensor * llm_graph_context::build_lora_mm(
632633
return res;
633634
}
634635

636+
static bool disable_fusion() {
637+
const char * disable_fusion = getenv("LLAMA_GRAPH_DISABLE_FUSION");
638+
return disable_fusion != nullptr && atoi(disable_fusion) != 0;
639+
}
640+
641+
642+
void llm_graph_context::build_qkv(const llama_layer & layer,
643+
ggml_tensor * cur,
644+
int64_t n_embd_head_q,
645+
int64_t n_embd_head_k,
646+
int64_t n_embd_head_v,
647+
int32_t n_head,
648+
int32_t n_head_kv,
649+
ggml_tensor ** q_out,
650+
ggml_tensor ** k_out,
651+
ggml_tensor ** v_out,
652+
int il) const {
653+
if (disable_fusion() || !layer.wqkv || (loras && !loras->empty())) {
654+
*q_out = build_lora_mm(layer.wq, cur);
655+
cb(*q_out, "Qcur", il);
656+
657+
*k_out = build_lora_mm(layer.wk, cur);
658+
cb(*k_out, "Kcur", il);
659+
660+
*v_out = build_lora_mm(layer.wv, cur);
661+
cb(*v_out, "Vcur", il);
662+
663+
*q_out = ggml_reshape_3d(ctx0, *q_out, n_embd_head_q, n_head, n_tokens);
664+
*k_out = ggml_reshape_3d(ctx0, *k_out, n_embd_head_k, n_head_kv, n_tokens);
665+
*v_out = ggml_reshape_3d(ctx0, *v_out, n_embd_head_v, n_head_kv, n_tokens);
666+
667+
return;
668+
}
669+
670+
671+
ggml_tensor * qkv = ggml_mul_mat(ctx0, layer.wqkv, cur);
672+
cb(qkv, "wqkv", il);
673+
674+
const int64_t q_offset = 0;
675+
const int64_t k_offset = n_embd_head_q * n_head;
676+
const int64_t v_offset = k_offset + n_embd_head_k * n_head_kv;
677+
const size_t elt_size = ggml_element_size(qkv);
678+
679+
ggml_tensor * Qcur = ggml_view_3d(
680+
ctx0, qkv,
681+
n_embd_head_q, n_head, n_tokens,
682+
n_embd_head_q * elt_size, qkv->nb[1],
683+
q_offset * elt_size);
684+
ggml_tensor * Kcur = ggml_view_3d(
685+
ctx0, qkv,
686+
n_embd_head_k, n_head_kv, n_tokens,
687+
n_embd_head_k * elt_size, qkv->nb[1],
688+
k_offset * elt_size);
689+
ggml_tensor * Vcur = ggml_view_3d(
690+
ctx0, qkv,
691+
n_embd_head_v, n_head_kv, n_tokens,
692+
n_embd_head_v * elt_size, qkv->nb[1],
693+
v_offset * elt_size);
694+
695+
cb(Qcur, "Qcur", il);
696+
cb(Kcur, "Kcur", il);
697+
cb(Vcur, "Vcur", il);
698+
699+
*q_out = Qcur;
700+
*k_out = Kcur;
701+
*v_out = Vcur;
702+
}
703+
635704
ggml_tensor * llm_graph_context::build_lora_mm_id(
636705
ggml_tensor * w, // ggml_tensor * as
637706
ggml_tensor * cur, // ggml_tensor * b

src/llama-graph.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ class llama_kv_cache_iswa_context;
2424
class llama_memory_recurrent_context;
2525
class llama_memory_hybrid_context;
2626

27+
struct llama_layer;
28+
2729
// certain models (typically multi-modal) can produce different types of graphs
2830
enum llm_graph_type {
2931
LLM_GRAPH_TYPE_DEFAULT,
@@ -604,6 +606,18 @@ struct llm_graph_context {
604606
ggml_tensor * w,
605607
ggml_tensor * cur) const;
606608

609+
void build_qkv(const llama_layer & layer,
610+
ggml_tensor * cur,
611+
int64_t n_embd_head_q,
612+
int64_t n_embd_head_k,
613+
int64_t n_embd_head_v,
614+
int32_t n_head,
615+
int32_t n_head_kv,
616+
ggml_tensor ** q_out,
617+
ggml_tensor ** k_out,
618+
ggml_tensor ** v_out,
619+
int il) const;
620+
607621
// do mat_mul_id, while optionally apply lora
608622
ggml_tensor * build_lora_mm_id(
609623
ggml_tensor * w, // ggml_tensor * as

src/llama-model-loader.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -840,6 +840,40 @@ struct ggml_tensor * llama_model_loader::create_tensor_as_view(struct ggml_conte
840840
return tensor;
841841
}
842842

843+
struct ggml_tensor * llama_model_loader::create_contiguous_tensor(struct ggml_context * ctx, const std::string & fused_name, const std::initializer_list<int64_t> & ne
844+
, std::vector<ggml_tensor**> tensors, int flags) {
845+
846+
(void)flags;
847+
848+
if (weights_map.find(fused_name) != weights_map.end()) {
849+
return nullptr;
850+
}
851+
852+
if (ggml_get_tensor(ctx, fused_name.c_str()) != nullptr) {
853+
return nullptr;
854+
}
855+
856+
const ggml_type type = (*tensors[0])->type;
857+
858+
struct ggml_tensor * fused = ggml_new_tensor(ctx, type, ne.size(), ne.begin());
859+
860+
if (!fused) {
861+
return nullptr;
862+
}
863+
864+
ggml_set_name(fused, fused_name.c_str());
865+
866+
size_t offset = 0;
867+
for (ggml_tensor **tensor : tensors) {
868+
std::initializer_list<int64_t> ne = { (*tensor)->ne[0], (*tensor)->ne[1], (*tensor)->ne[2], (*tensor)->ne[3] };
869+
struct ggml_tensor * view = create_tensor_as_view(ctx, fused, ggml_get_name(*tensor), ne, offset, false);
870+
*tensor = view;
871+
offset += ggml_nbytes(*tensor);
872+
}
873+
874+
return fused;
875+
}
876+
843877
void llama_model_loader::done_getting_tensors() const {
844878
if (n_created != n_tensors) {
845879
throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created));

src/llama-model-loader.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,9 @@ struct llama_model_loader {
147147

148148
struct ggml_tensor * create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::initializer_list<int64_t> & ne, size_t offset, bool required = true);
149149

150+
struct ggml_tensor * create_contiguous_tensor(struct ggml_context * ctx, const std::string & fused_name, const std::initializer_list<int64_t> & ne
151+
, std::vector<ggml_tensor**> tensors, int flags = 0);
152+
150153
void done_getting_tensors() const;
151154

152155
void init_mappings(bool prefetch = true, llama_mlocks * mlock_mmaps = nullptr);

src/llama-model.cpp

Lines changed: 134 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "llama-model.h"
22

3+
#include "gguf.h"
34
#include "llama-impl.h"
45
#include "llama-mmap.h"
56
#include "llama-batch.h"
@@ -2428,6 +2429,99 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
24282429
return ml.create_tensor(ctx, tn, ne, flags);
24292430
};
24302431

2432+
struct tensor_def {
2433+
LLM_TN_IMPL tn;
2434+
std::vector<int64_t> ne;
2435+
int flags;
2436+
ggml_tensor ** out;
2437+
};
2438+
2439+
auto create_contiguous = [&](const LLM_TN_IMPL & fused_tn,
2440+
std::initializer_list<int64_t> ne,
2441+
std::initializer_list<tensor_def> reqs) -> ggml_tensor * {
2442+
ggml_backend_buffer_type_t fused_buft = nullptr;
2443+
2444+
for (size_t i = 0; i < reqs.size(); ++i) {
2445+
const tensor_def & req = reqs.begin()[i];
2446+
const bool required = (req.flags & llama_model_loader::TENSOR_NOT_REQUIRED) == 0;
2447+
const ggml_tensor * tensor_meta = ml.check_tensor_dims(req.tn.str(), req.ne, required);
2448+
2449+
*req.out = const_cast<ggml_tensor*>(tensor_meta);
2450+
2451+
if (!*req.out) {
2452+
return nullptr;
2453+
}
2454+
2455+
llm_tensor tn_tensor = req.tn.tensor;
2456+
if (tn_tensor == LLM_TENSOR_TOKEN_EMBD && (req.flags & llama_model_loader::TENSOR_DUPLICATED)) {
2457+
tn_tensor = LLM_TENSOR_OUTPUT;
2458+
}
2459+
2460+
llm_tensor_info info;
2461+
try {
2462+
info = llm_tensor_info_for(tn_tensor);
2463+
} catch (const std::out_of_range &) {
2464+
throw std::runtime_error(format("missing tensor info mapping for %s", req.tn.str().c_str()));
2465+
}
2466+
2467+
bool bias = req.tn.suffix != nullptr && strcmp(req.tn.suffix, "bias") == 0;
2468+
ggml_op op = bias ? (info.op == GGML_OP_MUL_MAT_ID ? GGML_OP_ADD_ID : GGML_OP_ADD) : info.op;
2469+
2470+
buft_list_t * buft_list = nullptr;
2471+
switch (info.layer) {
2472+
case LLM_TENSOR_LAYER_INPUT:
2473+
buft_list = pimpl->dev_input.buft_list;
2474+
break;
2475+
case LLM_TENSOR_LAYER_OUTPUT:
2476+
buft_list = pimpl->dev_output.buft_list;
2477+
break;
2478+
case LLM_TENSOR_LAYER_REPEATING:
2479+
buft_list = pimpl->dev_layer.at(req.tn.bid).buft_list;
2480+
break;
2481+
default:
2482+
GGML_ABORT("invalid layer %d for tensor %s", info.layer, req.tn.str().c_str());
2483+
}
2484+
2485+
ggml_backend_buffer_type_t buft = select_weight_buft(hparams, *req.out, op, *buft_list);
2486+
if (!buft) {
2487+
return nullptr;
2488+
}
2489+
2490+
auto * buft_dev = ggml_backend_buft_get_device(buft);
2491+
if (ml.use_mmap && buft_dev && buft == ggml_backend_dev_host_buffer_type(buft_dev)) {
2492+
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
2493+
if (!cpu_dev) {
2494+
throw std::runtime_error("no CPU backend found");
2495+
}
2496+
buft = ggml_backend_dev_buffer_type(cpu_dev);
2497+
}
2498+
2499+
//TODO: check buft overrides
2500+
2501+
if (!fused_buft) {
2502+
fused_buft = buft;
2503+
} else if (fused_buft != buft) {
2504+
return nullptr;
2505+
}
2506+
}
2507+
2508+
if (!fused_buft) {
2509+
return nullptr;
2510+
}
2511+
2512+
ggml_context * ctx = ctx_for_buft(fused_buft);
2513+
2514+
std::vector<ggml_tensor**> tensor_req{reqs.size()};
2515+
for (size_t i = 0; i < reqs.size(); ++i) {
2516+
const auto & req = reqs.begin()[i];
2517+
tensor_req[i] = req.out;
2518+
}
2519+
2520+
ggml_tensor * fused = ml.create_contiguous_tensor(ctx, fused_tn.str(), ne, tensor_req, 0);
2521+
2522+
return fused;
2523+
};
2524+
24312525
layers.resize(n_layer);
24322526

24332527
// TODO: move to a separate function
@@ -3297,9 +3391,19 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
32973391

32983392
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
32993393

3300-
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
3301-
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
3302-
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
3394+
layer.wqkv = create_contiguous(
3395+
tn(LLM_TENSOR_ATTN_QKV, "weight", i),
3396+
{n_embd, n_embd_head_k * n_head + n_embd_gqa * 2},
3397+
{
3398+
{ tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0, &layer.wq },
3399+
{ tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0, &layer.wk },
3400+
{ tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0, &layer.wv },
3401+
});
3402+
if (!layer.wqkv) {
3403+
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
3404+
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
3405+
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
3406+
}
33033407
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
33043408

33053409
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
@@ -3328,9 +3432,19 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
33283432

33293433
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
33303434

3331-
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
3332-
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
3333-
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
3435+
layer.wqkv = create_contiguous(
3436+
tn(LLM_TENSOR_ATTN_QKV, "weight", i),
3437+
{n_embd, n_embd_head_k * n_head + n_embd_gqa * 2},
3438+
{
3439+
{ tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0, &layer.wq },
3440+
{ tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0, &layer.wk },
3441+
{ tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0, &layer.wv },
3442+
});
3443+
if (!layer.wqkv) {
3444+
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
3445+
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
3446+
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
3447+
}
33343448
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
33353449

33363450
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
@@ -9388,18 +9502,15 @@ struct llm_build_qwen3 : public llm_graph_context {
93889502
// self-attention
93899503
{
93909504
// compute Q and K and RoPE them
9391-
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
9392-
cb(Qcur, "Qcur", il);
93939505

9394-
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
9395-
cb(Kcur, "Kcur", il);
9396-
9397-
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
9398-
cb(Vcur, "Vcur", il);
9506+
ggml_tensor * Qcur = nullptr;
9507+
ggml_tensor * Kcur = nullptr;
9508+
ggml_tensor * Vcur = nullptr;
93999509

9400-
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
9401-
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
9402-
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
9510+
build_qkv(model.layers[il], cur, n_embd_head,
9511+
n_embd_head_k, n_embd_head_v, n_head, n_head_kv,
9512+
&Qcur, &Kcur, &Vcur, il
9513+
);
94039514

94049515
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
94059516
cb(Qcur, "Qcur_normed", il);
@@ -9509,18 +9620,15 @@ struct llm_build_qwen3moe : public llm_graph_context {
95099620
// self_attention
95109621
{
95119622
// compute Q and K and RoPE them
9512-
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
9513-
cb(Qcur, "Qcur", il);
95149623

9515-
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
9516-
cb(Kcur, "Kcur", il);
9517-
9518-
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
9519-
cb(Vcur, "Vcur", il);
9624+
ggml_tensor * Qcur = nullptr;
9625+
ggml_tensor * Kcur = nullptr;
9626+
ggml_tensor * Vcur = nullptr;
95209627

9521-
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
9522-
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
9523-
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
9628+
build_qkv(model.layers[il], cur, n_embd_head,
9629+
n_embd_head_k, n_embd_head_v, n_head, n_head_kv,
9630+
&Qcur, &Kcur, &Vcur, il
9631+
);
95249632

95259633
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
95269634
cb(Qcur, "Qcur_normed", il);

0 commit comments

Comments
 (0)