Skip to content

Commit d1aec07

Browse files
committed
feat: Initial implementation of llama_kv_cache_hybrid
Condensed from initial version https://github.com/gabe-l-hart/llama.cpp/tree/ec08571 The only difference is the removal of m_layer_cache_map which was unused and unnecessary now that child caches are instantiated with their own filters. Branch: HybridCache Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 52d7627 commit d1aec07

File tree

3 files changed

+403
-0
lines changed

3 files changed

+403
-0
lines changed

src/llama-kv-cache.cpp

Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2782,3 +2782,240 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
27822782

27832783
return true;
27842784
}
2785+
2786+
//
2787+
// llama_kv_cache_hybrid
2788+
//
2789+
llama_kv_cache_hybrid::llama_kv_cache_hybrid(
2790+
const llama_hparams & hparams,
2791+
std::vector<child_cache> children) :
2792+
m_hparams(hparams),
2793+
m_children(
2794+
[](std::vector<child_cache>& caches) -> std::set<std::unique_ptr<llama_kv_cache>> {
2795+
// Sort the caches by the lowest layer ID so the order is repeatable
2796+
for (auto & cache : caches) {
2797+
GGML_ASSERT(cache.layer_ids.size() > 0);
2798+
std::sort(cache.layer_ids.begin(), cache.layer_ids.end());
2799+
}
2800+
std::sort(caches.begin(), caches.end(), [](const child_cache & a, const child_cache & b) {
2801+
return a.layer_ids[0] < b.layer_ids[0];
2802+
});
2803+
std::set<std::unique_ptr<llama_kv_cache>> unique_caches;
2804+
for (auto & cache : caches) {
2805+
unique_caches.emplace(cache.child.release());
2806+
}
2807+
return unique_caches;
2808+
}(children)
2809+
),
2810+
m_has_recurrent(
2811+
[](const std::set<std::unique_ptr<llama_kv_cache>> & caches) -> bool {
2812+
for (const auto & cache : caches) {
2813+
if (dynamic_cast<llama_kv_cache_recurrent *>(cache.get())) {
2814+
return true;
2815+
}
2816+
}
2817+
return false;
2818+
}(m_children)
2819+
)
2820+
{
2821+
// Ensure at least one child
2822+
GGML_ASSERT(m_children.size() > 0);
2823+
2824+
// Ensure layers are not overlapping and are concurrent
2825+
std::set<size_t> seen_layers;
2826+
size_t max_layer = 0;
2827+
for (const auto & cache : children) {
2828+
for (const auto & layer_id : cache.layer_ids) {
2829+
GGML_ASSERT(seen_layers.find(layer_id) == seen_layers.end());
2830+
seen_layers.insert(layer_id);
2831+
if (layer_id > max_layer) {
2832+
max_layer = layer_id;
2833+
}
2834+
}
2835+
}
2836+
LLAMA_LOG_DEBUG("max_layer=%zu, seen_layers.size()=%zu\n", max_layer, seen_layers.size());
2837+
GGML_ASSERT(max_layer + 1 == seen_layers.size());
2838+
}
2839+
2840+
void llama_kv_cache_hybrid::clear() {
2841+
for (const auto & cache : m_children) {
2842+
cache->clear();
2843+
}
2844+
}
2845+
2846+
bool llama_kv_cache_hybrid::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
2847+
// First check if we can do this removal. This checks all children so that
2848+
// no mutation happens before we know if it's possible
2849+
if (!can_seq_rm(seq_id, p0, p1)) {
2850+
return false;
2851+
}
2852+
2853+
// Do the removal from each child which should never fail
2854+
for (const auto & cache : m_children) {
2855+
const bool failed = cache->seq_rm(seq_id, p0, p1);
2856+
GGML_ASSERT(!failed);
2857+
}
2858+
return true;
2859+
}
2860+
2861+
void llama_kv_cache_hybrid::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
2862+
for (const auto & cache : m_children) {
2863+
cache->seq_cp(seq_id_src, seq_id_dst, p0, p1);
2864+
}
2865+
}
2866+
2867+
void llama_kv_cache_hybrid::seq_keep(llama_seq_id seq_id) {
2868+
for (const auto & cache : m_children) {
2869+
cache->seq_keep(seq_id);
2870+
}
2871+
}
2872+
2873+
void llama_kv_cache_hybrid::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
2874+
for (const auto & cache : m_children) {
2875+
cache->seq_add(seq_id, p0, p1, delta);
2876+
}
2877+
}
2878+
2879+
void llama_kv_cache_hybrid::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
2880+
for (const auto & cache : m_children) {
2881+
cache->seq_div(seq_id, p0, p1, d);
2882+
}
2883+
}
2884+
2885+
llama_pos llama_kv_cache_hybrid::seq_pos_min(llama_seq_id seq_id) const {
2886+
llama_pos min_pos = -1;
2887+
for (const auto & cache : m_children) {
2888+
const auto child_min_pos = cache->seq_pos_min(seq_id);
2889+
min_pos = min_pos == -1 ? child_min_pos : std::min(min_pos, child_min_pos);
2890+
}
2891+
return min_pos;
2892+
}
2893+
2894+
llama_pos llama_kv_cache_hybrid::seq_pos_max(llama_seq_id seq_id) const {
2895+
llama_pos max_pos = 0;
2896+
for (const auto & cache : m_children) {
2897+
max_pos = std::max(max_pos, cache->seq_pos_max(seq_id));
2898+
}
2899+
return max_pos;
2900+
}
2901+
2902+
void llama_kv_cache_hybrid::restore() {
2903+
for (const auto & cache : m_children) {
2904+
cache->restore();
2905+
}
2906+
}
2907+
2908+
void llama_kv_cache_hybrid::commit() {
2909+
for (const auto & cache : m_children) {
2910+
cache->commit();
2911+
}
2912+
}
2913+
2914+
bool llama_kv_cache_hybrid::update(llama_context & ctx) {
2915+
bool updated = false;
2916+
for (const auto & cache : m_children) {
2917+
updated = cache->update(ctx) || updated;
2918+
}
2919+
return updated;
2920+
}
2921+
2922+
void llama_kv_cache_hybrid::defrag_sched(float thold) {
2923+
for (const auto & cache : m_children) {
2924+
cache->defrag_sched(thold);
2925+
}
2926+
}
2927+
2928+
void llama_kv_cache_hybrid::set_full() {
2929+
for (const auto & cache : m_children) {
2930+
cache->set_full();
2931+
}
2932+
}
2933+
2934+
bool llama_kv_cache_hybrid::can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
2935+
for (const auto & cache : m_children) {
2936+
if (!cache->can_seq_rm(seq_id, p0, p1)) {
2937+
return false;
2938+
}
2939+
}
2940+
return true;
2941+
}
2942+
2943+
llama_sbatch llama_kv_cache_hybrid::sbatch_init(const llama_batch & batch, bool logits_all) {
2944+
// If any of the caches are recurrent, require equal split
2945+
return llama_sbatch(batch, m_hparams.n_embd, !m_has_recurrent, logits_all);
2946+
}
2947+
2948+
llama_ubatch llama_kv_cache_hybrid::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
2949+
if (embd_pooled) {
2950+
// Pooled embeddings cannot be split across ubatches (yet)
2951+
return sbatch.split_seq(n_ubatch);
2952+
}
2953+
if (m_has_recurrent) {
2954+
return sbatch.split_equal(n_ubatch);
2955+
}
2956+
return sbatch.split_simple(n_ubatch);
2957+
}
2958+
2959+
bool llama_kv_cache_hybrid::find_slot(const llama_ubatch & batch) {
2960+
bool found = true;
2961+
for (const auto & cache : m_children) {
2962+
found = cache->find_slot(batch) && found;
2963+
}
2964+
return found;
2965+
}
2966+
2967+
int32_t llama_kv_cache_hybrid::get_n_tokens() const {
2968+
// The number of tokens should be the same across all child caches
2969+
int32_t n_tokens = -1;
2970+
for (const auto & cache : m_children) {
2971+
const auto cache_n_tokens = cache->get_n_tokens();
2972+
GGML_ASSERT(n_tokens == -1 || cache_n_tokens == n_tokens);
2973+
n_tokens = cache_n_tokens;
2974+
}
2975+
return n_tokens;
2976+
}
2977+
2978+
int32_t llama_kv_cache_hybrid::get_used_cells() const {
2979+
// TODO: Is this correct?
2980+
// Return the largest number of used cells
2981+
int32_t used_cells = -1;
2982+
for (const auto & cache : m_children) {
2983+
used_cells = std::max(used_cells, cache->get_used_cells());
2984+
}
2985+
return used_cells;
2986+
}
2987+
2988+
llama_pos llama_kv_cache_hybrid::get_pos_max() const {
2989+
llama_pos pos_max = -1;
2990+
for (const auto & cache : m_children) {
2991+
pos_max = std::max(pos_max, cache->get_pos_max());
2992+
}
2993+
return pos_max;
2994+
}
2995+
2996+
bool llama_kv_cache_hybrid::get_can_shift() const {
2997+
// TODO: Is this correct?
2998+
// If any children can shift, return true
2999+
for (const auto & cache : m_children) {
3000+
if (cache->get_can_shift()) {
3001+
return true;
3002+
}
3003+
}
3004+
return false;
3005+
}
3006+
3007+
void llama_kv_cache_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
3008+
// Write each cache state in order. Note that order is guaranteed at
3009+
// initialization by using an ordered set sorted by lowest layer ID
3010+
for (const auto & cache : m_children) {
3011+
cache->state_write(io, seq_id);
3012+
}
3013+
}
3014+
3015+
void llama_kv_cache_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
3016+
// Read each cache state in order. Note that order is guaranteed at
3017+
// initialization by using an ordered set sorted by lowest layer ID
3018+
for (const auto & cache : m_children) {
3019+
cache->state_read(io, seq_id);
3020+
}
3021+
}

src/llama-kv-cache.h

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,3 +515,101 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
515515
bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
516516
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
517517
};
518+
519+
//
520+
// llama_kv_cache_hybrid
521+
//
522+
523+
// utilizes multiple different cache types with each layer assigned to exactly
524+
// one cache. This is typically used for hybrid attention / recurrent caching
525+
526+
class llama_kv_cache_hybrid : public llama_kv_cache {
527+
public:
528+
529+
struct child_cache {
530+
std::unique_ptr<llama_kv_cache> child;
531+
std::vector<size_t> layer_ids;
532+
533+
child_cache(std::unique_ptr<llama_kv_cache> child_, std::vector<size_t> layer_ids_)
534+
: child(std::move(child_)), layer_ids(std::move(layer_ids_)) {}
535+
};
536+
537+
llama_kv_cache_hybrid(
538+
const llama_hparams & hparams,
539+
std::vector<child_cache> children);
540+
541+
virtual ~llama_kv_cache_hybrid() = default;
542+
543+
// getters for specific child cache type
544+
// NOTE: This will fail if there are multiple of the given type
545+
template<typename child_t>
546+
const child_t * get_child_cache() const {
547+
const child_t * child = nullptr;
548+
for (const auto & child_cache : m_children) {
549+
const child_t * child_cast = dynamic_cast<const child_t *>(child_cache.get());
550+
if (child_cast) {
551+
GGML_ASSERT(!child);
552+
child = child_cast;
553+
}
554+
}
555+
return child;
556+
}
557+
558+
//
559+
// llama_memory_i
560+
//
561+
562+
void clear() override;
563+
564+
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
565+
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
566+
void seq_keep(llama_seq_id seq_id) override;
567+
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
568+
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
569+
570+
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
571+
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
572+
573+
//
574+
// llama_kv_cache
575+
//
576+
577+
void restore() override;
578+
void commit() override;
579+
580+
bool update(llama_context & ctx) override;
581+
582+
void defrag_sched(float thold) override;
583+
584+
void set_full() override;
585+
586+
bool can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const override;
587+
588+
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
589+
590+
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
591+
592+
// updates the cache head
593+
// Note: On success, it's important that cache.head points
594+
// to the first cell of the slot.
595+
bool find_slot(const llama_ubatch & batch) override;
596+
597+
int32_t get_n_tokens() const override;
598+
int32_t get_used_cells() const override;
599+
600+
// TODO: better data structures to reduce the cost of this operation
601+
llama_pos get_pos_max() const override;
602+
603+
bool get_can_shift() const override;
604+
605+
// state write/load
606+
607+
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
608+
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
609+
610+
private:
611+
612+
const llama_hparams & m_hparams;
613+
const std::set<std::unique_ptr<llama_kv_cache>> m_children; // Ordered for state IO
614+
const bool m_has_recurrent;
615+
};

0 commit comments

Comments
 (0)