@@ -2533,7 +2533,7 @@ void llama_context_kv_self::kv_self_update() {
2533
2533
2534
2534
auto * gf = graph_init ();
2535
2535
2536
- kv_self. build_shift (ctx_compute.get (), gf, this );
2536
+ build_kv_self_shift (ctx_compute.get (), gf);
2537
2537
2538
2538
ggml_backend_sched_alloc_graph (sched.get (), gf);
2539
2539
@@ -2559,7 +2559,7 @@ void llama_context_kv_self::kv_self_update() {
2559
2559
2560
2560
auto * gf = graph_init ();
2561
2561
2562
- kv_self. build_defrag (ctx_compute.get (), gf, max_nodes (), !cparams. flash_attn );
2562
+ build_kv_self_defrag (ctx_compute.get (), gf);
2563
2563
2564
2564
ggml_backend_sched_alloc_graph (sched.get (), gf);
2565
2565
@@ -2817,6 +2817,309 @@ ggml_tensor * llama_context_kv_self::build_attn_soft_max(
2817
2817
return ggml_soft_max_ext (ctx0, kq, inp_KQ_mask_cnv, kq_scale, hparams.f_max_alibi_bias );
2818
2818
}
2819
2819
2820
+ void llama_context_kv_self::build_kv_self_shift (
2821
+ ggml_context * ctx0,
2822
+ ggml_cgraph * gf) {
2823
+ const auto & hparams = model.hparams ;
2824
+
2825
+ const auto & n_layer = hparams.n_layer ;
2826
+
2827
+ const auto & n_embd_head_k = hparams.n_embd_head_k ;
2828
+ // const auto & n_embd_head_v = hparams.n_embd_head_v;
2829
+
2830
+ // GGML_ASSERT(kv_self.size == n_ctx);
2831
+
2832
+ ggml_tensor * inp_k_shift = build_inp_k_shift (ctx0);
2833
+
2834
+ for (uint32_t il = 0 ; il < n_layer; ++il) {
2835
+ const int64_t n_head_kv = hparams.n_head_kv (il);
2836
+ const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa (il);
2837
+
2838
+ struct ggml_tensor * rope_factors = build_rope_factors (il);
2839
+
2840
+ struct ggml_tensor * k =
2841
+ ggml_view_3d (ctx0, kv_self.k_l [il],
2842
+ n_embd_head_k, n_head_kv, kv_self.size ,
2843
+ ggml_row_size (kv_self.k_l [il]->type , n_embd_head_k),
2844
+ ggml_row_size (kv_self.k_l [il]->type , n_embd_k_gqa),
2845
+ 0 );
2846
+
2847
+ ggml_tensor * cur = build_rope_shift (ctx0, k, inp_k_shift, rope_factors, kv_self.k_l [il]->buffer );
2848
+
2849
+ ggml_build_forward_expand (gf, cur);
2850
+ }
2851
+ }
2852
+
2853
+ void llama_context_kv_self::build_kv_self_defrag (
2854
+ ggml_context * ctx0,
2855
+ ggml_cgraph * gf) {
2856
+ const auto & hparams = model.hparams ;
2857
+
2858
+ const uint32_t n_layer = hparams.n_layer ;
2859
+
2860
+ const uint32_t n_kv = kv_self.cell_max ();
2861
+ const uint32_t n_used = kv_self.used ;
2862
+
2863
+ assert (n_used <= n_kv);
2864
+
2865
+ // const int64_t t_start = ggml_time_us();
2866
+
2867
+ // number of cells moved
2868
+ uint32_t n_moves = 0 ;
2869
+
2870
+ // each move requires 6*n_layer tensors (see build_kv_self_defrag)
2871
+ // - source view, destination view, copy operation
2872
+ // - x2 for keys and values
2873
+ // const uint32_t max_moves = max_nodes()/(6*n_layer);
2874
+ // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516
2875
+ const uint32_t max_moves = (max_nodes () - 2 *n_layer)/(6 *n_layer);
2876
+
2877
+ // determine which KV cells to move where
2878
+ //
2879
+ // cell i moves to ids[i]
2880
+ //
2881
+ // if ids[i] == i || ids[i] == n_kv, then cell i is not moved
2882
+ //
2883
+ std::vector<uint32_t > ids (n_kv, n_kv);
2884
+
2885
+ for (uint32_t i0 = 0 ; i0 < n_used; ++i0) {
2886
+ const auto & cell0 = kv_self.cells [i0];
2887
+
2888
+ if (!cell0.is_empty ()) {
2889
+ ids[i0] = i0;
2890
+
2891
+ continue ;
2892
+ }
2893
+
2894
+ // found a hole - fill it with data from the end of the cache
2895
+
2896
+ uint32_t nh = 1 ;
2897
+
2898
+ // determine the size of the hole
2899
+ while (i0 + nh < n_used && kv_self.cells [i0 + nh].is_empty ()) {
2900
+ nh++;
2901
+ }
2902
+
2903
+ uint32_t nf = 0 ;
2904
+ uint32_t is = n_kv - 1 ;
2905
+
2906
+ // starting from the end, find nh non-empty cells
2907
+ for (; is > i0; --is) {
2908
+ const auto & cell1 = kv_self.cells [is];
2909
+
2910
+ if (cell1.is_empty () || ids[is] != n_kv) {
2911
+ continue ;
2912
+ }
2913
+
2914
+ // non-empty cell which is not yet moved
2915
+ nf++;
2916
+
2917
+ if (nf == nh) {
2918
+ break ;
2919
+ }
2920
+ }
2921
+
2922
+ // this can only happen if `n_used` is not accurate, which would be a bug
2923
+ GGML_ASSERT (nf == nh && " KV defrag bug: nf != nh" );
2924
+
2925
+ nf = 0 ;
2926
+
2927
+ uint32_t i1 = is;
2928
+
2929
+ // are we moving a continuous block of memory?
2930
+ bool cont = false ;
2931
+
2932
+ // should we stop searching for the next move?
2933
+ bool stop = false ;
2934
+
2935
+ // go back and move the nf cells to the hole
2936
+ for (; i1 < n_kv; ++i1) {
2937
+ auto & cell1 = kv_self.cells [i1];
2938
+
2939
+ if (cell1.is_empty () || ids[i1] != n_kv) {
2940
+ if (n_moves == max_moves) {
2941
+ stop = true ;
2942
+ break ;
2943
+ }
2944
+
2945
+ cont = false ;
2946
+ continue ;
2947
+ }
2948
+
2949
+ // this cell goes to (i0 + nf)
2950
+ ids[i1] = i0 + nf;
2951
+
2952
+ // move the cell meta data
2953
+ kv_self.cells [i0 + nf] = cell1;
2954
+
2955
+ // clear the old cell and move the head there
2956
+ cell1 = llama_kv_cell ();
2957
+ kv_self.head = n_used;
2958
+
2959
+ if (!cont) {
2960
+ n_moves++;
2961
+ cont = true ;
2962
+ }
2963
+
2964
+ nf++;
2965
+
2966
+ if (nf == nh) {
2967
+ break ;
2968
+ }
2969
+ }
2970
+
2971
+ if (stop || n_moves == max_moves) {
2972
+ break ;
2973
+ }
2974
+
2975
+ // LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
2976
+
2977
+ i0 += nh - 1 ;
2978
+ }
2979
+
2980
+ if (n_moves == 0 ) {
2981
+ return ;
2982
+ }
2983
+
2984
+ // LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", n_moves);
2985
+
2986
+ // LLAMA_LOG_INFO("expected gf nodes: %u\n", 6*n_moves*n_layer);
2987
+
2988
+ #if 0
2989
+ // CPU defrag
2990
+ //
2991
+ // TODO: optimizations are possible:
2992
+ // - multiple threads
2993
+ // - avoid copying to the host memory when already there
2994
+ //
2995
+ // likely not worth the effort, as we have ggml_graph based defrag
2996
+ //
2997
+
2998
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
2999
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
3000
+
3001
+ const uint32_t kv_size = size;
3002
+
3003
+ std::vector<uint8_t> buf_k;
3004
+ std::vector<uint8_t> buf_v;
3005
+
3006
+ for (uint32_t il = 0; il < n_layer; ++il) {
3007
+ const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
3008
+ const size_t k_size = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size);
3009
+
3010
+ const size_t v_size_el = ggml_type_size(v_l[il]->type);
3011
+ const size_t v_size = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size);
3012
+
3013
+ buf_k.resize(k_size);
3014
+ buf_v.resize(v_size);
3015
+
3016
+ ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size());
3017
+ ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size());
3018
+
3019
+ // batch move [i, i+nm) to [id, id+nm)
3020
+ // note: cells can move only to a lower index
3021
+ for (uint32_t i = 0; i < n_kv; ++i) {
3022
+ const uint32_t id = ids[i];
3023
+
3024
+ if (i == id || id == n_kv) {
3025
+ continue;
3026
+ }
3027
+
3028
+ uint32_t nm = 1;
3029
+
3030
+ while (i + nm < n_kv && ids[i + nm] == id + nm) {
3031
+ nm++;
3032
+ }
3033
+
3034
+ // move keys
3035
+ {
3036
+ const int64_t os = i*k_size_row;
3037
+ const int64_t od = id*k_size_row;
3038
+
3039
+ memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row);
3040
+ }
3041
+
3042
+ // move values (note: they are transposed)
3043
+ {
3044
+ const int64_t os = i;
3045
+ const int64_t od = id;
3046
+
3047
+ for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
3048
+ memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el);
3049
+ }
3050
+ }
3051
+
3052
+ i += nm - 1;
3053
+ }
3054
+
3055
+ ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size());
3056
+ ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size());
3057
+ }
3058
+ #else
3059
+ for (uint32_t i = 0 ; i < ids.size (); ++i) {
3060
+ const uint32_t id = ids[i];
3061
+
3062
+ if (i == id || id == ids.size ()) {
3063
+ continue ;
3064
+ }
3065
+
3066
+ uint32_t nm = 1 ;
3067
+
3068
+ while (i + nm < ids.size () && ids[i + nm] == id + nm) {
3069
+ nm++;
3070
+ }
3071
+
3072
+ for (uint32_t il = 0 ; il < n_layer; ++il) { // NOLINT
3073
+ const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa (il);
3074
+ const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa (il);
3075
+
3076
+ ggml_tensor * view_k_src = ggml_view_2d (ctx0, kv_self.k_l [il],
3077
+ n_embd_k_gqa, nm,
3078
+ ggml_row_size (kv_self.k_l [il]->type , n_embd_k_gqa),
3079
+ ggml_row_size (kv_self.k_l [il]->type , n_embd_k_gqa*i));
3080
+
3081
+ ggml_tensor * view_k_dst = ggml_view_2d (ctx0, kv_self.k_l [il],
3082
+ n_embd_k_gqa, nm,
3083
+ ggml_row_size (kv_self.k_l [il]->type , n_embd_k_gqa),
3084
+ ggml_row_size (kv_self.k_l [il]->type , n_embd_k_gqa*id));
3085
+
3086
+ ggml_tensor * view_v_src;
3087
+ ggml_tensor * view_v_dst;
3088
+
3089
+ if (cparams.flash_attn ) {
3090
+ // NOTE: the V cache is not transposed when using flash attention
3091
+ view_v_src = ggml_view_2d (ctx0, kv_self.v_l [il],
3092
+ n_embd_v_gqa, nm,
3093
+ ggml_row_size (kv_self.v_l [il]->type , n_embd_v_gqa),
3094
+ ggml_row_size (kv_self.v_l [il]->type , n_embd_v_gqa*i));
3095
+
3096
+ view_v_dst = ggml_view_2d (ctx0, kv_self.v_l [il],
3097
+ n_embd_v_gqa, nm,
3098
+ ggml_row_size (kv_self.v_l [il]->type , n_embd_v_gqa),
3099
+ ggml_row_size (kv_self.v_l [il]->type , n_embd_v_gqa*id));
3100
+ } else {
3101
+ view_v_src = ggml_view_2d (ctx0, kv_self.v_l [il],
3102
+ nm, n_embd_v_gqa,
3103
+ ggml_row_size (kv_self.v_l [il]->type , kv_self.size ),
3104
+ ggml_row_size (kv_self.v_l [il]->type , i));
3105
+
3106
+ view_v_dst = ggml_view_2d (ctx0, kv_self.v_l [il],
3107
+ nm, n_embd_v_gqa,
3108
+ ggml_row_size (kv_self.v_l [il]->type , kv_self.size ),
3109
+ ggml_row_size (kv_self.v_l [il]->type , id));
3110
+ }
3111
+
3112
+ ggml_build_forward_expand (gf, ggml_cpy (ctx0, view_k_src, view_k_dst));
3113
+ ggml_build_forward_expand (gf, ggml_cpy (ctx0, view_v_src, view_v_dst));
3114
+ }
3115
+
3116
+ i += nm - 1 ;
3117
+ }
3118
+
3119
+ // LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
3120
+ #endif
3121
+ }
3122
+
2820
3123
ggml_tensor * llama_context_kv_self::build_inp_embd_enc (
2821
3124
ggml_context * ctx0,
2822
3125
int32_t n_tokens,
0 commit comments