@@ -2870,3 +2870,240 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
28702870
28712871 return true ;
28722872}
2873+
2874+ //
2875+ // llama_kv_cache_hybrid
2876+ //
2877+ llama_kv_cache_hybrid::llama_kv_cache_hybrid (
2878+ const llama_hparams & hparams,
2879+ std::vector<child_cache> children) :
2880+ m_hparams(hparams),
2881+ m_children(
2882+ [](std::vector<child_cache>& caches) -> std::set<std::unique_ptr<llama_kv_cache>> {
2883+ // Sort the caches by the lowest layer ID so the order is repeatable
2884+ for (auto & cache : caches) {
2885+ GGML_ASSERT (cache.layer_ids .size () > 0 );
2886+ std::sort (cache.layer_ids .begin (), cache.layer_ids .end ());
2887+ }
2888+ std::sort (caches.begin (), caches.end (), [](const child_cache & a, const child_cache & b) {
2889+ return a.layer_ids [0 ] < b.layer_ids [0 ];
2890+ });
2891+ std::set<std::unique_ptr<llama_kv_cache>> unique_caches;
2892+ for (auto & cache : caches) {
2893+ unique_caches.emplace (cache.child .release ());
2894+ }
2895+ return unique_caches;
2896+ }(children)
2897+ ),
2898+ m_has_recurrent (
2899+ [](const std::set<std::unique_ptr<llama_kv_cache>> & caches) -> bool {
2900+ for (const auto & cache : caches) {
2901+ if (dynamic_cast <llama_kv_cache_recurrent *>(cache.get ())) {
2902+ return true ;
2903+ }
2904+ }
2905+ return false ;
2906+ }(m_children)
2907+ )
2908+ {
2909+ // Ensure at least one child
2910+ GGML_ASSERT (m_children.size () > 0 );
2911+
2912+ // Ensure layers are not overlapping and are concurrent
2913+ std::set<size_t > seen_layers;
2914+ size_t max_layer = 0 ;
2915+ for (const auto & cache : children) {
2916+ for (const auto & layer_id : cache.layer_ids ) {
2917+ GGML_ASSERT (seen_layers.find (layer_id) == seen_layers.end ());
2918+ seen_layers.insert (layer_id);
2919+ if (layer_id > max_layer) {
2920+ max_layer = layer_id;
2921+ }
2922+ }
2923+ }
2924+ LLAMA_LOG_DEBUG (" max_layer=%zu, seen_layers.size()=%zu\n " , max_layer, seen_layers.size ());
2925+ GGML_ASSERT (max_layer + 1 == seen_layers.size ());
2926+ }
2927+
2928+ void llama_kv_cache_hybrid::clear () {
2929+ for (const auto & cache : m_children) {
2930+ cache->clear ();
2931+ }
2932+ }
2933+
2934+ bool llama_kv_cache_hybrid::seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
2935+ // First check if we can do this removal. This checks all children so that
2936+ // no mutation happens before we know if it's possible
2937+ if (!can_seq_rm (seq_id, p0, p1)) {
2938+ return false ;
2939+ }
2940+
2941+ // Do the removal from each child which should never fail
2942+ for (const auto & cache : m_children) {
2943+ const bool failed = cache->seq_rm (seq_id, p0, p1);
2944+ GGML_ASSERT (!failed);
2945+ }
2946+ return true ;
2947+ }
2948+
2949+ void llama_kv_cache_hybrid::seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
2950+ for (const auto & cache : m_children) {
2951+ cache->seq_cp (seq_id_src, seq_id_dst, p0, p1);
2952+ }
2953+ }
2954+
2955+ void llama_kv_cache_hybrid::seq_keep (llama_seq_id seq_id) {
2956+ for (const auto & cache : m_children) {
2957+ cache->seq_keep (seq_id);
2958+ }
2959+ }
2960+
2961+ void llama_kv_cache_hybrid::seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
2962+ for (const auto & cache : m_children) {
2963+ cache->seq_add (seq_id, p0, p1, delta);
2964+ }
2965+ }
2966+
2967+ void llama_kv_cache_hybrid::seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
2968+ for (const auto & cache : m_children) {
2969+ cache->seq_div (seq_id, p0, p1, d);
2970+ }
2971+ }
2972+
2973+ llama_pos llama_kv_cache_hybrid::seq_pos_min (llama_seq_id seq_id) const {
2974+ llama_pos min_pos = -1 ;
2975+ for (const auto & cache : m_children) {
2976+ const auto child_min_pos = cache->seq_pos_min (seq_id);
2977+ min_pos = min_pos == -1 ? child_min_pos : std::min (min_pos, child_min_pos);
2978+ }
2979+ return min_pos;
2980+ }
2981+
2982+ llama_pos llama_kv_cache_hybrid::seq_pos_max (llama_seq_id seq_id) const {
2983+ llama_pos max_pos = 0 ;
2984+ for (const auto & cache : m_children) {
2985+ max_pos = std::max (max_pos, cache->seq_pos_max (seq_id));
2986+ }
2987+ return max_pos;
2988+ }
2989+
2990+ void llama_kv_cache_hybrid::restore () {
2991+ for (const auto & cache : m_children) {
2992+ cache->restore ();
2993+ }
2994+ }
2995+
2996+ void llama_kv_cache_hybrid::commit () {
2997+ for (const auto & cache : m_children) {
2998+ cache->commit ();
2999+ }
3000+ }
3001+
3002+ bool llama_kv_cache_hybrid::update (llama_context & ctx) {
3003+ bool updated = false ;
3004+ for (const auto & cache : m_children) {
3005+ updated = cache->update (ctx) || updated;
3006+ }
3007+ return updated;
3008+ }
3009+
3010+ void llama_kv_cache_hybrid::defrag_sched (float thold) {
3011+ for (const auto & cache : m_children) {
3012+ cache->defrag_sched (thold);
3013+ }
3014+ }
3015+
3016+ void llama_kv_cache_hybrid::set_full () {
3017+ for (const auto & cache : m_children) {
3018+ cache->set_full ();
3019+ }
3020+ }
3021+
3022+ bool llama_kv_cache_hybrid::can_seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
3023+ for (const auto & cache : m_children) {
3024+ if (!cache->can_seq_rm (seq_id, p0, p1)) {
3025+ return false ;
3026+ }
3027+ }
3028+ return true ;
3029+ }
3030+
3031+ llama_sbatch llama_kv_cache_hybrid::sbatch_init (const llama_batch & batch, bool logits_all) {
3032+ // If any of the caches are recurrent, require equal split
3033+ return llama_sbatch (batch, m_hparams.n_embd , !m_has_recurrent, logits_all);
3034+ }
3035+
3036+ llama_ubatch llama_kv_cache_hybrid::ubatch_next (llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
3037+ if (embd_pooled) {
3038+ // Pooled embeddings cannot be split across ubatches (yet)
3039+ return sbatch.split_seq (n_ubatch);
3040+ }
3041+ if (m_has_recurrent) {
3042+ return sbatch.split_equal (n_ubatch);
3043+ }
3044+ return sbatch.split_simple (n_ubatch);
3045+ }
3046+
3047+ bool llama_kv_cache_hybrid::find_slot (const llama_ubatch & batch) {
3048+ bool found = true ;
3049+ for (const auto & cache : m_children) {
3050+ found = cache->find_slot (batch) && found;
3051+ }
3052+ return found;
3053+ }
3054+
3055+ int32_t llama_kv_cache_hybrid::get_n_tokens () const {
3056+ // The number of tokens should be the same across all child caches
3057+ int32_t n_tokens = -1 ;
3058+ for (const auto & cache : m_children) {
3059+ const auto cache_n_tokens = cache->get_n_tokens ();
3060+ GGML_ASSERT (n_tokens == -1 || cache_n_tokens == n_tokens);
3061+ n_tokens = cache_n_tokens;
3062+ }
3063+ return n_tokens;
3064+ }
3065+
3066+ int32_t llama_kv_cache_hybrid::get_used_cells () const {
3067+ // TODO: Is this correct?
3068+ // Return the largest number of used cells
3069+ int32_t used_cells = -1 ;
3070+ for (const auto & cache : m_children) {
3071+ used_cells = std::max (used_cells, cache->get_used_cells ());
3072+ }
3073+ return used_cells;
3074+ }
3075+
3076+ llama_pos llama_kv_cache_hybrid::get_pos_max () const {
3077+ llama_pos pos_max = -1 ;
3078+ for (const auto & cache : m_children) {
3079+ pos_max = std::max (pos_max, cache->get_pos_max ());
3080+ }
3081+ return pos_max;
3082+ }
3083+
3084+ bool llama_kv_cache_hybrid::get_can_shift () const {
3085+ // TODO: Is this correct?
3086+ // If any children can shift, return true
3087+ for (const auto & cache : m_children) {
3088+ if (cache->get_can_shift ()) {
3089+ return true ;
3090+ }
3091+ }
3092+ return false ;
3093+ }
3094+
3095+ void llama_kv_cache_hybrid::state_write (llama_io_write_i & io, llama_seq_id seq_id) const {
3096+ // Write each cache state in order. Note that order is guaranteed at
3097+ // initialization by using an ordered set sorted by lowest layer ID
3098+ for (const auto & cache : m_children) {
3099+ cache->state_write (io, seq_id);
3100+ }
3101+ }
3102+
3103+ void llama_kv_cache_hybrid::state_read (llama_io_read_i & io, llama_seq_id seq_id) {
3104+ // Read each cache state in order. Note that order is guaranteed at
3105+ // initialization by using an ordered set sorted by lowest layer ID
3106+ for (const auto & cache : m_children) {
3107+ cache->state_read (io, seq_id);
3108+ }
3109+ }
0 commit comments