@@ -2354,6 +2354,231 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
23542354 return true ;
23552355}
23562356
2357+ //
2358+ // llama_kv_cache_hybrid
2359+ //
2360+ llama_kv_cache_hybrid::llama_kv_cache_hybrid (
2361+ const llama_hparams & hparams,
2362+ const std::vector<child_cache> & children) :
2363+ m_hparams(hparams),
2364+ m_layer_cache_map(
2365+ [](const std::vector<child_cache>& caches) -> std::unordered_map<size_t, llama_kv_cache*> {
2366+ std::unordered_map<size_t , llama_kv_cache*> map;
2367+ for (const auto & cache : caches) {
2368+ for (size_t layer_id : cache.layer_ids ) {
2369+ map[layer_id] = cache.child ;
2370+ }
2371+ }
2372+
2373+ return map;
2374+ }(children)
2375+ ),
2376+ m_children (
2377+ [](std::vector<child_cache> caches) -> std::set<llama_kv_cache*> {
2378+ // Sort the caches by the lowest layer ID so the order is repeatable
2379+ for (auto & cache : caches) {
2380+ GGML_ASSERT (cache.layer_ids .size () > 0 );
2381+ std::sort (cache.layer_ids .begin (), cache.layer_ids .end ());
2382+ }
2383+ std::sort (caches.begin (), caches.end (), [](const child_cache & a, const child_cache & b) {
2384+ return a.layer_ids [0 ] < b.layer_ids [0 ];
2385+ });
2386+ std::set<llama_kv_cache*> unique_caches;
2387+ for (const auto & cache : caches) {
2388+ unique_caches.insert (cache.child );
2389+ }
2390+ return unique_caches;
2391+ }(children)
2392+ ),
2393+ m_has_recurrent (
2394+ [](const std::vector<child_cache>& caches) -> bool {
2395+ for (const auto & cache : caches) {
2396+ if (dynamic_cast <llama_kv_cache_recurrent *>(cache.child )) {
2397+ return true ;
2398+ }
2399+ }
2400+ return false ;
2401+ }(children)
2402+ )
2403+ {
2404+ // Ensure at least one child
2405+ GGML_ASSERT (m_children.size () > 0 );
2406+
2407+ // Ensure layers are not overlapping and are concurrent
2408+ std::set<size_t > seen_layers;
2409+ size_t max_layer = 0 ;
2410+ for (const auto & cache : children) {
2411+ for (const auto & layer_id : cache.layer_ids ) {
2412+ GGML_ASSERT (seen_layers.find (layer_id) == seen_layers.end ());
2413+ seen_layers.insert (layer_id);
2414+ if (layer_id > max_layer) {
2415+ max_layer = layer_id;
2416+ }
2417+ }
2418+ }
2419+ GGML_ASSERT (max_layer == seen_layers.size ());
2420+ }
2421+
2422+ void llama_kv_cache_hybrid::clear () {
2423+ for (const auto & cache : m_children) {
2424+ cache->clear ();
2425+ }
2426+ }
2427+
2428+ bool llama_kv_cache_hybrid::seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
2429+ // TODO: Will it cause problems if some caches are able to remove the seq
2430+ // but others aren't?
2431+ bool removed = true ;
2432+ for (const auto & cache : m_children) {
2433+ removed = cache->seq_rm (seq_id, p0, p1) && removed;
2434+ }
2435+ return removed;
2436+ }
2437+
2438+ void llama_kv_cache_hybrid::seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
2439+ for (const auto & cache : m_children) {
2440+ cache->seq_cp (seq_id_src, seq_id_dst, p0, p1);
2441+ }
2442+ }
2443+
2444+ void llama_kv_cache_hybrid::seq_keep (llama_seq_id seq_id) {
2445+ for (const auto & cache : m_children) {
2446+ cache->seq_keep (seq_id);
2447+ }
2448+ }
2449+
2450+ void llama_kv_cache_hybrid::seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
2451+ for (const auto & cache : m_children) {
2452+ cache->seq_add (seq_id, p0, p1, delta);
2453+ }
2454+ }
2455+
2456+ void llama_kv_cache_hybrid::seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
2457+ for (const auto & cache : m_children) {
2458+ cache->seq_div (seq_id, p0, p1, d);
2459+ }
2460+ }
2461+
2462+ llama_pos llama_kv_cache_hybrid::seq_pos_max (llama_seq_id seq_id) const {
2463+ llama_pos max_pos = 0 ;
2464+ for (const auto & cache : m_children) {
2465+ max_pos = std::max (max_pos, cache->seq_pos_max (seq_id));
2466+ }
2467+ return max_pos;
2468+ }
2469+
2470+ void llama_kv_cache_hybrid::restore () {
2471+ for (const auto & cache : m_children) {
2472+ cache->restore ();
2473+ }
2474+ }
2475+
2476+ void llama_kv_cache_hybrid::commit () {
2477+ for (const auto & cache : m_children) {
2478+ cache->commit ();
2479+ }
2480+ }
2481+
2482+ bool llama_kv_cache_hybrid::update (llama_context & ctx) {
2483+ bool updated = false ;
2484+ for (const auto & cache : m_children) {
2485+ updated = cache->update (ctx) || updated;
2486+ }
2487+ return updated;
2488+ }
2489+
2490+ void llama_kv_cache_hybrid::defrag_sched (float thold) {
2491+ for (const auto & cache : m_children) {
2492+ cache->defrag_sched (thold);
2493+ }
2494+ }
2495+
2496+ void llama_kv_cache_hybrid::set_full () {
2497+ for (const auto & cache : m_children) {
2498+ cache->set_full ();
2499+ }
2500+ }
2501+
2502+ llama_sbatch llama_kv_cache_hybrid::sbatch_init (const llama_batch & batch, bool logits_all) {
2503+ // If any of the caches are recurrent, require simple split
2504+ return llama_sbatch (batch, m_hparams.n_embd , m_has_recurrent, logits_all);
2505+ }
2506+
2507+ llama_ubatch llama_kv_cache_hybrid::ubatch_next (llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
2508+ if (m_has_recurrent) {
2509+ return sbatch.split_simple (n_ubatch);
2510+ }
2511+ if (embd_pooled) {
2512+ // Pooled embeddings cannot be split across ubatches (yet)
2513+ return sbatch.split_seq (n_ubatch);
2514+ }
2515+ return sbatch.split_equal (n_ubatch);
2516+ }
2517+
2518+ bool llama_kv_cache_hybrid::find_slot (const llama_ubatch & batch) {
2519+ bool found = true ;
2520+ for (const auto & cache : m_children) {
2521+ found = cache->find_slot (batch) && found;
2522+ }
2523+ return found;
2524+ }
2525+
2526+ int32_t llama_kv_cache_hybrid::get_n_tokens () const {
2527+ // The number of tokens should be the same across all child caches
2528+ int32_t n_tokens = -1 ;
2529+ for (const auto & cache : m_children) {
2530+ const auto cache_n_tokens = cache->get_n_tokens ();
2531+ GGML_ASSERT (n_tokens == -1 || cache_n_tokens == n_tokens);
2532+ n_tokens = cache_n_tokens;
2533+ }
2534+ return n_tokens;
2535+ }
2536+
2537+ int32_t llama_kv_cache_hybrid::get_used_cells () const {
2538+ // TODO: Is this correct?
2539+ // Return the largetst number of used cells
2540+ int32_t used_cells = -1 ;
2541+ for (const auto & cache : m_children) {
2542+ used_cells = std::max (used_cells, cache->get_used_cells ());
2543+ }
2544+ return used_cells;
2545+ }
2546+
2547+ llama_pos llama_kv_cache_hybrid::get_pos_max () const {
2548+ llama_pos pos_max = -1 ;
2549+ for (const auto & cache : m_children) {
2550+ pos_max = std::max (pos_max, cache->get_pos_max ());
2551+ }
2552+ return pos_max;
2553+ }
2554+
2555+ bool llama_kv_cache_hybrid::get_can_shift () const {
2556+ // TODO: Is this correct?
2557+ // If any children can shift, return true
2558+ for (const auto & cache : m_children) {
2559+ if (cache->get_can_shift ()) {
2560+ return true ;
2561+ }
2562+ }
2563+ return false ;
2564+ }
2565+
2566+ void llama_kv_cache_hybrid::state_write (llama_io_write_i & io, llama_seq_id seq_id) const {
2567+ // Write each cache state in order. Note that order is guaranteed at
2568+ // initialization by using an ordered set sorted by lowest layer ID
2569+ for (const auto & cache : m_children) {
2570+ cache->state_write (io, seq_id);
2571+ }
2572+ }
2573+
2574+ void llama_kv_cache_hybrid::state_read (llama_io_read_i & io, llama_seq_id seq_id) {
2575+ // Read each cache state in order. Note that order is guaranteed at
2576+ // initialization by using an ordered set sorted by lowest layer ID
2577+ for (const auto & cache : m_children) {
2578+ cache->state_read (io, seq_id);
2579+ }
2580+ }
2581+
23572582//
23582583// kv cache view
23592584//
0 commit comments