|
6 | 6 | #include "llama-model.h" |
7 | 7 | #include "llama-kv-cache.h" |
8 | 8 |
|
| 9 | +#include <cinttypes> |
9 | 10 | #include <cstring> |
| 11 | +#include <limits> |
10 | 12 | #include <stdexcept> |
11 | | -#include <cinttypes> |
12 | 13 |
|
13 | 14 | // |
14 | 15 | // llama_context |
@@ -632,6 +633,49 @@ bool llama_context::apply_adapter_cvec( |
632 | 633 | return cvec.apply(model, data, len, n_embd, il_start, il_end); |
633 | 634 | } |
634 | 635 |
|
| 636 | +llm_graph_result_ptr llama_context::process(const llama_ubatch & ubatch, llm_graph_type gtype, ggml_status * ret) { |
| 637 | + auto * gf = graph_init(); |
| 638 | + if (!gf) { |
| 639 | + LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__); |
| 640 | + if (ret) { |
| 641 | + *ret = GGML_STATUS_FAILED; |
| 642 | + } |
| 643 | + return nullptr; |
| 644 | + } |
| 645 | + |
| 646 | + auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype); |
| 647 | + if (!res) { |
| 648 | + LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__); |
| 649 | + if (ret) { |
| 650 | + *ret = GGML_STATUS_FAILED; |
| 651 | + } |
| 652 | + return nullptr; |
| 653 | + } |
| 654 | + |
| 655 | + // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); |
| 656 | + |
| 657 | + if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) { |
| 658 | + LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__); |
| 659 | + if (ret) { |
| 660 | + *ret = GGML_STATUS_ALLOC_FAILED; |
| 661 | + } |
| 662 | + return nullptr; |
| 663 | + } |
| 664 | + |
| 665 | + res->set_inputs(&ubatch); |
| 666 | + |
| 667 | + const auto status = graph_compute(gf, ubatch.n_tokens > 1); |
| 668 | + if (status != GGML_STATUS_SUCCESS) { |
| 669 | + LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status); |
| 670 | + if (ret) { |
| 671 | + *ret = status; |
| 672 | + } |
| 673 | + return nullptr; |
| 674 | + } |
| 675 | + |
| 676 | + return res; |
| 677 | +} |
| 678 | + |
635 | 679 | int llama_context::encode(llama_batch & inp_batch) { |
636 | 680 | if (inp_batch.n_tokens == 0) { |
637 | 681 | LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); |
@@ -703,26 +747,18 @@ int llama_context::encode(llama_batch & inp_batch) { |
703 | 747 | // ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223 |
704 | 748 | cparams.causal_attn = false; |
705 | 749 |
|
706 | | - auto * gf = graph_init(); |
707 | | - auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_ENCODER); |
708 | | - |
709 | | - ggml_backend_sched_alloc_graph(sched.get(), gf); |
710 | | - |
711 | | - res->set_inputs(&ubatch); |
| 750 | + ggml_status status; |
| 751 | + auto res = process(ubatch, LLM_GRAPH_TYPE_ENCODER, &status); |
712 | 752 |
|
713 | 753 | cparams.causal_attn = causal_attn_org; |
714 | 754 |
|
715 | | - const auto compute_status = graph_compute(gf, n_tokens > 1); |
716 | | - switch (compute_status) { |
717 | | - case GGML_STATUS_SUCCESS: |
718 | | - break; |
719 | | - case GGML_STATUS_ABORTED: |
720 | | - return 2; |
721 | | - case GGML_STATUS_ALLOC_FAILED: |
722 | | - return -2; |
723 | | - case GGML_STATUS_FAILED: |
724 | | - default: |
725 | | - return -3; |
| 755 | + if (!res) { |
| 756 | + switch (status) { |
| 757 | + case GGML_STATUS_ABORTED: return 2; |
| 758 | + case GGML_STATUS_ALLOC_FAILED: return -2; |
| 759 | + case GGML_STATUS_FAILED: return -3; |
| 760 | + case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen"); |
| 761 | + } |
726 | 762 | } |
727 | 763 |
|
728 | 764 | auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); |
@@ -942,25 +978,34 @@ int llama_context::decode(llama_batch & inp_batch) { |
942 | 978 | ggml_backend_sched_reset(sched.get()); |
943 | 979 | ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); |
944 | 980 |
|
945 | | - auto * gf = graph_init(); |
946 | | - auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DECODER); |
| 981 | + ggml_status status; |
| 982 | + auto res = process(ubatch, LLM_GRAPH_TYPE_DECODER, &status); |
| 983 | + |
| 984 | + if (!res) { |
| 985 | + // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache |
| 986 | + llama_pos pos_min[LLAMA_MAX_PARALLEL_SEQUENCES] = { std::numeric_limits<llama_pos>::max() }; |
| 987 | + |
| 988 | + for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { |
| 989 | + const auto & seq_id = ubatch.seq_id[i][0]; |
947 | 990 |
|
948 | | - // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); |
| 991 | + pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]); |
| 992 | + } |
949 | 993 |
|
950 | | - ggml_backend_sched_alloc_graph(sched.get(), gf); |
| 994 | + for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { |
| 995 | + if (pos_min[s] == std::numeric_limits<llama_pos>::max()) { |
| 996 | + continue; |
| 997 | + } |
951 | 998 |
|
952 | | - res->set_inputs(&ubatch); |
| 999 | + LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]); |
| 1000 | + |
| 1001 | + llama_kv_self_seq_rm(this, s, pos_min[s], -1); |
| 1002 | + } |
953 | 1003 |
|
954 | | - const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1); |
955 | | - if (compute_status != GGML_STATUS_SUCCESS) { |
956 | | - switch (compute_status) { |
957 | | - case GGML_STATUS_ABORTED: |
958 | | - return 2; |
959 | | - case GGML_STATUS_ALLOC_FAILED: |
960 | | - return -2; |
961 | | - case GGML_STATUS_FAILED: |
962 | | - default: |
963 | | - return -3; |
| 1004 | + switch (status) { |
| 1005 | + case GGML_STATUS_ABORTED: return 2; |
| 1006 | + case GGML_STATUS_ALLOC_FAILED: return -2; |
| 1007 | + case GGML_STATUS_FAILED: return -3; |
| 1008 | + case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen"); |
964 | 1009 | } |
965 | 1010 | } |
966 | 1011 |
|
|
0 commit comments