Skip to content

Commit 08dbb7f

Browse files
authored
Merge pull request #29 from ikawrakow/main
Merge Q, K, V (ikawrakow#878)
2 parents ef49b83 + 14760aa commit 08dbb7f

File tree

10 files changed

+260
-119
lines changed

10 files changed

+260
-119
lines changed

common/common.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1272,6 +1272,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
12721272
params.validate_quants = true;
12731273
return true;
12741274
}
1275+
if (arg == "-mqkv" || arg == "--merge-qkv") {
1276+
params.merge_qkv = true;
1277+
return true;
1278+
}
12751279
if (arg == "--numa") {
12761280
CHECK_ARG
12771281
std::string value(argv[i]);
@@ -1911,6 +1915,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
19111915
options.push_back({ "*", "-no-fug, --no-fused-up-gate", "disaable fused up-gate (default: %s)", params.fused_up_gate ? "enabled" : "disabled" });
19121916
options.push_back({ "*", "-no-mmad, --no-fused-mul-multiadd", "disaable fused mul-multi_add (default: %s)", params.fused_mmad? "enabled" : "disabled" });
19131917
options.push_back({ "*", "-ser, --smart-expert-reduction,","experts reduction (default: %d,%g)", params.min_experts, params.thresh_experts});
1918+
options.push_back({ "*", "-mqkv, --merge-qkv,", "merge Q,K,V (default: %d)", params.merge_qkv});
19141919
options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n"
19151920
"in conversation mode, this will be used as system prompt\n"
19161921
"(default: '%s')", params.prompt.c_str() });
@@ -2778,7 +2783,7 @@ void llama_lora_adapters_apply(struct llama_context * ctx, std::vector<llama_lor
27782783

27792784
struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & params) {
27802785
auto mparams = llama_model_default_params();
2781-
mparams.devices = params.devices.c_str();
2786+
mparams.devices = params.devices.c_str();
27822787

27832788
if (params.n_gpu_layers != -1) {
27842789
mparams.n_gpu_layers = params.n_gpu_layers;
@@ -2794,6 +2799,7 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
27942799
mparams.repack_tensors = params.repack_tensors;
27952800
mparams.use_thp = params.use_thp;
27962801
mparams.validate_quants = params.validate_quants;
2802+
mparams.merge_qkv = params.merge_qkv;
27972803
if (params.kv_overrides.empty()) {
27982804
mparams.kv_overrides = NULL;
27992805
} else {
@@ -3965,6 +3971,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
39653971
fprintf(stream, "repack: %s # default: false\n", params.repack_tensors ? "true" : "false");
39663972
fprintf(stream, "use_thp: %s # default: false\n", params.use_thp ? "true" : "false");
39673973
fprintf(stream, "validate_quants: %s # default: false\n", params.validate_quants ? "true" : "false");
3974+
fprintf(stream, "merge_qkv: %s # default: false\n", params.merge_qkv ? "true" : "false");
39683975
fprintf(stream, "penalize_nl: %s # default: false\n", sparams.penalize_nl ? "true" : "false");
39693976
fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type);
39703977
fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride);

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,7 @@ struct gpt_params {
269269
bool use_thp = false; // use transparent huge pages (linux only)
270270
bool validate_quants = false; // if true, check for NaNs while loading the model
271271
bool only_active_exps = true; // if true, offload only active experts (relevant only for hybrid CPU/GPU)
272+
bool merge_qkv = false; // if true, merge separate Q, K, V tensors into a single, contiguous tensor
272273

273274
std::string cache_type_k = "f16"; // KV cache data type for the K
274275
std::string cache_type_v = "f16"; // KV cache data type for the V

include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,7 @@ extern "C" {
382382
bool repack_tensors;// repack if available
383383
bool use_thp; // use transparent huge pages (linux only)
384384
bool validate_quants; // if true, check for NaNs while loading the model
385+
bool merge_qkv; // if true, merge separate Q, K, V tensors into a single, contiguous tensor
385386
};
386387

387388
// NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations

src/llama-build-context.cpp

Lines changed: 126 additions & 82 deletions
Large diffs are not rendered by default.

src/llama-build-context.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,13 @@ struct llm_build_context {
149149
ggml_tensor * wv, ggml_tensor * bv,
150150
float attention_scale, int il);
151151

152+
std::tuple<ggml_tensor*, ggml_tensor*, ggml_tensor*> llm_build_mul_mat_qkv(ggml_cgraph * gf, ggml_tensor * cur,
153+
ggml_tensor * wqkv, ggml_tensor * bqkv,
154+
ggml_tensor * wq, ggml_tensor * bq,
155+
ggml_tensor * wk, ggml_tensor * bk,
156+
ggml_tensor * wv, ggml_tensor * bv,
157+
ggml_tensor * q_norm, ggml_tensor * k_norm, float attention_scale, int il);
158+
152159
ggml_cgraph * build_llama();
153160

154161
ggml_cgraph * build_deci();

src/llama-load-tensors.cpp

Lines changed: 105 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ struct create_tensors_helper : public create_tensors_helper_interface {
2828

2929
virtual size_t get_ctx_size() const override { return ctx_size; }
3030

31+
bool merge_qkv(const LLM_TN & tn, int i, int bias);
32+
3133
bool create_tensors() override;
3234

3335
bool create_llama_tensors(const LLM_TN & tn);
@@ -284,15 +286,11 @@ bool create_tensors_helper::create_llama_tensors(const LLM_TN & tn) {
284286

285287
layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
286288

287-
layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head});
288-
layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa});
289-
layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa});
289+
use_mmap_buffer &= !merge_qkv(tn, i, 1);
290+
290291
layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd});
291292

292293
// optional bias tensors
293-
layer.bq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
294-
layer.bk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
295-
layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
296294
layer.bo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
297295

298296
layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
@@ -418,9 +416,8 @@ bool create_tensors_helper::create_llama4_tensors(const LLM_TN & tn) {
418416

419417
layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
420418

421-
layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
422-
layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
423-
layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
419+
use_mmap_buffer &= !merge_qkv(tn, i, 0);
420+
424421
layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
425422

426423
layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
@@ -1018,9 +1015,8 @@ bool create_tensors_helper::create_qwen3_tensors(const LLM_TN & tn) {
10181015

10191016
layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
10201017

1021-
layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head});
1022-
layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
1023-
layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
1018+
use_mmap_buffer &= !merge_qkv(tn, i, 0);
1019+
10241020
layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd});
10251021

10261022
layer.attn_k_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k});
@@ -1044,9 +1040,8 @@ bool create_tensors_helper::create_qwen3_moe_tensors(const LLM_TN & tn) {
10441040

10451041
layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
10461042

1047-
layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head});
1048-
layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
1049-
layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
1043+
use_mmap_buffer &= !merge_qkv(tn, i, 0);
1044+
10501045
layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd});
10511046

10521047
layer.attn_k_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k});
@@ -1700,12 +1695,16 @@ bool create_tensors_helper::create_glm4_moe_tensors(const LLM_TN & tn) {
17001695
layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags);
17011696

17021697
// GLM-style attention with bias terms
1703-
layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, flags);
1704-
layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, flags);
1705-
layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, flags);
1706-
layer.bq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), { n_embd_head_k * n_head }, flags);
1707-
layer.bk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), { n_embd_k_gqa }, flags);
1708-
layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), { n_embd_v_gqa }, flags);
1698+
if (!flags) {
1699+
use_mmap_buffer &= !merge_qkv(tn, i, 2);
1700+
} else {
1701+
layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, flags);
1702+
layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, flags);
1703+
layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, flags);
1704+
layer.bq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), { n_embd_head_k * n_head }, flags);
1705+
layer.bk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), { n_embd_k_gqa }, flags);
1706+
layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), { n_embd_v_gqa }, flags);
1707+
}
17091708

17101709
layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, flags);
17111710

@@ -2380,10 +2379,10 @@ bool create_tensors_helper::create_openai_moe_tensors(const LLM_TN & tn) {
23802379
layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
23812380
layer.attn_post_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
23822381

2383-
layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_rot}, 0);
2384-
layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_head_kv * n_rot}, 0);
2385-
layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_head_kv * n_rot}, 0);
2382+
use_mmap_buffer &= !merge_qkv(tn, i, 2);
2383+
23862384
layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_rot, n_embd}, 0);
2385+
layer.bo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
23872386

23882387
layer.attn_sinks = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_SINKS, "weight", i), {n_head}, 0);
23892388

@@ -2394,11 +2393,6 @@ bool create_tensors_helper::create_openai_moe_tensors(const LLM_TN & tn) {
23942393
layer.ffn_up_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0, &ctx_ffn_up);
23952394

23962395
// bias
2397-
layer.bq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_head * n_rot}, 0);
2398-
layer.bk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_head_kv * n_rot}, 0);
2399-
layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_head_kv * n_rot}, 0);
2400-
layer.bo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
2401-
24022396
ggml_context *ctx_ffn_gate_b, *ctx_ffn_up_b, *ctx_ffn_down_b;
24032397
layer.ffn_gate_inp_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "bias", i), {n_expert}, 0);
24042398
layer.ffn_gate_exps_b = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "bias", i), {n_ff_exp, n_expert}, 0, &ctx_ffn_gate_b);
@@ -2421,6 +2415,88 @@ bool create_tensors_helper::create_openai_moe_tensors(const LLM_TN & tn) {
24212415
return use_mmap_buffer;
24222416
}
24232417

2418+
bool create_tensors_helper::merge_qkv(const LLM_TN & tn, int i, int bias) {
2419+
auto& hparams = model.hparams;
2420+
const int64_t n_head = hparams.n_head();
2421+
const int64_t n_head_kv = hparams.n_head_kv();
2422+
const int64_t n_embd = hparams.n_embd;
2423+
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
2424+
const int64_t n_embd_head_k = hparams.n_embd_head_k;
2425+
const int64_t n_embd_gqa = n_embd_v_gqa;
2426+
2427+
ggml_context * ctx_layer = ctx_for_layer(i);
2428+
ggml_context * ctx_split = ctx_for_layer_split(i);
2429+
2430+
auto & layer = model.layers[i];
2431+
2432+
auto wq_name = tn(LLM_TENSOR_ATTN_Q, "weight", i);
2433+
auto wk_name = tn(LLM_TENSOR_ATTN_K, "weight", i);
2434+
auto wv_name = tn(LLM_TENSOR_ATTN_V, "weight", i);
2435+
auto wq = ml.require_tensor_meta(wq_name.c_str());
2436+
auto wk = ml.require_tensor_meta(wk_name.c_str());
2437+
auto wv = ml.require_tensor_meta(wv_name.c_str());
2438+
GGML_ASSERT(wq && wk && wv);
2439+
2440+
bool fused_qkv = false;
2441+
if (ml.merge_qkv && wq->type == wk->type && wq->type == wv->type && hparams.f_attention_scale == 0.0f) {
2442+
GGML_ASSERT(wq->ne[0] == n_embd && wq->ne[1] == n_head * n_embd_head_k);
2443+
GGML_ASSERT(wk->ne[0] == n_embd && wk->ne[1] == n_embd_gqa);
2444+
GGML_ASSERT(wv->ne[0] == n_embd && wv->ne[1] == n_embd_gqa);
2445+
layer.wqkv = ggml_new_tensor_2d(ctx_split, wq->type, n_embd, n_embd_head_k * (n_head + n_head_kv + n_head_kv));
2446+
snprintf(layer.wqkv->name, GGML_MAX_NAME, "blk.%d.attn_qkv.weight", i);
2447+
// This does not work. If we are doing this merge manually, it basically means that the arch does not have
2448+
// an LLM_TENSOR_ATTN_QKV entry, so we will get __missing__ as the tensor name.
2449+
//ggml_set_name(layer.wqkv, tn(LLM_TENSOR_ATTN_QKV, "weight", i).c_str());
2450+
layer.wq = ml.create_tensor_as_view(ctx_split, layer.wqkv, wq_name.c_str(), { wq->ne[0], wq->ne[1] }, 0);
2451+
layer.wk = ml.create_tensor_as_view(ctx_split, layer.wqkv, wk_name.c_str(), { wk->ne[0], wk->ne[1] }, wq->ne[1]*wq->nb[1]);
2452+
layer.wv = ml.create_tensor_as_view(ctx_split, layer.wqkv, wv_name.c_str(), { wv->ne[0], wv->ne[1] }, wq->ne[1]*wq->nb[1] + wk->ne[1]*wk->nb[1] );
2453+
fused_qkv = true;
2454+
printf("================================== Created merged qkv %s\n", layer.wqkv->name);
2455+
if (bias) {
2456+
auto bq_name = tn(LLM_TENSOR_ATTN_Q, "bias", i);
2457+
auto bk_name = tn(LLM_TENSOR_ATTN_K, "bias", i);
2458+
auto bv_name = tn(LLM_TENSOR_ATTN_V, "bias", i);
2459+
auto bq = ml.get_tensor_meta(bq_name.c_str());
2460+
auto bk = ml.get_tensor_meta(bk_name.c_str());
2461+
auto bv = ml.get_tensor_meta(bv_name.c_str());
2462+
if (bias == 2) {
2463+
GGML_ASSERT(bq && bk && bv);
2464+
} else {
2465+
GGML_ASSERT(!bq && !bk && !bv);
2466+
}
2467+
if (bq && bk && bv) {
2468+
GGML_ASSERT(bq->type == GGML_TYPE_F32 && bk->type == GGML_TYPE_F32 && bv->type == GGML_TYPE_F32);
2469+
GGML_ASSERT(ggml_nrows(bq) == 1 && bq->ne[0] == wq->ne[1]);
2470+
GGML_ASSERT(ggml_nrows(bk) == 1 && bk->ne[0] == wk->ne[1]);
2471+
GGML_ASSERT(ggml_nrows(bv) == 1 && bv->ne[0] == wv->ne[1]);
2472+
layer.bqkv = ggml_new_tensor_1d(ctx_layer, bq->type, n_embd_head_k * (n_head + n_head_kv + n_head_kv));
2473+
snprintf(layer.bqkv->name, GGML_MAX_NAME, "blk.%d.attn_qkv.bias", i);
2474+
layer.bq = ml.create_tensor_as_view(ctx_layer, layer.bqkv, bq_name.c_str(), { bq->ne[0] }, 0);
2475+
layer.bk = ml.create_tensor_as_view(ctx_layer, layer.bqkv, bk_name.c_str(), { bk->ne[0] }, bq->ne[0]*bq->nb[0]);
2476+
layer.bv = ml.create_tensor_as_view(ctx_layer, layer.bqkv, bv_name.c_str(), { bv->ne[0] }, bq->ne[0]*bq->nb[0] + bk->ne[0]*bk->nb[0] );
2477+
}
2478+
}
2479+
}
2480+
2481+
if (!fused_qkv) {
2482+
if (ml.merge_qkv) {
2483+
printf("%s: did not merge Q, K, V in layer %d because %d, %d, %d\n", __func__, i,
2484+
wq->type == wk->type, wq->type == wv->type, hparams.f_attention_scale == 0.0f);
2485+
}
2486+
layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head});
2487+
layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
2488+
layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
2489+
if (bias) {
2490+
auto flags = bias == 1 ? llama_model_loader::TENSOR_NOT_REQUIRED : 0;
2491+
layer.bq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {layer.wq->ne[1]}, flags);
2492+
layer.bk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {layer.wk->ne[1]}, flags);
2493+
layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {layer.wv->ne[1]}, flags);
2494+
}
2495+
}
2496+
2497+
return fused_qkv;
2498+
}
2499+
24242500
bool create_tensors_helper::create_tensors() {
24252501
const auto tn = LLM_TN(model.arch);
24262502
bool use_mmap_buffer = true;

src/llama-model-loader.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,9 +203,10 @@ namespace GGUFMeta {
203203
};
204204
}
205205

206-
llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, bool repack_tensors, bool use_thp,
207-
const llama_model_kv_override * param_overrides_p,
208-
const llama_model_tensor_buft_override * param_tensor_buft_overrides_p) {
206+
llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors,
207+
bool repack_tensors, bool use_thp, bool merge_qkv,
208+
const llama_model_kv_override * param_overrides_p,
209+
const llama_model_tensor_buft_override * param_tensor_buft_overrides_p) {
209210
int trace = 0;
210211
if (getenv("LLAMA_TRACE")) {
211212
trace = atoi(getenv("LLAMA_TRACE"));
@@ -495,6 +496,7 @@ llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap,
495496
this->check_tensors = check_tensors;
496497
this->repack_tensors = repack_tensors;
497498
this->use_thp = use_thp;
499+
this->merge_qkv = merge_qkv;
498500
}
499501

500502
llama_model_loader::~llama_model_loader() {

src/llama-model-loader.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ struct llama_model_loader {
4444
bool check_tensors;
4545
bool repack_tensors = false;
4646
bool use_thp = false;
47+
bool merge_qkv = false;
4748

4849
llama_files files;
4950
llama_ftype ftype;
@@ -78,7 +79,7 @@ struct llama_model_loader {
7879
std::string arch_name;
7980
LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN);
8081

81-
llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, bool repack_tensors, bool use_thp,
82+
llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, bool repack_tensors, bool use_thp, bool merge_qkv,
8283
const llama_model_kv_override * param_overrides_p,
8384
const llama_model_tensor_buft_override * param_tensor_buft_overrides_p);
8485

src/llama-quantize.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1007,7 +1007,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
10071007
auto v = (std::vector<llama_model_kv_override>*)params->kv_overrides;
10081008
kv_overrides = v->data();
10091009
}
1010-
llama_model_loader ml(fname_inp, use_mmap, /*check_tensors*/ true, /* repack_tensors */ false, /* use_thp */ false, kv_overrides, nullptr);
1010+
llama_model_loader ml(fname_inp, use_mmap, /*check_tensors*/ true, /* repack_tensors */ false,
1011+
/* use_thp */ false, /* merge_qkv */ false, kv_overrides, nullptr);
10111012
ml.init_mappings(false); // no prefetching
10121013

10131014
llama_model model;

0 commit comments

Comments
 (0)