|
1 | 1 | #include "llama-model.h" |
2 | 2 |
|
| 3 | +#include "gguf.h" |
3 | 4 | #include "llama-impl.h" |
4 | 5 | #include "llama-mmap.h" |
5 | 6 | #include "llama-batch.h" |
@@ -2428,6 +2429,99 @@ bool llama_model::load_tensors(llama_model_loader & ml) { |
2428 | 2429 | return ml.create_tensor(ctx, tn, ne, flags); |
2429 | 2430 | }; |
2430 | 2431 |
|
| 2432 | + struct tensor_def { |
| 2433 | + LLM_TN_IMPL tn; |
| 2434 | + std::vector<int64_t> ne; |
| 2435 | + int flags; |
| 2436 | + ggml_tensor ** out; |
| 2437 | + }; |
| 2438 | + |
| 2439 | + auto create_contiguous = [&](const LLM_TN_IMPL & fused_tn, |
| 2440 | + std::initializer_list<int64_t> ne, |
| 2441 | + std::initializer_list<tensor_def> reqs) -> ggml_tensor * { |
| 2442 | + ggml_backend_buffer_type_t fused_buft = nullptr; |
| 2443 | + |
| 2444 | + for (size_t i = 0; i < reqs.size(); ++i) { |
| 2445 | + const tensor_def & req = reqs.begin()[i]; |
| 2446 | + const bool required = (req.flags & llama_model_loader::TENSOR_NOT_REQUIRED) == 0; |
| 2447 | + const ggml_tensor * tensor_meta = ml.check_tensor_dims(req.tn.str(), req.ne, required); |
| 2448 | + |
| 2449 | + *req.out = const_cast<ggml_tensor*>(tensor_meta); |
| 2450 | + |
| 2451 | + if (!*req.out) { |
| 2452 | + return nullptr; |
| 2453 | + } |
| 2454 | + |
| 2455 | + llm_tensor tn_tensor = req.tn.tensor; |
| 2456 | + if (tn_tensor == LLM_TENSOR_TOKEN_EMBD && (req.flags & llama_model_loader::TENSOR_DUPLICATED)) { |
| 2457 | + tn_tensor = LLM_TENSOR_OUTPUT; |
| 2458 | + } |
| 2459 | + |
| 2460 | + llm_tensor_info info; |
| 2461 | + try { |
| 2462 | + info = llm_tensor_info_for(tn_tensor); |
| 2463 | + } catch (const std::out_of_range &) { |
| 2464 | + throw std::runtime_error(format("missing tensor info mapping for %s", req.tn.str().c_str())); |
| 2465 | + } |
| 2466 | + |
| 2467 | + bool bias = req.tn.suffix != nullptr && strcmp(req.tn.suffix, "bias") == 0; |
| 2468 | + ggml_op op = bias ? (info.op == GGML_OP_MUL_MAT_ID ? GGML_OP_ADD_ID : GGML_OP_ADD) : info.op; |
| 2469 | + |
| 2470 | + buft_list_t * buft_list = nullptr; |
| 2471 | + switch (info.layer) { |
| 2472 | + case LLM_TENSOR_LAYER_INPUT: |
| 2473 | + buft_list = pimpl->dev_input.buft_list; |
| 2474 | + break; |
| 2475 | + case LLM_TENSOR_LAYER_OUTPUT: |
| 2476 | + buft_list = pimpl->dev_output.buft_list; |
| 2477 | + break; |
| 2478 | + case LLM_TENSOR_LAYER_REPEATING: |
| 2479 | + buft_list = pimpl->dev_layer.at(req.tn.bid).buft_list; |
| 2480 | + break; |
| 2481 | + default: |
| 2482 | + GGML_ABORT("invalid layer %d for tensor %s", info.layer, req.tn.str().c_str()); |
| 2483 | + } |
| 2484 | + |
| 2485 | + ggml_backend_buffer_type_t buft = select_weight_buft(hparams, *req.out, op, *buft_list); |
| 2486 | + if (!buft) { |
| 2487 | + return nullptr; |
| 2488 | + } |
| 2489 | + |
| 2490 | + auto * buft_dev = ggml_backend_buft_get_device(buft); |
| 2491 | + if (ml.use_mmap && buft_dev && buft == ggml_backend_dev_host_buffer_type(buft_dev)) { |
| 2492 | + auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); |
| 2493 | + if (!cpu_dev) { |
| 2494 | + throw std::runtime_error("no CPU backend found"); |
| 2495 | + } |
| 2496 | + buft = ggml_backend_dev_buffer_type(cpu_dev); |
| 2497 | + } |
| 2498 | + |
| 2499 | + //TODO: check buft overrides |
| 2500 | + |
| 2501 | + if (!fused_buft) { |
| 2502 | + fused_buft = buft; |
| 2503 | + } else if (fused_buft != buft) { |
| 2504 | + return nullptr; |
| 2505 | + } |
| 2506 | + } |
| 2507 | + |
| 2508 | + if (!fused_buft) { |
| 2509 | + return nullptr; |
| 2510 | + } |
| 2511 | + |
| 2512 | + ggml_context * ctx = ctx_for_buft(fused_buft); |
| 2513 | + |
| 2514 | + std::vector<ggml_tensor**> tensor_req{reqs.size()}; |
| 2515 | + for (size_t i = 0; i < reqs.size(); ++i) { |
| 2516 | + const auto & req = reqs.begin()[i]; |
| 2517 | + tensor_req[i] = req.out; |
| 2518 | + } |
| 2519 | + |
| 2520 | + ggml_tensor * fused = ml.create_contiguous_tensor(ctx, fused_tn.str(), ne, tensor_req, 0); |
| 2521 | + |
| 2522 | + return fused; |
| 2523 | + }; |
| 2524 | + |
2431 | 2525 | layers.resize(n_layer); |
2432 | 2526 |
|
2433 | 2527 | // TODO: move to a separate function |
@@ -3297,9 +3391,19 @@ bool llama_model::load_tensors(llama_model_loader & ml) { |
3297 | 3391 |
|
3298 | 3392 | layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); |
3299 | 3393 |
|
3300 | | - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); |
3301 | | - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); |
3302 | | - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); |
| 3394 | + layer.wqkv = create_contiguous( |
| 3395 | + tn(LLM_TENSOR_ATTN_QKV, "weight", i), |
| 3396 | + {n_embd, n_embd_head_k * n_head + n_embd_gqa * 2}, |
| 3397 | + { |
| 3398 | + { tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0, &layer.wq }, |
| 3399 | + { tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0, &layer.wk }, |
| 3400 | + { tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0, &layer.wv }, |
| 3401 | + }); |
| 3402 | + if (!layer.wqkv) { |
| 3403 | + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); |
| 3404 | + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); |
| 3405 | + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); |
| 3406 | + } |
3303 | 3407 | layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); |
3304 | 3408 |
|
3305 | 3409 | layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); |
@@ -3328,9 +3432,19 @@ bool llama_model::load_tensors(llama_model_loader & ml) { |
3328 | 3432 |
|
3329 | 3433 | layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); |
3330 | 3434 |
|
3331 | | - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); |
3332 | | - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); |
3333 | | - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); |
| 3435 | + layer.wqkv = create_contiguous( |
| 3436 | + tn(LLM_TENSOR_ATTN_QKV, "weight", i), |
| 3437 | + {n_embd, n_embd_head_k * n_head + n_embd_gqa * 2}, |
| 3438 | + { |
| 3439 | + { tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0, &layer.wq }, |
| 3440 | + { tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0, &layer.wk }, |
| 3441 | + { tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0, &layer.wv }, |
| 3442 | + }); |
| 3443 | + if (!layer.wqkv) { |
| 3444 | + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); |
| 3445 | + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); |
| 3446 | + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); |
| 3447 | + } |
3334 | 3448 | layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); |
3335 | 3449 |
|
3336 | 3450 | layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); |
@@ -9388,18 +9502,15 @@ struct llm_build_qwen3 : public llm_graph_context { |
9388 | 9502 | // self-attention |
9389 | 9503 | { |
9390 | 9504 | // compute Q and K and RoPE them |
9391 | | - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); |
9392 | | - cb(Qcur, "Qcur", il); |
9393 | 9505 |
|
9394 | | - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); |
9395 | | - cb(Kcur, "Kcur", il); |
9396 | | - |
9397 | | - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); |
9398 | | - cb(Vcur, "Vcur", il); |
| 9506 | + ggml_tensor * Qcur = nullptr; |
| 9507 | + ggml_tensor * Kcur = nullptr; |
| 9508 | + ggml_tensor * Vcur = nullptr; |
9399 | 9509 |
|
9400 | | - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); |
9401 | | - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); |
9402 | | - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); |
| 9510 | + build_qkv(model.layers[il], cur, n_embd_head, |
| 9511 | + n_embd_head_k, n_embd_head_v, n_head, n_head_kv, |
| 9512 | + &Qcur, &Kcur, &Vcur, il |
| 9513 | + ); |
9403 | 9514 |
|
9404 | 9515 | Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); |
9405 | 9516 | cb(Qcur, "Qcur_normed", il); |
@@ -9509,18 +9620,15 @@ struct llm_build_qwen3moe : public llm_graph_context { |
9509 | 9620 | // self_attention |
9510 | 9621 | { |
9511 | 9622 | // compute Q and K and RoPE them |
9512 | | - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); |
9513 | | - cb(Qcur, "Qcur", il); |
9514 | 9623 |
|
9515 | | - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); |
9516 | | - cb(Kcur, "Kcur", il); |
9517 | | - |
9518 | | - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); |
9519 | | - cb(Vcur, "Vcur", il); |
| 9624 | + ggml_tensor * Qcur = nullptr; |
| 9625 | + ggml_tensor * Kcur = nullptr; |
| 9626 | + ggml_tensor * Vcur = nullptr; |
9520 | 9627 |
|
9521 | | - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); |
9522 | | - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); |
9523 | | - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); |
| 9628 | + build_qkv(model.layers[il], cur, n_embd_head, |
| 9629 | + n_embd_head_k, n_embd_head_v, n_head, n_head_kv, |
| 9630 | + &Qcur, &Kcur, &Vcur, il |
| 9631 | + ); |
9524 | 9632 |
|
9525 | 9633 | Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); |
9526 | 9634 | cb(Qcur, "Qcur_normed", il); |
|
0 commit comments