@@ -2384,6 +2384,231 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
23842384 return true ;
23852385}
23862386
2387+ //
2388+ // llama_kv_cache_hybrid
2389+ //
2390+ llama_kv_cache_hybrid::llama_kv_cache_hybrid (
2391+ const llama_hparams & hparams,
2392+ const std::vector<child_cache> & children) :
2393+ m_hparams(hparams),
2394+ m_layer_cache_map(
2395+ [](const std::vector<child_cache>& caches) -> std::unordered_map<size_t, llama_kv_cache*> {
2396+ std::unordered_map<size_t , llama_kv_cache*> map;
2397+ for (const auto & cache : caches) {
2398+ for (size_t layer_id : cache.layer_ids ) {
2399+ map[layer_id] = cache.child ;
2400+ }
2401+ }
2402+
2403+ return map;
2404+ }(children)
2405+ ),
2406+ m_children (
2407+ [](std::vector<child_cache> caches) -> std::set<llama_kv_cache*> {
2408+ // Sort the caches by the lowest layer ID so the order is repeatable
2409+ for (auto & cache : caches) {
2410+ GGML_ASSERT (cache.layer_ids .size () > 0 );
2411+ std::sort (cache.layer_ids .begin (), cache.layer_ids .end ());
2412+ }
2413+ std::sort (caches.begin (), caches.end (), [](const child_cache & a, const child_cache & b) {
2414+ return a.layer_ids [0 ] < b.layer_ids [0 ];
2415+ });
2416+ std::set<llama_kv_cache*> unique_caches;
2417+ for (const auto & cache : caches) {
2418+ unique_caches.insert (cache.child );
2419+ }
2420+ return unique_caches;
2421+ }(children)
2422+ ),
2423+ m_has_recurrent (
2424+ [](const std::vector<child_cache>& caches) -> bool {
2425+ for (const auto & cache : caches) {
2426+ if (dynamic_cast <llama_kv_cache_recurrent *>(cache.child )) {
2427+ return true ;
2428+ }
2429+ }
2430+ return false ;
2431+ }(children)
2432+ )
2433+ {
2434+ // Ensure at least one child
2435+ GGML_ASSERT (m_children.size () > 0 );
2436+
2437+ // Ensure layers are not overlapping and are concurrent
2438+ std::set<size_t > seen_layers;
2439+ size_t max_layer = 0 ;
2440+ for (const auto & cache : children) {
2441+ for (const auto & layer_id : cache.layer_ids ) {
2442+ GGML_ASSERT (seen_layers.find (layer_id) == seen_layers.end ());
2443+ seen_layers.insert (layer_id);
2444+ if (layer_id > max_layer) {
2445+ max_layer = layer_id;
2446+ }
2447+ }
2448+ }
2449+ GGML_ASSERT (max_layer == seen_layers.size ());
2450+ }
2451+
2452+ void llama_kv_cache_hybrid::clear () {
2453+ for (const auto & cache : m_children) {
2454+ cache->clear ();
2455+ }
2456+ }
2457+
2458+ bool llama_kv_cache_hybrid::seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
2459+ // TODO: Will it cause problems if some caches are able to remove the seq
2460+ // but others aren't?
2461+ bool removed = true ;
2462+ for (const auto & cache : m_children) {
2463+ removed = cache->seq_rm (seq_id, p0, p1) && removed;
2464+ }
2465+ return removed;
2466+ }
2467+
2468+ void llama_kv_cache_hybrid::seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
2469+ for (const auto & cache : m_children) {
2470+ cache->seq_cp (seq_id_src, seq_id_dst, p0, p1);
2471+ }
2472+ }
2473+
2474+ void llama_kv_cache_hybrid::seq_keep (llama_seq_id seq_id) {
2475+ for (const auto & cache : m_children) {
2476+ cache->seq_keep (seq_id);
2477+ }
2478+ }
2479+
2480+ void llama_kv_cache_hybrid::seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
2481+ for (const auto & cache : m_children) {
2482+ cache->seq_add (seq_id, p0, p1, delta);
2483+ }
2484+ }
2485+
2486+ void llama_kv_cache_hybrid::seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
2487+ for (const auto & cache : m_children) {
2488+ cache->seq_div (seq_id, p0, p1, d);
2489+ }
2490+ }
2491+
2492+ llama_pos llama_kv_cache_hybrid::seq_pos_max (llama_seq_id seq_id) const {
2493+ llama_pos max_pos = 0 ;
2494+ for (const auto & cache : m_children) {
2495+ max_pos = std::max (max_pos, cache->seq_pos_max (seq_id));
2496+ }
2497+ return max_pos;
2498+ }
2499+
2500+ void llama_kv_cache_hybrid::restore () {
2501+ for (const auto & cache : m_children) {
2502+ cache->restore ();
2503+ }
2504+ }
2505+
2506+ void llama_kv_cache_hybrid::commit () {
2507+ for (const auto & cache : m_children) {
2508+ cache->commit ();
2509+ }
2510+ }
2511+
2512+ bool llama_kv_cache_hybrid::update (llama_context & ctx) {
2513+ bool updated = false ;
2514+ for (const auto & cache : m_children) {
2515+ updated = cache->update (ctx) || updated;
2516+ }
2517+ return updated;
2518+ }
2519+
2520+ void llama_kv_cache_hybrid::defrag_sched (float thold) {
2521+ for (const auto & cache : m_children) {
2522+ cache->defrag_sched (thold);
2523+ }
2524+ }
2525+
2526+ void llama_kv_cache_hybrid::set_full () {
2527+ for (const auto & cache : m_children) {
2528+ cache->set_full ();
2529+ }
2530+ }
2531+
2532+ llama_sbatch llama_kv_cache_hybrid::sbatch_init (const llama_batch & batch, bool logits_all) {
2533+ // If any of the caches are recurrent, require simple split
2534+ return llama_sbatch (batch, m_hparams.n_embd , m_has_recurrent, logits_all);
2535+ }
2536+
2537+ llama_ubatch llama_kv_cache_hybrid::ubatch_next (llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
2538+ if (m_has_recurrent) {
2539+ return sbatch.split_simple (n_ubatch);
2540+ }
2541+ if (embd_pooled) {
2542+ // Pooled embeddings cannot be split across ubatches (yet)
2543+ return sbatch.split_seq (n_ubatch);
2544+ }
2545+ return sbatch.split_equal (n_ubatch);
2546+ }
2547+
2548+ bool llama_kv_cache_hybrid::find_slot (const llama_ubatch & batch) {
2549+ bool found = true ;
2550+ for (const auto & cache : m_children) {
2551+ found = cache->find_slot (batch) && found;
2552+ }
2553+ return found;
2554+ }
2555+
2556+ int32_t llama_kv_cache_hybrid::get_n_tokens () const {
2557+ // The number of tokens should be the same across all child caches
2558+ int32_t n_tokens = -1 ;
2559+ for (const auto & cache : m_children) {
2560+ const auto cache_n_tokens = cache->get_n_tokens ();
2561+ GGML_ASSERT (n_tokens == -1 || cache_n_tokens == n_tokens);
2562+ n_tokens = cache_n_tokens;
2563+ }
2564+ return n_tokens;
2565+ }
2566+
2567+ int32_t llama_kv_cache_hybrid::get_used_cells () const {
2568+ // TODO: Is this correct?
2569+ // Return the largetst number of used cells
2570+ int32_t used_cells = -1 ;
2571+ for (const auto & cache : m_children) {
2572+ used_cells = std::max (used_cells, cache->get_used_cells ());
2573+ }
2574+ return used_cells;
2575+ }
2576+
2577+ llama_pos llama_kv_cache_hybrid::get_pos_max () const {
2578+ llama_pos pos_max = -1 ;
2579+ for (const auto & cache : m_children) {
2580+ pos_max = std::max (pos_max, cache->get_pos_max ());
2581+ }
2582+ return pos_max;
2583+ }
2584+
2585+ bool llama_kv_cache_hybrid::get_can_shift () const {
2586+ // TODO: Is this correct?
2587+ // If any children can shift, return true
2588+ for (const auto & cache : m_children) {
2589+ if (cache->get_can_shift ()) {
2590+ return true ;
2591+ }
2592+ }
2593+ return false ;
2594+ }
2595+
2596+ void llama_kv_cache_hybrid::state_write (llama_io_write_i & io, llama_seq_id seq_id) const {
2597+ // Write each cache state in order. Note that order is guaranteed at
2598+ // initialization by using an ordered set sorted by lowest layer ID
2599+ for (const auto & cache : m_children) {
2600+ cache->state_write (io, seq_id);
2601+ }
2602+ }
2603+
2604+ void llama_kv_cache_hybrid::state_read (llama_io_read_i & io, llama_seq_id seq_id) {
2605+ // Read each cache state in order. Note that order is guaranteed at
2606+ // initialization by using an ordered set sorted by lowest layer ID
2607+ for (const auto & cache : m_children) {
2608+ cache->state_read (io, seq_id);
2609+ }
2610+ }
2611+
23872612//
23882613// kv cache view
23892614//
0 commit comments