@@ -2782,3 +2782,240 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
27822782
27832783 return true ;
27842784}
2785+
2786+ //
2787+ // llama_kv_cache_hybrid
2788+ //
2789+ llama_kv_cache_hybrid::llama_kv_cache_hybrid (
2790+ const llama_hparams & hparams,
2791+ std::vector<child_cache> children) :
2792+ m_hparams(hparams),
2793+ m_children(
2794+ [](std::vector<child_cache>& caches) -> std::set<std::unique_ptr<llama_kv_cache>> {
2795+ // Sort the caches by the lowest layer ID so the order is repeatable
2796+ for (auto & cache : caches) {
2797+ GGML_ASSERT (cache.layer_ids .size () > 0 );
2798+ std::sort (cache.layer_ids .begin (), cache.layer_ids .end ());
2799+ }
2800+ std::sort (caches.begin (), caches.end (), [](const child_cache & a, const child_cache & b) {
2801+ return a.layer_ids [0 ] < b.layer_ids [0 ];
2802+ });
2803+ std::set<std::unique_ptr<llama_kv_cache>> unique_caches;
2804+ for (auto & cache : caches) {
2805+ unique_caches.emplace (cache.child .release ());
2806+ }
2807+ return unique_caches;
2808+ }(children)
2809+ ),
2810+ m_has_recurrent (
2811+ [](const std::set<std::unique_ptr<llama_kv_cache>> & caches) -> bool {
2812+ for (const auto & cache : caches) {
2813+ if (dynamic_cast <llama_kv_cache_recurrent *>(cache.get ())) {
2814+ return true ;
2815+ }
2816+ }
2817+ return false ;
2818+ }(m_children)
2819+ )
2820+ {
2821+ // Ensure at least one child
2822+ GGML_ASSERT (m_children.size () > 0 );
2823+
2824+ // Ensure layers are not overlapping and are concurrent
2825+ std::set<size_t > seen_layers;
2826+ size_t max_layer = 0 ;
2827+ for (const auto & cache : children) {
2828+ for (const auto & layer_id : cache.layer_ids ) {
2829+ GGML_ASSERT (seen_layers.find (layer_id) == seen_layers.end ());
2830+ seen_layers.insert (layer_id);
2831+ if (layer_id > max_layer) {
2832+ max_layer = layer_id;
2833+ }
2834+ }
2835+ }
2836+ LLAMA_LOG_DEBUG (" max_layer=%zu, seen_layers.size()=%zu\n " , max_layer, seen_layers.size ());
2837+ GGML_ASSERT (max_layer + 1 == seen_layers.size ());
2838+ }
2839+
2840+ void llama_kv_cache_hybrid::clear () {
2841+ for (const auto & cache : m_children) {
2842+ cache->clear ();
2843+ }
2844+ }
2845+
2846+ bool llama_kv_cache_hybrid::seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
2847+ // First check if we can do this removal. This checks all children so that
2848+ // no mutation happens before we know if it's possible
2849+ if (!can_seq_rm (seq_id, p0, p1)) {
2850+ return false ;
2851+ }
2852+
2853+ // Do the removal from each child which should never fail
2854+ for (const auto & cache : m_children) {
2855+ const bool failed = cache->seq_rm (seq_id, p0, p1);
2856+ GGML_ASSERT (!failed);
2857+ }
2858+ return true ;
2859+ }
2860+
2861+ void llama_kv_cache_hybrid::seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
2862+ for (const auto & cache : m_children) {
2863+ cache->seq_cp (seq_id_src, seq_id_dst, p0, p1);
2864+ }
2865+ }
2866+
2867+ void llama_kv_cache_hybrid::seq_keep (llama_seq_id seq_id) {
2868+ for (const auto & cache : m_children) {
2869+ cache->seq_keep (seq_id);
2870+ }
2871+ }
2872+
2873+ void llama_kv_cache_hybrid::seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
2874+ for (const auto & cache : m_children) {
2875+ cache->seq_add (seq_id, p0, p1, delta);
2876+ }
2877+ }
2878+
2879+ void llama_kv_cache_hybrid::seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
2880+ for (const auto & cache : m_children) {
2881+ cache->seq_div (seq_id, p0, p1, d);
2882+ }
2883+ }
2884+
2885+ llama_pos llama_kv_cache_hybrid::seq_pos_min (llama_seq_id seq_id) const {
2886+ llama_pos min_pos = -1 ;
2887+ for (const auto & cache : m_children) {
2888+ const auto child_min_pos = cache->seq_pos_min (seq_id);
2889+ min_pos = min_pos == -1 ? child_min_pos : std::min (min_pos, child_min_pos);
2890+ }
2891+ return min_pos;
2892+ }
2893+
2894+ llama_pos llama_kv_cache_hybrid::seq_pos_max (llama_seq_id seq_id) const {
2895+ llama_pos max_pos = 0 ;
2896+ for (const auto & cache : m_children) {
2897+ max_pos = std::max (max_pos, cache->seq_pos_max (seq_id));
2898+ }
2899+ return max_pos;
2900+ }
2901+
2902+ void llama_kv_cache_hybrid::restore () {
2903+ for (const auto & cache : m_children) {
2904+ cache->restore ();
2905+ }
2906+ }
2907+
2908+ void llama_kv_cache_hybrid::commit () {
2909+ for (const auto & cache : m_children) {
2910+ cache->commit ();
2911+ }
2912+ }
2913+
2914+ bool llama_kv_cache_hybrid::update (llama_context & ctx) {
2915+ bool updated = false ;
2916+ for (const auto & cache : m_children) {
2917+ updated = cache->update (ctx) || updated;
2918+ }
2919+ return updated;
2920+ }
2921+
2922+ void llama_kv_cache_hybrid::defrag_sched (float thold) {
2923+ for (const auto & cache : m_children) {
2924+ cache->defrag_sched (thold);
2925+ }
2926+ }
2927+
2928+ void llama_kv_cache_hybrid::set_full () {
2929+ for (const auto & cache : m_children) {
2930+ cache->set_full ();
2931+ }
2932+ }
2933+
2934+ bool llama_kv_cache_hybrid::can_seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
2935+ for (const auto & cache : m_children) {
2936+ if (!cache->can_seq_rm (seq_id, p0, p1)) {
2937+ return false ;
2938+ }
2939+ }
2940+ return true ;
2941+ }
2942+
2943+ llama_sbatch llama_kv_cache_hybrid::sbatch_init (const llama_batch & batch, bool logits_all) {
2944+ // If any of the caches are recurrent, require equal split
2945+ return llama_sbatch (batch, m_hparams.n_embd , !m_has_recurrent, logits_all);
2946+ }
2947+
2948+ llama_ubatch llama_kv_cache_hybrid::ubatch_next (llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
2949+ if (embd_pooled) {
2950+ // Pooled embeddings cannot be split across ubatches (yet)
2951+ return sbatch.split_seq (n_ubatch);
2952+ }
2953+ if (m_has_recurrent) {
2954+ return sbatch.split_equal (n_ubatch);
2955+ }
2956+ return sbatch.split_simple (n_ubatch);
2957+ }
2958+
2959+ bool llama_kv_cache_hybrid::find_slot (const llama_ubatch & batch) {
2960+ bool found = true ;
2961+ for (const auto & cache : m_children) {
2962+ found = cache->find_slot (batch) && found;
2963+ }
2964+ return found;
2965+ }
2966+
2967+ int32_t llama_kv_cache_hybrid::get_n_tokens () const {
2968+ // The number of tokens should be the same across all child caches
2969+ int32_t n_tokens = -1 ;
2970+ for (const auto & cache : m_children) {
2971+ const auto cache_n_tokens = cache->get_n_tokens ();
2972+ GGML_ASSERT (n_tokens == -1 || cache_n_tokens == n_tokens);
2973+ n_tokens = cache_n_tokens;
2974+ }
2975+ return n_tokens;
2976+ }
2977+
2978+ int32_t llama_kv_cache_hybrid::get_used_cells () const {
2979+ // TODO: Is this correct?
2980+ // Return the largest number of used cells
2981+ int32_t used_cells = -1 ;
2982+ for (const auto & cache : m_children) {
2983+ used_cells = std::max (used_cells, cache->get_used_cells ());
2984+ }
2985+ return used_cells;
2986+ }
2987+
2988+ llama_pos llama_kv_cache_hybrid::get_pos_max () const {
2989+ llama_pos pos_max = -1 ;
2990+ for (const auto & cache : m_children) {
2991+ pos_max = std::max (pos_max, cache->get_pos_max ());
2992+ }
2993+ return pos_max;
2994+ }
2995+
2996+ bool llama_kv_cache_hybrid::get_can_shift () const {
2997+ // TODO: Is this correct?
2998+ // If any children can shift, return true
2999+ for (const auto & cache : m_children) {
3000+ if (cache->get_can_shift ()) {
3001+ return true ;
3002+ }
3003+ }
3004+ return false ;
3005+ }
3006+
3007+ void llama_kv_cache_hybrid::state_write (llama_io_write_i & io, llama_seq_id seq_id) const {
3008+ // Write each cache state in order. Note that order is guaranteed at
3009+ // initialization by using an ordered set sorted by lowest layer ID
3010+ for (const auto & cache : m_children) {
3011+ cache->state_write (io, seq_id);
3012+ }
3013+ }
3014+
3015+ void llama_kv_cache_hybrid::state_read (llama_io_read_i & io, llama_seq_id seq_id) {
3016+ // Read each cache state in order. Note that order is guaranteed at
3017+ // initialization by using an ordered set sorted by lowest layer ID
3018+ for (const auto & cache : m_children) {
3019+ cache->state_read (io, seq_id);
3020+ }
3021+ }
0 commit comments