Skip to content

Commit f62858b

Browse files
committed
feat: Initial implementation of llama_kv_cache_hybrid
Condensed from initial version https://github.com/gabe-l-hart/llama.cpp/tree/ec08571 The only difference is the removal of m_layer_cache_map which was unused and unnecessary now that child caches are instantiated with their own filters. Branch: HybridCache Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 7a0fe25 commit f62858b

File tree

3 files changed

+403
-0
lines changed

3 files changed

+403
-0
lines changed

src/llama-kv-cache.cpp

Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2870,3 +2870,240 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
28702870

28712871
return true;
28722872
}
2873+
2874+
//
2875+
// llama_kv_cache_hybrid
2876+
//
2877+
llama_kv_cache_hybrid::llama_kv_cache_hybrid(
2878+
const llama_hparams & hparams,
2879+
std::vector<child_cache> children) :
2880+
m_hparams(hparams),
2881+
m_children(
2882+
[](std::vector<child_cache>& caches) -> std::set<std::unique_ptr<llama_kv_cache>> {
2883+
// Sort the caches by the lowest layer ID so the order is repeatable
2884+
for (auto & cache : caches) {
2885+
GGML_ASSERT(cache.layer_ids.size() > 0);
2886+
std::sort(cache.layer_ids.begin(), cache.layer_ids.end());
2887+
}
2888+
std::sort(caches.begin(), caches.end(), [](const child_cache & a, const child_cache & b) {
2889+
return a.layer_ids[0] < b.layer_ids[0];
2890+
});
2891+
std::set<std::unique_ptr<llama_kv_cache>> unique_caches;
2892+
for (auto & cache : caches) {
2893+
unique_caches.emplace(cache.child.release());
2894+
}
2895+
return unique_caches;
2896+
}(children)
2897+
),
2898+
m_has_recurrent(
2899+
[](const std::set<std::unique_ptr<llama_kv_cache>> & caches) -> bool {
2900+
for (const auto & cache : caches) {
2901+
if (dynamic_cast<llama_kv_cache_recurrent *>(cache.get())) {
2902+
return true;
2903+
}
2904+
}
2905+
return false;
2906+
}(m_children)
2907+
)
2908+
{
2909+
// Ensure at least one child
2910+
GGML_ASSERT(m_children.size() > 0);
2911+
2912+
// Ensure layers are not overlapping and are concurrent
2913+
std::set<size_t> seen_layers;
2914+
size_t max_layer = 0;
2915+
for (const auto & cache : children) {
2916+
for (const auto & layer_id : cache.layer_ids) {
2917+
GGML_ASSERT(seen_layers.find(layer_id) == seen_layers.end());
2918+
seen_layers.insert(layer_id);
2919+
if (layer_id > max_layer) {
2920+
max_layer = layer_id;
2921+
}
2922+
}
2923+
}
2924+
LLAMA_LOG_DEBUG("max_layer=%zu, seen_layers.size()=%zu\n", max_layer, seen_layers.size());
2925+
GGML_ASSERT(max_layer + 1 == seen_layers.size());
2926+
}
2927+
2928+
void llama_kv_cache_hybrid::clear() {
2929+
for (const auto & cache : m_children) {
2930+
cache->clear();
2931+
}
2932+
}
2933+
2934+
bool llama_kv_cache_hybrid::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
2935+
// First check if we can do this removal. This checks all children so that
2936+
// no mutation happens before we know if it's possible
2937+
if (!can_seq_rm(seq_id, p0, p1)) {
2938+
return false;
2939+
}
2940+
2941+
// Do the removal from each child which should never fail
2942+
for (const auto & cache : m_children) {
2943+
const bool failed = cache->seq_rm(seq_id, p0, p1);
2944+
GGML_ASSERT(!failed);
2945+
}
2946+
return true;
2947+
}
2948+
2949+
void llama_kv_cache_hybrid::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
2950+
for (const auto & cache : m_children) {
2951+
cache->seq_cp(seq_id_src, seq_id_dst, p0, p1);
2952+
}
2953+
}
2954+
2955+
void llama_kv_cache_hybrid::seq_keep(llama_seq_id seq_id) {
2956+
for (const auto & cache : m_children) {
2957+
cache->seq_keep(seq_id);
2958+
}
2959+
}
2960+
2961+
void llama_kv_cache_hybrid::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
2962+
for (const auto & cache : m_children) {
2963+
cache->seq_add(seq_id, p0, p1, delta);
2964+
}
2965+
}
2966+
2967+
void llama_kv_cache_hybrid::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
2968+
for (const auto & cache : m_children) {
2969+
cache->seq_div(seq_id, p0, p1, d);
2970+
}
2971+
}
2972+
2973+
llama_pos llama_kv_cache_hybrid::seq_pos_min(llama_seq_id seq_id) const {
2974+
llama_pos min_pos = -1;
2975+
for (const auto & cache : m_children) {
2976+
const auto child_min_pos = cache->seq_pos_min(seq_id);
2977+
min_pos = min_pos == -1 ? child_min_pos : std::min(min_pos, child_min_pos);
2978+
}
2979+
return min_pos;
2980+
}
2981+
2982+
llama_pos llama_kv_cache_hybrid::seq_pos_max(llama_seq_id seq_id) const {
2983+
llama_pos max_pos = 0;
2984+
for (const auto & cache : m_children) {
2985+
max_pos = std::max(max_pos, cache->seq_pos_max(seq_id));
2986+
}
2987+
return max_pos;
2988+
}
2989+
2990+
void llama_kv_cache_hybrid::restore() {
2991+
for (const auto & cache : m_children) {
2992+
cache->restore();
2993+
}
2994+
}
2995+
2996+
void llama_kv_cache_hybrid::commit() {
2997+
for (const auto & cache : m_children) {
2998+
cache->commit();
2999+
}
3000+
}
3001+
3002+
bool llama_kv_cache_hybrid::update(llama_context & ctx) {
3003+
bool updated = false;
3004+
for (const auto & cache : m_children) {
3005+
updated = cache->update(ctx) || updated;
3006+
}
3007+
return updated;
3008+
}
3009+
3010+
void llama_kv_cache_hybrid::defrag_sched(float thold) {
3011+
for (const auto & cache : m_children) {
3012+
cache->defrag_sched(thold);
3013+
}
3014+
}
3015+
3016+
void llama_kv_cache_hybrid::set_full() {
3017+
for (const auto & cache : m_children) {
3018+
cache->set_full();
3019+
}
3020+
}
3021+
3022+
bool llama_kv_cache_hybrid::can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
3023+
for (const auto & cache : m_children) {
3024+
if (!cache->can_seq_rm(seq_id, p0, p1)) {
3025+
return false;
3026+
}
3027+
}
3028+
return true;
3029+
}
3030+
3031+
llama_sbatch llama_kv_cache_hybrid::sbatch_init(const llama_batch & batch, bool logits_all) {
3032+
// If any of the caches are recurrent, require equal split
3033+
return llama_sbatch(batch, m_hparams.n_embd, !m_has_recurrent, logits_all);
3034+
}
3035+
3036+
llama_ubatch llama_kv_cache_hybrid::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
3037+
if (embd_pooled) {
3038+
// Pooled embeddings cannot be split across ubatches (yet)
3039+
return sbatch.split_seq(n_ubatch);
3040+
}
3041+
if (m_has_recurrent) {
3042+
return sbatch.split_equal(n_ubatch);
3043+
}
3044+
return sbatch.split_simple(n_ubatch);
3045+
}
3046+
3047+
bool llama_kv_cache_hybrid::find_slot(const llama_ubatch & batch) {
3048+
bool found = true;
3049+
for (const auto & cache : m_children) {
3050+
found = cache->find_slot(batch) && found;
3051+
}
3052+
return found;
3053+
}
3054+
3055+
int32_t llama_kv_cache_hybrid::get_n_tokens() const {
3056+
// The number of tokens should be the same across all child caches
3057+
int32_t n_tokens = -1;
3058+
for (const auto & cache : m_children) {
3059+
const auto cache_n_tokens = cache->get_n_tokens();
3060+
GGML_ASSERT(n_tokens == -1 || cache_n_tokens == n_tokens);
3061+
n_tokens = cache_n_tokens;
3062+
}
3063+
return n_tokens;
3064+
}
3065+
3066+
int32_t llama_kv_cache_hybrid::get_used_cells() const {
3067+
// TODO: Is this correct?
3068+
// Return the largest number of used cells
3069+
int32_t used_cells = -1;
3070+
for (const auto & cache : m_children) {
3071+
used_cells = std::max(used_cells, cache->get_used_cells());
3072+
}
3073+
return used_cells;
3074+
}
3075+
3076+
llama_pos llama_kv_cache_hybrid::get_pos_max() const {
3077+
llama_pos pos_max = -1;
3078+
for (const auto & cache : m_children) {
3079+
pos_max = std::max(pos_max, cache->get_pos_max());
3080+
}
3081+
return pos_max;
3082+
}
3083+
3084+
bool llama_kv_cache_hybrid::get_can_shift() const {
3085+
// TODO: Is this correct?
3086+
// If any children can shift, return true
3087+
for (const auto & cache : m_children) {
3088+
if (cache->get_can_shift()) {
3089+
return true;
3090+
}
3091+
}
3092+
return false;
3093+
}
3094+
3095+
void llama_kv_cache_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
3096+
// Write each cache state in order. Note that order is guaranteed at
3097+
// initialization by using an ordered set sorted by lowest layer ID
3098+
for (const auto & cache : m_children) {
3099+
cache->state_write(io, seq_id);
3100+
}
3101+
}
3102+
3103+
void llama_kv_cache_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
3104+
// Read each cache state in order. Note that order is guaranteed at
3105+
// initialization by using an ordered set sorted by lowest layer ID
3106+
for (const auto & cache : m_children) {
3107+
cache->state_read(io, seq_id);
3108+
}
3109+
}

src/llama-kv-cache.h

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,3 +528,101 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
528528
bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
529529
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
530530
};
531+
532+
//
533+
// llama_kv_cache_hybrid
534+
//
535+
536+
// utilizes multiple different cache types with each layer assigned to exactly
537+
// one cache. This is typically used for hybrid attention / recurrent caching
538+
539+
class llama_kv_cache_hybrid : public llama_kv_cache {
540+
public:
541+
542+
struct child_cache {
543+
std::unique_ptr<llama_kv_cache> child;
544+
std::vector<size_t> layer_ids;
545+
546+
child_cache(std::unique_ptr<llama_kv_cache> child_, std::vector<size_t> layer_ids_)
547+
: child(std::move(child_)), layer_ids(std::move(layer_ids_)) {}
548+
};
549+
550+
llama_kv_cache_hybrid(
551+
const llama_hparams & hparams,
552+
std::vector<child_cache> children);
553+
554+
virtual ~llama_kv_cache_hybrid() = default;
555+
556+
// getters for specific child cache type
557+
// NOTE: This will fail if there are multiple of the given type
558+
template<typename child_t>
559+
const child_t * get_child_cache() const {
560+
const child_t * child = nullptr;
561+
for (const auto & child_cache : m_children) {
562+
const child_t * child_cast = dynamic_cast<const child_t *>(child_cache.get());
563+
if (child_cast) {
564+
GGML_ASSERT(!child);
565+
child = child_cast;
566+
}
567+
}
568+
return child;
569+
}
570+
571+
//
572+
// llama_memory_i
573+
//
574+
575+
void clear() override;
576+
577+
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
578+
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
579+
void seq_keep(llama_seq_id seq_id) override;
580+
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
581+
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
582+
583+
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
584+
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
585+
586+
//
587+
// llama_kv_cache
588+
//
589+
590+
void restore() override;
591+
void commit() override;
592+
593+
bool update(llama_context & ctx) override;
594+
595+
void defrag_sched(float thold) override;
596+
597+
void set_full() override;
598+
599+
bool can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const override;
600+
601+
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
602+
603+
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
604+
605+
// updates the cache head
606+
// Note: On success, it's important that cache.head points
607+
// to the first cell of the slot.
608+
bool find_slot(const llama_ubatch & batch) override;
609+
610+
int32_t get_n_tokens() const override;
611+
int32_t get_used_cells() const override;
612+
613+
// TODO: better data structures to reduce the cost of this operation
614+
llama_pos get_pos_max() const override;
615+
616+
bool get_can_shift() const override;
617+
618+
// state write/load
619+
620+
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
621+
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
622+
623+
private:
624+
625+
const llama_hparams & m_hparams;
626+
const std::set<std::unique_ptr<llama_kv_cache>> m_children; // Ordered for state IO
627+
const bool m_has_recurrent;
628+
};

0 commit comments

Comments
 (0)