@@ -399,7 +399,7 @@ std::vector<uint32_t> llama_kv_cache_unified::prepare(const std::vector<llama_ub
399399 break ;
400400 }
401401
402- // remeber the position that we found
402+ // remember the position that we found
403403 res.push_back (head_new);
404404
405405 // store the old state of the cells in the recovery stack
@@ -3037,11 +3037,93 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
30373037//
30383038// llama_kv_cache_hybrid
30393039//
3040- llama_kv_cache_hybrid::llama_kv_cache_hybrid (
3041- const llama_hparams & hparams,
3042- std::vector<child_cache> children) :
3043- m_hparams(hparams),
3044- m_children(
3040+
3041+
3042+ class llama_kv_cache_hybrid_decode_state_t : public llama_memory_decode_state_i {
3043+ public:
3044+ explicit llama_kv_cache_hybrid_decode_state_t (
3045+ std::vector<llama_memory_decode_state_ptr> decode_states) :
3046+ status([](const std::vector<llama_memory_decode_state_ptr> & decode_states) -> llama_memory_status {
3047+ for (const auto & decode_state : decode_states) {
3048+ if (!decode_state) {
3049+ return LLAMA_MEMORY_STATUS_FAILED_PREPARE;
3050+ }
3051+ const auto & status = decode_state->get_status ();
3052+ if (status != LLAMA_MEMORY_STATUS_SUCCESS) {
3053+ return status;
3054+ }
3055+ }
3056+ return LLAMA_MEMORY_STATUS_SUCCESS;
3057+ }(decode_states)),
3058+ decode_states (std::move(decode_states)) {
3059+
3060+ // make sure at least one decode state
3061+ assert (!decode_states.empty ());
3062+
3063+ // make sure all out_ids match across states
3064+ // TODO: This could be expensive, so maybe don't do it?
3065+ const auto & out_ids = decode_states[0 ]->out_ids ();
3066+ for (size_t i = 1 ; i < decode_states.size (); ++i) {
3067+ const auto & out_ids_i = decode_states[i]->out_ids ();
3068+ assert (out_ids.size () == out_ids_i.size ());
3069+ for (size_t j = 0 ; j < out_ids.size (); ++j) {
3070+ assert (out_ids[j] == out_ids_i[j]);
3071+ }
3072+ }
3073+ }
3074+
3075+ ~llama_kv_cache_hybrid_decode_state_t () = default ;
3076+
3077+ llama_ubatch * next () override {
3078+ assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
3079+
3080+ // hit next on each child
3081+ std::vector<llama_ubatch *> next_ubatches;
3082+ for (const auto & decode_state : decode_states) {
3083+ next_ubatches.push_back (decode_state->next ());
3084+ }
3085+
3086+ // make sure they all match
3087+ // TODO: unnecessary safety?
3088+ llama_ubatch * res = next_ubatches[0 ];
3089+ assert (res);
3090+ for (size_t i = 1 ; i < next_ubatches.size (); ++i) {
3091+ llama_ubatch * ubatch_i = next_ubatches[i];
3092+ assert (ubatch_i);
3093+ assert (ubatch_i->n_tokens == res->n_tokens );
3094+ assert (ubatch_i->n_seq_tokens == res->n_seq_tokens );
3095+ assert (ubatch_i->n_seqs == res->n_seqs );
3096+ for (size_t j = 0 ; j < res->n_tokens ; ++j) {
3097+ assert (ubatch_i->token [j] == res->token [j]);
3098+ assert (ubatch_i->pos [j] == res->pos [j]);
3099+ assert (ubatch_i->output [j] == res->output [j]);
3100+ }
3101+ for (size_t j = 0 ; j < res->n_seqs ; ++j) {
3102+ assert (ubatch_i->n_seq_id [j] == res->n_seq_id [j]);
3103+ }
3104+ }
3105+
3106+ // return the first ubatch since they all match
3107+ return res;
3108+ }
3109+
3110+ std::vector<int64_t > & out_ids () override {
3111+ assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
3112+
3113+ return decode_states[0 ]->out_ids ();
3114+ }
3115+
3116+ llama_memory_status get_status () const override {
3117+ return status;
3118+ }
3119+
3120+ private:
3121+ const llama_memory_status status;
3122+ std::vector<llama_memory_decode_state_ptr> decode_states;
3123+ };
3124+
3125+ llama_kv_cache_hybrid::llama_kv_cache_hybrid (std::vector<child_cache> children_) :
3126+ children(
30453127 [](std::vector<child_cache>& caches) -> std::set<std::unique_ptr<llama_kv_cache>> {
30463128 // Sort the caches by the lowest layer ID so the order is repeatable
30473129 for (auto & cache : caches) {
@@ -3056,26 +3138,26 @@ llama_kv_cache_hybrid::llama_kv_cache_hybrid(
30563138 unique_caches.emplace (cache.child .release ());
30573139 }
30583140 return unique_caches;
3059- }(children )
3141+ }(children_ )
30603142 ),
3061- m_has_recurrent (
3143+ has_recurrent (
30623144 [](const std::set<std::unique_ptr<llama_kv_cache>> & caches) -> bool {
30633145 for (const auto & cache : caches) {
30643146 if (dynamic_cast <llama_kv_cache_recurrent *>(cache.get ())) {
30653147 return true ;
30663148 }
30673149 }
30683150 return false ;
3069- }(m_children )
3151+ }(children )
30703152 )
30713153{
30723154 // Ensure at least one child
3073- GGML_ASSERT (m_children .size () > 0 );
3155+ GGML_ASSERT (children .size () > 0 );
30743156
30753157 // Ensure layers are not overlapping and are concurrent
30763158 std::set<size_t > seen_layers;
30773159 size_t max_layer = 0 ;
3078- for (const auto & cache : children ) {
3160+ for (const auto & cache : children_ ) {
30793161 for (const auto & layer_id : cache.layer_ids ) {
30803162 GGML_ASSERT (seen_layers.find (layer_id) == seen_layers.end ());
30813163 seen_layers.insert (layer_id);
@@ -3089,7 +3171,7 @@ llama_kv_cache_hybrid::llama_kv_cache_hybrid(
30893171}
30903172
30913173void llama_kv_cache_hybrid::clear () {
3092- for (const auto & cache : m_children ) {
3174+ for (const auto & cache : children ) {
30933175 cache->clear ();
30943176 }
30953177}
@@ -3102,40 +3184,40 @@ bool llama_kv_cache_hybrid::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
31023184 }
31033185
31043186 // Do the removal from each child which should never fail
3105- for (const auto & cache : m_children ) {
3187+ for (const auto & cache : children ) {
31063188 const bool failed = cache->seq_rm (seq_id, p0, p1);
31073189 GGML_ASSERT (!failed);
31083190 }
31093191 return true ;
31103192}
31113193
31123194void llama_kv_cache_hybrid::seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
3113- for (const auto & cache : m_children ) {
3195+ for (const auto & cache : children ) {
31143196 cache->seq_cp (seq_id_src, seq_id_dst, p0, p1);
31153197 }
31163198}
31173199
31183200void llama_kv_cache_hybrid::seq_keep (llama_seq_id seq_id) {
3119- for (const auto & cache : m_children ) {
3201+ for (const auto & cache : children ) {
31203202 cache->seq_keep (seq_id);
31213203 }
31223204}
31233205
31243206void llama_kv_cache_hybrid::seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
3125- for (const auto & cache : m_children ) {
3207+ for (const auto & cache : children ) {
31263208 cache->seq_add (seq_id, p0, p1, delta);
31273209 }
31283210}
31293211
31303212void llama_kv_cache_hybrid::seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
3131- for (const auto & cache : m_children ) {
3213+ for (const auto & cache : children ) {
31323214 cache->seq_div (seq_id, p0, p1, d);
31333215 }
31343216}
31353217
31363218llama_pos llama_kv_cache_hybrid::seq_pos_min (llama_seq_id seq_id) const {
31373219 llama_pos min_pos = -1 ;
3138- for (const auto & cache : m_children ) {
3220+ for (const auto & cache : children ) {
31393221 const auto child_min_pos = cache->seq_pos_min (seq_id);
31403222 min_pos = min_pos == -1 ? child_min_pos : std::min (min_pos, child_min_pos);
31413223 }
@@ -3144,81 +3226,67 @@ llama_pos llama_kv_cache_hybrid::seq_pos_min(llama_seq_id seq_id) const {
31443226
31453227llama_pos llama_kv_cache_hybrid::seq_pos_max (llama_seq_id seq_id) const {
31463228 llama_pos max_pos = 0 ;
3147- for (const auto & cache : m_children ) {
3229+ for (const auto & cache : children ) {
31483230 max_pos = std::max (max_pos, cache->seq_pos_max (seq_id));
31493231 }
31503232 return max_pos;
31513233}
31523234
3153- void llama_kv_cache_hybrid::restore () {
3154- for (const auto & cache : m_children) {
3155- cache->restore ();
3156- }
3157- }
3235+ llama_memory_decode_state_ptr llama_kv_cache_hybrid::init (
3236+ const llama_batch & batch,
3237+ uint32_t n_ubatch,
3238+ bool embd_pooled,
3239+ bool logits_all,
3240+ bool split_equal) {
31583241
3159- void llama_kv_cache_hybrid::commit () {
3160- for (const auto & cache : m_children) {
3161- cache->commit ();
3242+ // recurrent children require equal splits
3243+ // TODO: just ignore this if set incorrectly?
3244+ assert (!has_recurrent || split_equal);
3245+
3246+ // init all children and capture their decode states
3247+ std::vector<llama_memory_decode_state_ptr> decode_states;
3248+ for (const auto & child : children) {
3249+ decode_states.emplace_back (
3250+ child->init (batch, n_ubatch, embd_pooled, logits_all, split_equal));
31623251 }
3252+
3253+ // return the hybrid decode state
3254+ return std::make_unique<llama_kv_cache_hybrid_decode_state_t >(std::move (decode_states));
31633255}
31643256
31653257bool llama_kv_cache_hybrid::update (llama_context & ctx) {
31663258 bool updated = false ;
3167- for (const auto & cache : m_children ) {
3259+ for (const auto & cache : children ) {
31683260 updated = cache->update (ctx) || updated;
31693261 }
31703262 return updated;
31713263}
31723264
31733265void llama_kv_cache_hybrid::defrag_sched (float thold) {
3174- for (const auto & cache : m_children ) {
3266+ for (const auto & cache : children ) {
31753267 cache->defrag_sched (thold);
31763268 }
31773269}
31783270
31793271void llama_kv_cache_hybrid::set_full () {
3180- for (const auto & cache : m_children ) {
3272+ for (const auto & cache : children ) {
31813273 cache->set_full ();
31823274 }
31833275}
31843276
31853277bool llama_kv_cache_hybrid::can_seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
3186- for (const auto & cache : m_children ) {
3278+ for (const auto & cache : children ) {
31873279 if (!cache->can_seq_rm (seq_id, p0, p1)) {
31883280 return false ;
31893281 }
31903282 }
31913283 return true ;
31923284}
31933285
3194- llama_sbatch llama_kv_cache_hybrid::sbatch_init (const llama_batch & batch, bool logits_all) {
3195- // If any of the caches are recurrent, require equal split
3196- return llama_sbatch (batch, m_hparams.n_embd , !m_has_recurrent, logits_all);
3197- }
3198-
3199- llama_ubatch llama_kv_cache_hybrid::ubatch_next (llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
3200- if (embd_pooled) {
3201- // Pooled embeddings cannot be split across ubatches (yet)
3202- return sbatch.split_seq (n_ubatch);
3203- }
3204- if (m_has_recurrent) {
3205- return sbatch.split_equal (n_ubatch);
3206- }
3207- return sbatch.split_simple (n_ubatch);
3208- }
3209-
3210- bool llama_kv_cache_hybrid::find_slot (const llama_ubatch & batch) {
3211- bool found = true ;
3212- for (const auto & cache : m_children) {
3213- found = cache->find_slot (batch) && found;
3214- }
3215- return found;
3216- }
3217-
32183286bool llama_kv_cache_hybrid::get_can_shift () const {
32193287 // TODO: Is this correct?
32203288 // If any children can shift, return true
3221- for (const auto & cache : m_children ) {
3289+ for (const auto & cache : children ) {
32223290 if (cache->get_can_shift ()) {
32233291 return true ;
32243292 }
@@ -3229,15 +3297,15 @@ bool llama_kv_cache_hybrid::get_can_shift() const {
32293297void llama_kv_cache_hybrid::state_write (llama_io_write_i & io, llama_seq_id seq_id) const {
32303298 // Write each cache state in order. Note that order is guaranteed at
32313299 // initialization by using an ordered set sorted by lowest layer ID
3232- for (const auto & cache : m_children ) {
3300+ for (const auto & cache : children ) {
32333301 cache->state_write (io, seq_id);
32343302 }
32353303}
32363304
32373305void llama_kv_cache_hybrid::state_read (llama_io_read_i & io, llama_seq_id seq_id) {
32383306 // Read each cache state in order. Note that order is guaranteed at
32393307 // initialization by using an ordered set sorted by lowest layer ID
3240- for (const auto & cache : m_children ) {
3308+ for (const auto & cache : children ) {
32413309 cache->state_read (io, seq_id);
32423310 }
32433311}
0 commit comments